Coverage for src/flag_gems/runtime/backend/_cambricon/ops/topk.py: 0%

319 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.ops.topk import topk_stage1_kernel, topk_stage2_kernel 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils.triton_version_utils import HAS_TLE 

12 

13if HAS_TLE: 

14 import triton.experimental.tle.language as tle_gpu 

15else: 

16 tle_gpu = None 

17 

18from ..utils import TOTAL_CORE_NUM 

19 

20logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

21_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

22_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

23_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

24_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

25_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

26_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

27_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

28_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

29_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

30_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

31_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

32_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

33 

34 

35@triton.jit 

36def _get_finfo_val( 

37 dtype, 

38 return_max, 

39): 

40 if dtype is tl.float32: 

41 if return_max: 

42 return _MAX_FLOAT32_VAL 

43 else: 

44 return _MIN_FLOAT32_VAL 

45 elif dtype is tl.float16: 

46 if return_max: 

47 return _MAX_FLOAT16_VAL 

48 else: 

49 return _MIN_FLOAT16_VAL 

50 elif dtype is tl.bfloat16: 

51 if return_max: 

52 return _MAX_BFLOAT16_VAL 

53 else: 

54 return _MIN_BFLOAT16_VAL 

55 

56 

57@triton.jit 

58def _get_iinfo_val( 

59 dtype, 

60 return_max, 

61): 

62 if dtype is tl.int16: 

63 if return_max: 

64 return _MAX_INT16_VAL 

65 else: 

66 return _MIN_INT16_VAL 

67 elif dtype is tl.int32: 

68 if return_max: 

69 return _MAX_INT32_VAL 

70 else: 

71 return _MIN_INT32_VAL 

72 elif dtype is tl.int64: 

73 if return_max: 

74 return _MAX_INT64_VAL 

75 else: 

76 return _MIN_INT64_VAL 

77 

78 

79@triton.jit 

80def get_topk_bubble_res( 

81 buffer, buffer_ind, k, axis, mask_val, DESCENDING, BLOCK_M, BLOCK_N 

82): 

83 kep_buffer_n = buffer 

84 topk_buffer_index_n = buffer_ind 

85 ret = tl.empty([BLOCK_M, k], dtype=buffer.dtype) 

86 ret_ind = tl.empty([BLOCK_M, k], dtype=buffer_ind.dtype) 

87 for k_ind in tl.range(0, k): 

88 if DESCENDING: 

89 sel_val, sel_index = tl.max(kep_buffer_n, axis=axis, return_indices=True) 

90 else: 

91 sel_val, sel_index = tl.min(kep_buffer_n, axis=axis, return_indices=True) 

92 

93 if BLOCK_M > 1: 

94 mask_sel = tl.arange(0, BLOCK_N)[None, :] == sel_index[:, None] 

95 tep_sel_index_buffer = tl.where(mask_sel, topk_buffer_index_n, 0) 

96 sel_index_res = tl.max(tep_sel_index_buffer, axis=axis) 

97 sel_val_res = sel_val 

98 ret[:, k_ind] = sel_val_res 

99 ret_ind[:, k_ind] = sel_index_res 

100 

101 # Update buffer. 

102 kep_buffer_n = tl.where(mask_sel, mask_val, kep_buffer_n) 

103 else: 

104 indices = sel_index[0] 

105 ret[:, k_ind] = sel_val 

106 ret_ind[:, k_ind] = topk_buffer_index_n[:, indices] 

107 # Update buffer. 

108 kep_buffer_n[:, indices] = mask_val 

109 return ret, ret_ind 

110 

111 

112BLOCK_BATCH = [1, 16] 

113BLOCK_N = [128, 512, 1024, 2048] 

114 

115 

116def topk_cfggen(): 

117 num_stage = [1, 3] 

118 configs = [ 

119 triton.Config({"TILE_M": m, "TILE_N": n}, num_warps=1, num_stages=s) 

120 for m in BLOCK_BATCH 

121 for n in BLOCK_N 

122 for s in num_stage 

123 ] 

124 return configs 

125 

126 

127def topk_config_prune(configs, named_args, **kwargs): 

128 k = named_args["k"] 

129 N = named_args["N"] 

130 block_m = named_args["BLOCK_M"] 

131 new_configs = [] 

132 

133 for config in configs: 

134 tile_n = config.kwargs["TILE_N"] 

135 tile_m = config.kwargs["TILE_M"] 

136 if tile_n < k or tile_m > block_m: 

137 continue 

138 if len(new_configs) >= 1: 

139 last_tn = new_configs[-1].kwargs["TILE_N"] 

140 last_tm = new_configs[-1].kwargs["TILE_M"] 

141 if tile_n > N and last_tn >= N and last_tm == tile_m: 

142 continue 

143 config.kwargs["TILE_M_NUM"] = triton.cdiv(block_m, tile_m) 

144 config.kwargs["TILE_N_NUM"] = triton.cdiv(N, tile_n) 

145 new_configs.append(config) 

146 

147 if (N not in BLOCK_N) and (N <= max(BLOCK_N)): 

148 for tm in BLOCK_BATCH: 

149 new_configs.append( 

150 triton.Config( 

151 { 

152 "TILE_M": tm, 

153 "TILE_N": N, 

154 "TILE_M_NUM": triton.cdiv(block_m, tm), 

155 "TILE_N_NUM": 1, 

156 }, 

157 num_warps=1, 

158 num_stages=3, 

159 ) 

160 ) 

161 return new_configs 

162 

163 

164@libentry() 

165@libtuner( 

166 configs=topk_cfggen(), 

167 key=["k", "N", "M", "BLOCK_M"], 

168 prune_configs_by={"early_config_prune": topk_config_prune}, 

169) 

170@triton.jit 

171def topk_bubble_kernel( 

172 inp_ptr, 

173 out_ptr, 

174 out_index_ptr, 

175 k: tl.constexpr, 

176 M: tl.constexpr, 

177 N: tl.constexpr, 

178 BLOCK_M: tl.constexpr, 

179 TILE_M: tl.constexpr, 

180 TILE_N: tl.constexpr, 

181 TILE_M_NUM: tl.constexpr, 

182 TILE_N_NUM: tl.constexpr, 

183 DESCENDING: tl.constexpr, 

184): 

185 pid = tl.program_id(0) 

186 m_st = pid * BLOCK_M 

187 

188 mask_val = _get_finfo_val(inp_ptr.dtype.element_ty, return_max=not DESCENDING) 

189 mask_val = mask_val.to(inp_ptr.dtype.element_ty) 

190 

191 for m_block_ind in tl.range(0, TILE_M_NUM): 

192 m_iter_st = m_block_ind * TILE_M + m_st 

193 m_offset_val = m_iter_st + tl.arange(0, TILE_M) 

194 m_offset = m_offset_val[:, None] 

195 m_offset_mask = m_offset < M 

196 

197 topk_buffer_n = tl.full( 

198 [TILE_M, TILE_N_NUM * k], value=mask_val, dtype=inp_ptr.dtype.element_ty 

199 ) 

200 topk_buffer_index_n = tl.full( 

201 [TILE_M, TILE_N_NUM * k], value=0, dtype=out_index_ptr.dtype.element_ty 

202 ) 

203 for n_block_ind in tl.range(0, TILE_N_NUM): 

204 n_st = n_block_ind * TILE_N 

205 n_offset = n_st + tl.arange(0, TILE_N)[None, :] 

206 n_offset_mask = n_offset < N 

207 

208 inp_mask = m_offset_mask & n_offset_mask 

209 inp_ptrs = inp_ptr + m_offset * N + n_offset 

210 block_inp_val = tl.load(inp_ptrs, mask=inp_mask, other=mask_val) 

211 

212 local_buffer, local_buffer_ind = get_topk_bubble_res( 

213 block_inp_val, 

214 n_offset.to(out_index_ptr.dtype.element_ty), 

215 k, 

216 1, 

217 mask_val, 

218 DESCENDING, 

219 TILE_M, 

220 TILE_N, 

221 ) 

222 tep_index = n_block_ind * k 

223 topk_buffer_n[:, tep_index : tep_index + k] = local_buffer 

224 topk_buffer_index_n[:, tep_index : tep_index + k] = local_buffer_ind 

225 if TILE_N_NUM > 1: 

226 global_res, global_res_ind = get_topk_bubble_res( 

227 topk_buffer_n, 

228 topk_buffer_index_n, 

229 k, 

230 1, 

231 mask_val, 

232 DESCENDING, 

233 TILE_M, 

234 TILE_N_NUM * k, 

235 ) 

236 else: 

237 global_res = topk_buffer_n 

238 global_res_ind = topk_buffer_index_n 

239 

240 # Store topk. 

241 store_ptrs = m_offset * k + tl.arange(0, k)[None, :] 

242 store_mask = m_offset_mask 

243 tl.store(store_ptrs + out_ptr, global_res, store_mask) 

244 tl.store(store_ptrs + out_index_ptr, global_res_ind, store_mask) 

245 

246 

247if HAS_TLE: 

248 

249 @triton.jit 

250 def _get_topmask_and_fullmask(x): 

251 tl.static_assert( 

252 x.dtype.is_int_unsigned(), 

253 "floating-point value must be passed as bits", 

254 ) 

255 tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) 

256 fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 

257 tm_arr = tl.full(x.shape, tm, dtype=x.dtype) 

258 fm_arr = tl.full(x.shape, fm, dtype=x.dtype) 

259 return tm_arr, fm_arr 

260 

261 @triton.jit 

262 def _fpval_to_key_with_nan(x, x_bits): 

263 tm, fm = _get_topmask_and_fullmask(x_bits) 

264 mask = tl.where((x_bits & tm) != 0, fm, tm) 

265 key = x_bits ^ mask 

266 return tl.where(x == x, key, fm) 

267 

268 @triton.jit 

269 def _key_to_fpval(x): 

270 tm, fm = _get_topmask_and_fullmask(x) 

271 mask = tl.where((x & tm) != 0, tm, fm) 

272 return x ^ mask 

273 

274 @libentry() 

275 @triton.jit 

276 def topk_kernel_radix_tle( 

277 X, 

278 Yv, 

279 Yi, 

280 stride_xm, 

281 stride_ym, 

282 n_cols, 

283 K: tl.constexpr, 

284 K_PAD: tl.constexpr, 

285 BLOCK_N: tl.constexpr, 

286 RADIX_BITS: tl.constexpr, 

287 ): 

288 pid = tl.program_id(0) 

289 x_dtype = X.dtype.element_ty 

290 x_nbits: tl.constexpr = x_dtype.primitive_bitwidth 

291 if x_nbits < 16: 

292 y_nbits: tl.constexpr = 32 

293 else: 

294 y_nbits: tl.constexpr = x_nbits * 2 

295 x_utype = tl.dtype(f"uint{x_nbits}") 

296 x_ultype = tl.dtype(f"uint{y_nbits}") 

297 

298 RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS 

299 RADIX_MASK: tl.constexpr = RADIX_SIZE - 1 

300 bins = tl.arange(0, RADIX_SIZE) 

301 one = tl.full([BLOCK_N], 1, tl.int32) 

302 

303 desired = tl.full((), 0, dtype=x_utype) 

304 desired_mask = tl.full((), 0, dtype=x_utype) 

305 k_to_find = tl.full((), K, dtype=tl.int32) 

306 n_tiles = tl.cdiv(n_cols, BLOCK_N) 

307 

308 smem_counts = tle_gpu.gpu.alloc( 

309 [RADIX_SIZE], 

310 dtype=tl.int32, 

311 layout=None, 

312 scope=tle_gpu.gpu.smem, 

313 nv_mma_shared_layout=False, 

314 ) 

315 smem_count_ptrs = tle_gpu.gpu.local_ptr(smem_counts, (bins,)) 

316 

317 for digit_pos in tl.static_range(x_nbits - RADIX_BITS, -1, -RADIX_BITS): 

318 tl.store(smem_count_ptrs, tl.zeros([RADIX_SIZE], dtype=tl.int32)) 

319 for t in tl.range(0, n_tiles): 

320 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

321 mask_n = offs_n < n_cols 

322 x_ptrs = X + pid * stride_xm + offs_n 

323 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

324 x_bits = x.to(x_utype, bitcast=True) 

325 x_key = _fpval_to_key_with_nan(x, x_bits) 

326 matches = (x_key & desired_mask) == desired 

327 digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32) 

328 valid = mask_n & matches 

329 count_addrs = tle_gpu.gpu.local_ptr(smem_counts, (digit,)) 

330 tl.atomic_add(count_addrs, one, mask=valid, sem="relaxed", scope="cta") 

331 

332 counts = tl.load(smem_count_ptrs) 

333 

334 cumsum_desc = tl.cumsum(counts, axis=0, reverse=True) 

335 tl.store(smem_count_ptrs, cumsum_desc) 

336 

337 selected_scalar = 0 

338 counts_gt_scalar = 0 

339 found = 0 

340 for rev in tl.static_range(RADIX_SIZE): 

341 d = RADIX_SIZE - 1 - rev 

342 cum_d = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d,))) 

343 if d + 1 < RADIX_SIZE: 

344 cum_next = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d + 1,))) 

345 else: 

346 cum_next = 0 

347 take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find) 

348 selected_scalar = tl.where(take, d, selected_scalar) 

349 counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar) 

350 found = tl.where(take, 1, found) 

351 

352 selected_u = selected_scalar.to(x_utype) 

353 desired = desired | (selected_u << digit_pos) 

354 desired_mask = desired_mask | ( 

355 tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos 

356 ) 

357 k_to_find = k_to_find - counts_gt_scalar 

358 

359 thr_key = desired 

360 

361 min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype) 

362 min_bits = min_val.to(x_utype, bitcast=True) 

363 min_key = _fpval_to_key_with_nan(min_val, min_bits) 

364 min_packed = min_key.to(x_ultype) << 16 

365 offs_k = tl.arange(0, K_PAD) 

366 

367 smem_selected = tle_gpu.gpu.alloc( 

368 [K_PAD], 

369 dtype=x_ultype, 

370 layout=None, 

371 scope=tle_gpu.gpu.smem, 

372 nv_mma_shared_layout=False, 

373 ) 

374 smem_selected_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (offs_k,)) 

375 tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype)) 

376 

377 smem_write_count = tle_gpu.gpu.alloc( 

378 [1], 

379 dtype=tl.int32, 

380 layout=None, 

381 scope=tle_gpu.gpu.smem, 

382 nv_mma_shared_layout=False, 

383 ) 

384 tl.store(tle_gpu.gpu.local_ptr(smem_write_count, (0,)), 0) 

385 write_count_ptrs = tle_gpu.gpu.local_ptr( 

386 smem_write_count, (tl.zeros([BLOCK_N], dtype=tl.int32),) 

387 ) 

388 

389 for t in tl.range(0, n_tiles): 

390 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

391 mask_n = offs_n < n_cols 

392 x_ptrs = X + pid * stride_xm + offs_n 

393 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

394 x_bits = x.to(x_utype, bitcast=True) 

395 x_key = _fpval_to_key_with_nan(x, x_bits) 

396 idx_key = (n_cols - offs_n).to(x_ultype) 

397 packed = (x_key.to(x_ultype) << 16) | idx_key 

398 take_gt = mask_n & (x_key > thr_key) 

399 pos = tl.atomic_add( 

400 write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta" 

401 ) 

402 write_mask = take_gt & (pos < K_PAD) 

403 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),)) 

404 tl.store(dst_ptrs, packed, mask=write_mask) 

405 

406 for t in tl.range(0, n_tiles): 

407 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

408 mask_n = offs_n < n_cols 

409 x_ptrs = X + pid * stride_xm + offs_n 

410 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

411 x_bits = x.to(x_utype, bitcast=True) 

412 x_key = _fpval_to_key_with_nan(x, x_bits) 

413 idx_key = (n_cols - offs_n).to(x_ultype) 

414 packed = (x_key.to(x_ultype) << 16) | idx_key 

415 take_eq = mask_n & (x_key == thr_key) 

416 pos = tl.atomic_add( 

417 write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta" 

418 ) 

419 write_mask = take_eq & (pos < K_PAD) 

420 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),)) 

421 tl.store(dst_ptrs, packed, mask=write_mask) 

422 

423 selected_packed = tl.load(smem_selected_ptrs) 

424 

425 topk = tl.sort(selected_packed, dim=0, descending=True) 

426 idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype) 

427 idx_raw = (topk & idx_mask).to(tl.uint32) 

428 y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32) 

429 y_values_raw = (topk >> 16).to(x_utype) 

430 y_values = _key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) 

431 

432 mask_k = offs_k < K 

433 yv_ptrs = Yv + pid * stride_ym + offs_k 

434 yi_ptrs = Yi + pid * stride_ym + offs_k 

435 tl.store(yv_ptrs, y_values, mask=mask_k) 

436 tl.store(yi_ptrs, y_indices, mask=mask_k) 

437 

438 

439def topk(x, k, dim=-1, largest=True, sorted=True): 

440 logger.debug("GEMS_CAMBRICON TOPK") 

441 # If dim equals to last dim, we set it to -1. 

442 if dim < 0: 

443 dim = dim + x.ndim 

444 

445 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

446 assert sorted, "Currently only support sorted == True" 

447 

448 # Early return for k=0 to avoid Triton kernel compilation error. 

449 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0. 

450 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0. 

451 if k == 0: 

452 out_shape = list(x.shape[:-1]) + [0] 

453 return ( 

454 torch.empty(out_shape, device=x.device, dtype=x.dtype), 

455 torch.empty(out_shape, device=x.device, dtype=torch.int64), 

456 ) 

457 

458 descending = True 

459 if not largest: 

460 descending = False 

461 

462 topk_elem_cnt = x.shape[dim] 

463 batch_size = math.prod(x.shape) // topk_elem_cnt 

464 out_shape = x.shape[:-1] + (k,) 

465 

466 if ( 

467 HAS_TLE 

468 and sorted 

469 and descending 

470 and x.is_cuda 

471 and x.dtype in (torch.float16, torch.float32, torch.bfloat16) 

472 and topk_elem_cnt <= 65535 

473 and triton.next_power_of_2(k) <= 1024 

474 ): 

475 k_pad = triton.next_power_of_2(k) 

476 out_shape = x.shape[:-1] + (k,) 

477 y_vals = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

478 y_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32) 

479 block_n_radix = max(k_pad, min(512, triton.next_power_of_2(topk_elem_cnt))) 

480 block_n_radix = min(block_n_radix, 1024) 

481 

482 x_2d = x.reshape(batch_size, topk_elem_cnt) 

483 with torch_device_fn.device(x.device): 

484 topk_kernel_radix_tle[(batch_size,)]( 

485 x_2d, 

486 y_vals, 

487 y_idx, 

488 x_2d.stride(0), 

489 y_vals.stride(0), 

490 topk_elem_cnt, 

491 K=k, 

492 K_PAD=k_pad, 

493 BLOCK_N=block_n_radix, 

494 RADIX_BITS=4, 

495 num_warps=4, 

496 num_stages=1, 

497 ) 

498 return (y_vals, y_idx.to(torch.int64)) 

499 

500 if k <= math.log2(topk_elem_cnt): 

501 logger.debug("GEMS_CAMBRICON TOPK USING BUBBLE") 

502 topk_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

503 topk_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

504 

505 def grid_fn(meta): 

506 return (min(batch_size, TOTAL_CORE_NUM),) 

507 

508 block_m = triton.cdiv(batch_size, TOTAL_CORE_NUM) 

509 topk_bubble_kernel[grid_fn]( 

510 x, 

511 topk_out, 

512 topk_out_idx, 

513 k, 

514 batch_size, 

515 topk_elem_cnt, 

516 block_m, 

517 DESCENDING=descending, 

518 ) 

519 return (topk_out, topk_out_idx) 

520 else: 

521 logger.debug("GEMS_CAMBRICON TOPK USING SORT") 

522 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

523 if topk_elem_cnt < 1024: 

524 chunk_size = 256 

525 else: 

526 chunk_size = 1024 

527 

528 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

529 if chunk_size < k: 

530 chunk_size = triton.next_power_of_2(k) 

531 

532 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

533 

534 stage1_out = torch.empty( 

535 batch_size * chunk_num * k, device=x.device, dtype=x.dtype 

536 ) 

537 stage1_out_idx = torch.empty( 

538 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

539 ) 

540 

541 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

542 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

543 

544 with torch_device_fn.device(x.device): 

545 topk_stage1_kernel[ 

546 batch_size, 

547 chunk_num, 

548 ]( 

549 stage1_out, # pointer to the output 

550 stage1_out_idx, # pointer to the output 

551 x, # pointer to the input 

552 k, 

553 topk_elem_cnt, 

554 chunk_size, 

555 descending, 

556 ) 

557 stage2_elem_cnt = chunk_num * k 

558 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

559 

560 with torch_device_fn.device(x.device): 

561 topk_stage2_kernel[batch_size,]( 

562 stage2_out, 

563 stage2_out_idx, 

564 stage1_out, 

565 stage1_out_idx, 

566 dim, 

567 k, 

568 stage2_elem_cnt, 

569 BLOCK_SIZE, 

570 descending, 

571 ) 

572 

573 return (stage2_out, stage2_out_idx)