Coverage for src/flag_gems/ops/topk.py: 24%

327 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.language.core as core 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2, zeros_like 

13except ImportError: 

14 pass 

15 

16from flag_gems.runtime import torch_device_fn 

17from flag_gems.utils import libentry 

18from flag_gems.utils import triton_lang_extension as tle 

19from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

20from flag_gems.utils.triton_version_utils import HAS_TLE 

21 

22if HAS_TLE: 

23 import triton.experimental.tle.language as tle_gpu 

24else: 

25 tle_gpu = None 

26 

27logger = logging.getLogger(__name__) 

28_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

29_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

30_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

31_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

32_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

33_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

34_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

35_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

36_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

37_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

38_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

39_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

40_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

41_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

42 

43 

44@triton.jit 

45def _get_finfo_val( 

46 dtype, 

47 return_max, 

48): 

49 if dtype is tl.float32: 

50 if return_max: 

51 return _MAX_FLOAT32_VAL 

52 else: 

53 return _MIN_FLOAT32_VAL 

54 elif dtype is tl.float16: 

55 if return_max: 

56 return _MAX_FLOAT16_VAL 

57 else: 

58 return _MIN_FLOAT16_VAL 

59 elif dtype is tl.bfloat16: 

60 if return_max: 

61 return _MAX_BFLOAT16_VAL 

62 else: 

63 return _MIN_BFLOAT16_VAL 

64 

65 

66@triton.jit 

67def _get_iinfo_val( 

68 dtype, 

69 return_max, 

70): 

71 if return_max: 

72 return get_dtype_max(dtype) 

73 else: 

74 return get_dtype_min(dtype) 

75 

76 

77@libentry() 

78@triton.jit 

79def topk_stage1_kernel( 

80 y_ptr, 

81 index_ptr, 

82 x_ptr, 

83 k, 

84 N: tl.constexpr, 

85 CHUNK_SIZE: tl.constexpr, 

86 DESCENDING: tl.constexpr, 

87): 

88 cur_batch = tle.program_id(0) 

89 cur_chunk_idx = tle.program_id(1) 

90 chunk_num = tle.num_programs(1) 

91 

92 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

93 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

94 

95 chunk_offset = cur_chunk_idx * CHUNK_SIZE 

96 x_ptr += cur_batch * N + chunk_offset 

97 

98 cols = tl.arange(0, CHUNK_SIZE) 

99 mask = (chunk_offset + cols) < N 

100 

101 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING) 

102 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32) 

103 for k_idx in range(k): 

104 if DESCENDING: 

105 chunk_select_val = tl.max(x_val) 

106 chunk_select_idx = tl.argmax(x_val, axis=0) 

107 else: 

108 chunk_select_val = tl.min(x_val) 

109 chunk_select_idx = tl.argmin(x_val, axis=0) 

110 

111 tl.store(y_ptr + k_idx, chunk_select_val) 

112 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset) 

113 

114 if DESCENDING: 

115 x_val = tl.where( 

116 cols == chunk_select_idx, 

117 _get_finfo_val(tl.float32, return_max=False), 

118 x_val, 

119 ) 

120 else: 

121 x_val = tl.where( 

122 cols == chunk_select_idx, 

123 _get_finfo_val(tl.float32, return_max=True), 

124 x_val, 

125 ) 

126 

127 

128""" 

129Note(Zhengzekang): 

130Refer from triton2.2 official `sort` implementation: 

131https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

132Just add indices to sort with values. 

133""" 

134 

135 

136@triton.jit 

137def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

138 n_outer: core.constexpr = x.numel >> n_dims 

139 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

140 

141 # tl.device_print("shape is: ", shape) 

142 y = core.reshape(x, shape) 

143 y_idx = core.reshape(ids, shape) 

144 

145 # slice left/right with 'stride' 2**(n_dims - i - 1) 

146 mask = core.arange(0, 2)[None, :, None] 

147 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

148 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

149 left = core.reshape(left, x.shape) 

150 right = core.reshape(right, x.shape) 

151 

152 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

153 ids.dtype 

154 ) 

155 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

156 ids.dtype 

157 ) 

158 left_idx = core.reshape(left_idx, ids.shape) 

159 right_idx = core.reshape(right_idx, ids.shape) 

160 

161 # actual compare-and-swap 

162 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

163 idtype = core.int8 

164 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

165 idtype = core.int16 

166 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

167 idtype = core.int32 

168 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

169 idtype = core.int64 

170 else: 

171 raise ValueError("Unsupported dtype") 

172 

173 ileft = left.to(idtype, bitcast=True) 

174 iright = right.to(idtype, bitcast=True) 

175 ix = x.to(idtype, bitcast=True) 

176 

177 cond = (left > right) ^ flip 

178 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

179 

180 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

181 idx_dtype = core.int8 

182 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

183 idx_dtype = core.int16 

184 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

185 idx_dtype = core.int32 

186 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

187 idx_dtype = core.int64 

188 else: 

189 raise ValueError("Unsupported dtype") 

190 

191 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

192 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

193 ix_idx = ids.to(idx_dtype, bitcast=True) 

194 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

195 

196 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

197 

198 

199@triton.jit 

200def _bitonic_merge( 

201 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

202): 

203 """ 

204 order_type 0 == ascending 

205 order_type 1 == descending 

206 order_type 2 == alternating 

207 """ 

208 n_outer: core.constexpr = x.numel >> n_dims 

209 core.static_assert(stage <= n_dims) 

210 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

211 # descending order. 

212 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

213 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

214 # a stride of 2) at this stage 

215 if order == 2: 

216 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

217 flip = core.reshape( 

218 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

219 ) 

220 else: 

221 flip = order 

222 # perform `stage` rounds of `compare-and-swap` 

223 for i in core.static_range(stage): 

224 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

225 return x, ids 

226 

227 

228@triton.jit 

229def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

230 # handle default dimension or check that it is the most minor dim 

231 _dim: core.constexpr = dim 

232 n_dims: core.constexpr = _log2(x.shape[_dim]) 

233 for i in core.static_range(1, n_dims + 1): 

234 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

235 return x, ids 

236 

237 

238@libentry() 

239@triton.jit 

240def topk_stage2_kernel( 

241 y_ptr, 

242 index_ptr, 

243 chunk_x, 

244 chunk_index, 

245 sort_dim: tl.constexpr, 

246 k: tl.constexpr, 

247 N: tl.constexpr, 

248 BLOCK_SIZE: tl.constexpr, 

249 DESCENDING: tl.constexpr, 

250): 

251 cur_batch = tle.program_id(0) 

252 chunk_x += cur_batch * N 

253 chunk_index += cur_batch * N 

254 y_ptr += cur_batch * k 

255 index_ptr += cur_batch * k 

256 

257 cols = tl.arange(0, BLOCK_SIZE) 

258 mask = cols < N 

259 

260 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

261 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL 

262 

263 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32) 

264 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to( 

265 tl.int32 

266 ) 

267 

268 sorted_chunk_x, sorted_chunk_index = argsort( 

269 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

270 ) 

271 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k) 

272 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k) 

273 

274 

275if HAS_TLE: 

276 

277 @triton.jit 

278 def _get_topmask_and_fullmask(x): 

279 tl.static_assert( 

280 x.dtype.is_int_unsigned(), 

281 "floating-point value must be passed as bits", 

282 ) 

283 tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) 

284 fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 

285 tm_arr = tl.full(x.shape, tm, dtype=x.dtype) 

286 fm_arr = tl.full(x.shape, fm, dtype=x.dtype) 

287 return tm_arr, fm_arr 

288 

289 @triton.jit 

290 def _fpval_to_key_with_nan(x, x_bits): 

291 tm, fm = _get_topmask_and_fullmask(x_bits) 

292 mask = tl.where((x_bits & tm) != 0, fm, tm) 

293 key = x_bits ^ mask 

294 return tl.where(x == x, key, fm) 

295 

296 @triton.jit 

297 def _key_to_fpval(x): 

298 tm, fm = _get_topmask_and_fullmask(x) 

299 mask = tl.where((x & tm) != 0, tm, fm) 

300 return x ^ mask 

301 

302 @libentry() 

303 @triton.jit 

304 def topk_kernel_radix_tle( 

305 X, 

306 Yv, 

307 Yi, 

308 stride_xm, 

309 stride_ym, 

310 n_cols, 

311 K: tl.constexpr, 

312 K_PAD: tl.constexpr, 

313 BLOCK_N: tl.constexpr, 

314 RADIX_BITS: tl.constexpr, 

315 ): 

316 pid = tl.program_id(0) 

317 x_dtype = X.dtype.element_ty 

318 x_nbits: tl.constexpr = x_dtype.primitive_bitwidth 

319 if x_nbits < 16: 

320 y_nbits: tl.constexpr = 32 

321 else: 

322 y_nbits: tl.constexpr = x_nbits * 2 

323 x_utype = tl.dtype(f"uint{x_nbits}") 

324 x_ultype = tl.dtype(f"uint{y_nbits}") 

325 

326 RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS 

327 RADIX_MASK: tl.constexpr = RADIX_SIZE - 1 

328 bins = tl.arange(0, RADIX_SIZE) 

329 one = tl.full([BLOCK_N], 1, tl.int32) 

330 

331 desired = tl.full((), 0, dtype=x_utype) 

332 desired_mask = tl.full((), 0, dtype=x_utype) 

333 k_to_find = tl.full((), K, dtype=tl.int32) 

334 n_tiles = tl.cdiv(n_cols, BLOCK_N) 

335 

336 smem_counts = tle_gpu.gpu.alloc( 

337 [RADIX_SIZE], 

338 dtype=tl.int32, 

339 layout=None, 

340 scope=tle_gpu.gpu.smem, 

341 nv_mma_shared_layout=False, 

342 ) 

343 smem_count_ptrs = tle_gpu.gpu.local_ptr(smem_counts, (bins,)) 

344 

345 for digit_pos in tl.static_range(x_nbits - RADIX_BITS, -1, -RADIX_BITS): 

346 tl.store(smem_count_ptrs, tl.zeros([RADIX_SIZE], dtype=tl.int32)) 

347 for t in tl.range(0, n_tiles): 

348 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

349 mask_n = offs_n < n_cols 

350 x_ptrs = X + pid * stride_xm + offs_n 

351 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

352 x_bits = x.to(x_utype, bitcast=True) 

353 x_key = _fpval_to_key_with_nan(x, x_bits) 

354 matches = (x_key & desired_mask) == desired 

355 digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32) 

356 valid = mask_n & matches 

357 count_addrs = tle_gpu.gpu.local_ptr(smem_counts, (digit,)) 

358 tl.atomic_add(count_addrs, one, mask=valid, sem="relaxed", scope="cta") 

359 

360 counts = tl.load(smem_count_ptrs) 

361 

362 cumsum_desc = tl.cumsum(counts, axis=0, reverse=True) 

363 tl.store(smem_count_ptrs, cumsum_desc) 

364 

365 selected_scalar = 0 

366 counts_gt_scalar = 0 

367 found = 0 

368 for rev in tl.static_range(RADIX_SIZE): 

369 d = RADIX_SIZE - 1 - rev 

370 cum_d = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d,))) 

371 if d + 1 < RADIX_SIZE: 

372 cum_next = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d + 1,))) 

373 else: 

374 cum_next = 0 

375 take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find) 

376 selected_scalar = tl.where(take, d, selected_scalar) 

377 counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar) 

378 found = tl.where(take, 1, found) 

379 

380 selected_u = selected_scalar.to(x_utype) 

381 desired = desired | (selected_u << digit_pos) 

382 desired_mask = desired_mask | ( 

383 tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos 

384 ) 

385 k_to_find = k_to_find - counts_gt_scalar 

386 

387 thr_key = desired 

388 

389 min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype) 

390 min_bits = min_val.to(x_utype, bitcast=True) 

391 min_key = _fpval_to_key_with_nan(min_val, min_bits) 

392 min_packed = min_key.to(x_ultype) << 16 

393 offs_k = tl.arange(0, K_PAD) 

394 

395 smem_selected = tle_gpu.gpu.alloc( 

396 [K_PAD], 

397 dtype=x_ultype, 

398 layout=None, 

399 scope=tle_gpu.gpu.smem, 

400 nv_mma_shared_layout=False, 

401 ) 

402 smem_selected_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (offs_k,)) 

403 tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype)) 

404 

405 smem_write_count = tle_gpu.gpu.alloc( 

406 [1], 

407 dtype=tl.int32, 

408 layout=None, 

409 scope=tle_gpu.gpu.smem, 

410 nv_mma_shared_layout=False, 

411 ) 

412 tl.store(tle_gpu.gpu.local_ptr(smem_write_count, (0,)), 0) 

413 write_count_ptrs = tle_gpu.gpu.local_ptr( 

414 smem_write_count, (tl.zeros([BLOCK_N], dtype=tl.int32),) 

415 ) 

416 

417 for t in tl.range(0, n_tiles): 

418 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

419 mask_n = offs_n < n_cols 

420 x_ptrs = X + pid * stride_xm + offs_n 

421 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

422 x_bits = x.to(x_utype, bitcast=True) 

423 x_key = _fpval_to_key_with_nan(x, x_bits) 

424 idx_key = (n_cols - offs_n).to(x_ultype) 

425 packed = (x_key.to(x_ultype) << 16) | idx_key 

426 take_gt = mask_n & (x_key > thr_key) 

427 pos = tl.atomic_add( 

428 write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta" 

429 ) 

430 write_mask = take_gt & (pos < K_PAD) 

431 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),)) 

432 tl.store(dst_ptrs, packed, mask=write_mask) 

433 

434 for t in tl.range(0, n_tiles): 

435 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

436 mask_n = offs_n < n_cols 

437 x_ptrs = X + pid * stride_xm + offs_n 

438 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

439 x_bits = x.to(x_utype, bitcast=True) 

440 x_key = _fpval_to_key_with_nan(x, x_bits) 

441 idx_key = (n_cols - offs_n).to(x_ultype) 

442 packed = (x_key.to(x_ultype) << 16) | idx_key 

443 take_eq = mask_n & (x_key == thr_key) 

444 pos = tl.atomic_add( 

445 write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta" 

446 ) 

447 write_mask = take_eq & (pos < K_PAD) 

448 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),)) 

449 tl.store(dst_ptrs, packed, mask=write_mask) 

450 

451 selected_packed = tl.load(smem_selected_ptrs) 

452 

453 topk = tl.sort(selected_packed, dim=0, descending=True) 

454 idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype) 

455 idx_raw = (topk & idx_mask).to(tl.uint32) 

456 y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32) 

457 y_values_raw = (topk >> 16).to(x_utype) 

458 y_values = _key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) 

459 

460 mask_k = offs_k < K 

461 yv_ptrs = Yv + pid * stride_ym + offs_k 

462 yi_ptrs = Yi + pid * stride_ym + offs_k 

463 tl.store(yv_ptrs, y_values, mask=mask_k) 

464 tl.store(yi_ptrs, y_indices, mask=mask_k) 

465 

466 

467def topk(x, k, dim=-1, largest=True, sorted=True): 

468 logger.debug("GEMS TOPK") 

469 # If dim equals to last dim, we set it to -1. 

470 if dim < 0: 

471 dim = dim + x.ndim 

472 

473 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

474 # assert sorted, "Currently only support sorted == True" 

475 

476 # Early return for k=0 to avoid Triton kernel compilation error. 

477 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0. 

478 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0. 

479 if k == 0: 

480 out_shape = list(x.shape[:-1]) + [0] 

481 return ( 

482 torch.empty(out_shape, device=x.device, dtype=x.dtype), 

483 torch.empty(out_shape, device=x.device, dtype=torch.int64), 

484 ) 

485 

486 descending = True 

487 if not largest: 

488 descending = False 

489 

490 topk_elem_cnt = x.shape[dim] 

491 batch_size = math.prod(x.shape) // topk_elem_cnt 

492 

493 if ( 

494 HAS_TLE 

495 and sorted 

496 and descending 

497 and x.is_cuda 

498 and x.dtype in (torch.float16, torch.float32, torch.bfloat16) 

499 and k >= 8 

500 and topk_elem_cnt <= 65535 

501 and triton.next_power_of_2(k) <= 1024 

502 ): 

503 k_pad = triton.next_power_of_2(k) 

504 out_shape = x.shape[:-1] + (k,) 

505 y_vals = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

506 y_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32) 

507 block_n_radix = max(k_pad, min(512, triton.next_power_of_2(topk_elem_cnt))) 

508 block_n_radix = min(block_n_radix, 1024) 

509 

510 x_2d = x.reshape(batch_size, topk_elem_cnt) 

511 y_vals_2d = y_vals.reshape(batch_size, k) 

512 y_idx_2d = y_idx.reshape(batch_size, k) 

513 with torch_device_fn.device(x.device): 

514 topk_kernel_radix_tle[(batch_size,)]( 

515 x_2d, 

516 y_vals_2d, 

517 y_idx_2d, 

518 x_2d.stride(0), 

519 y_vals_2d.stride(0), 

520 topk_elem_cnt, 

521 K=k, 

522 K_PAD=k_pad, 

523 BLOCK_N=block_n_radix, 

524 RADIX_BITS=4, 

525 num_warps=4, 

526 num_stages=1, 

527 ) 

528 return (y_vals, y_idx.to(torch.int64)) 

529 

530 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

531 if topk_elem_cnt < 1024: 

532 chunk_size = 256 

533 else: 

534 chunk_size = 1024 

535 

536 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

537 if chunk_size < k: 

538 chunk_size = triton.next_power_of_2(k) 

539 

540 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

541 

542 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) 

543 stage1_out_idx = torch.empty( 

544 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

545 ) 

546 

547 out_shape = x.shape[:-1] + (k,) 

548 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

549 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

550 

551 with torch_device_fn.device(x.device): 

552 topk_stage1_kernel[ 

553 batch_size, 

554 chunk_num, 

555 ]( 

556 stage1_out, # pointer to the output 

557 stage1_out_idx, # pointer to the output 

558 x, # pointer to the input 

559 k, 

560 topk_elem_cnt, 

561 chunk_size, 

562 descending, 

563 ) 

564 stage2_elem_cnt = chunk_num * k 

565 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

566 

567 with torch_device_fn.device(x.device): 

568 topk_stage2_kernel[batch_size,]( 

569 stage2_out, 

570 stage2_out_idx, 

571 stage1_out, 

572 stage1_out_idx, 

573 dim, 

574 k, 

575 stage2_elem_cnt, 

576 BLOCK_SIZE, 

577 descending, 

578 ) 

579 

580 return (stage2_out, stage2_out_idx)