Coverage for src/flag_gems/runtime/backend/_sunrise/ops/topk.py: 0%

323 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.language.core as core 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2 

13except ImportError: 

14 pass 

15 

16from flag_gems.runtime import torch_device_fn 

17from flag_gems.utils import libentry 

18from flag_gems.utils import triton_lang_extension as ext 

19from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

20from flag_gems.utils.triton_version_utils import HAS_TLE 

21 

22if HAS_TLE: 

23 import triton.experimental.tle.language as tle_gpu 

24else: 

25 tle_gpu = None 

26 

27logger = logging.getLogger(__name__) 

28_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

29_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

30_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

31_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

32_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

33_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

34_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

35_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

36_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

37_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

38_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

39_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

40_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

41_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

42 

43 

44@triton.jit 

45def _get_finfo_val( 

46 dtype, 

47 return_max, 

48): 

49 if dtype is tl.float32: 

50 if return_max: 

51 return _MAX_FLOAT32_VAL 

52 else: 

53 return _MIN_FLOAT32_VAL 

54 elif dtype is tl.float16: 

55 if return_max: 

56 return _MAX_FLOAT16_VAL 

57 else: 

58 return _MIN_FLOAT16_VAL 

59 elif dtype is tl.bfloat16: 

60 if return_max: 

61 return _MAX_BFLOAT16_VAL 

62 else: 

63 return _MIN_BFLOAT16_VAL 

64 

65 

66@triton.jit 

67def _get_iinfo_val( 

68 dtype, 

69 return_max, 

70): 

71 if return_max: 

72 return get_dtype_max(dtype) 

73 else: 

74 return get_dtype_min(dtype) 

75 

76 

77@libentry() 

78@triton.jit 

79def topk_stage1_kernel( 

80 y_ptr, 

81 index_ptr, 

82 x_ptr, 

83 k, 

84 N: tl.constexpr, 

85 CHUNK_SIZE: tl.constexpr, 

86 DESCENDING: tl.constexpr, 

87): 

88 cur_batch = ext.program_id(0) 

89 cur_chunk_idx = ext.program_id(1) 

90 chunk_num = ext.num_programs(1) 

91 

92 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

93 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

94 

95 chunk_offset = cur_chunk_idx * CHUNK_SIZE 

96 x_ptr += cur_batch * N + chunk_offset 

97 

98 cols = tl.arange(0, CHUNK_SIZE) 

99 mask = (chunk_offset + cols) < N 

100 

101 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING) 

102 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32) 

103 for k_idx in range(k): 

104 if DESCENDING: 

105 chunk_select_val = tl.max(x_val) 

106 chunk_select_idx = tl.argmax(x_val, axis=0) 

107 else: 

108 chunk_select_val = tl.min(x_val) 

109 chunk_select_idx = tl.argmin(x_val, axis=0) 

110 

111 tl.store(y_ptr + k_idx, chunk_select_val) 

112 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset) 

113 

114 if DESCENDING: 

115 x_val = tl.where( 

116 cols == chunk_select_idx, 

117 _get_finfo_val(tl.float32, return_max=False), 

118 x_val, 

119 ) 

120 else: 

121 x_val = tl.where( 

122 cols == chunk_select_idx, 

123 _get_finfo_val(tl.float32, return_max=True), 

124 x_val, 

125 ) 

126 

127 

128""" 

129Note(Zhengzekang): 

130Refer from triton2.2 official `sort` implementation: 

131https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

132Just add indices to sort with values. 

133""" 

134 

135 

136@triton.jit 

137def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

138 n_outer: core.constexpr = x.numel >> n_dims 

139 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

140 

141 # tl.device_print("shape is: ", shape) 

142 y = core.reshape(x, shape) 

143 y_idx = core.reshape(ids, shape) 

144 

145 # slice left/right with 'stride' 2**(n_dims - i - 1) 

146 mask = core.arange(0, 2)[None, :, None] 

147 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

148 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

149 left = core.reshape(left, x.shape) 

150 right = core.reshape(right, x.shape) 

151 

152 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

153 ids.dtype 

154 ) 

155 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

156 ids.dtype 

157 ) 

158 left_idx = core.reshape(left_idx, ids.shape) 

159 right_idx = core.reshape(right_idx, ids.shape) 

160 

161 # actual compare-and-swap 

162 # is_right indicator: 0 for left, 1 for right element in each pair. 

163 is_right = core.reshape( 

164 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

165 ) 

166 

167 # Paired value: for left (is_right=0), the paired is right; 

168 # for right (is_right=1), the paired is left. 

169 paired_val = core.where(is_right, left, right) 

170 paired_idx = core.where(is_right, left_idx, right_idx) 

171 

172 # Conditional swap following the official Triton pattern: 

173 # swap if (current > paired) differs from (flip ^ is_right). 

174 flip_right = (flip ^ is_right) != 0 

175 cond = (x > paired_val) != flip_right 

176 x = core.where(cond, paired_val, x) 

177 ids = core.where(cond, paired_idx, ids) 

178 

179 return x, ids 

180 

181 

182@triton.jit 

183def _bitonic_merge( 

184 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

185): 

186 """ 

187 order_type 0 == ascending 

188 order_type 1 == descending 

189 order_type 2 == alternating 

190 """ 

191 n_outer: core.constexpr = x.numel >> n_dims 

192 core.static_assert(stage <= n_dims) 

193 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

194 # descending order. 

195 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

196 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

197 # a stride of 2) at this stage 

198 if order == 2: 

199 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

200 flip = core.reshape( 

201 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

202 ) 

203 else: 

204 flip = order 

205 # perform `stage` rounds of `compare-and-swap` 

206 for i in core.static_range(stage): 

207 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

208 return x, ids 

209 

210 

211@triton.jit 

212def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

213 # handle default dimension or check that it is the most minor dim 

214 _dim: core.constexpr = dim 

215 n_dims: core.constexpr = _log2(x.shape[_dim]) 

216 for i in core.static_range(1, n_dims + 1): 

217 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

218 return x, ids 

219 

220 

221@libentry() 

222@triton.jit 

223def topk_stage2_kernel( 

224 y_ptr, 

225 index_ptr, 

226 chunk_x, 

227 chunk_index, 

228 sort_dim: tl.constexpr, 

229 k: tl.constexpr, 

230 N: tl.constexpr, 

231 BLOCK_SIZE: tl.constexpr, 

232 DESCENDING: tl.constexpr, 

233): 

234 cur_batch = ext.program_id(0) 

235 chunk_x += cur_batch * N 

236 chunk_index += cur_batch * N 

237 y_ptr += cur_batch * k 

238 index_ptr += cur_batch * k 

239 

240 cols = tl.arange(0, BLOCK_SIZE) 

241 mask = cols < N 

242 

243 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

244 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL 

245 

246 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32) 

247 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to( 

248 tl.int32 

249 ) 

250 

251 sorted_chunk_x, sorted_chunk_index = argsort( 

252 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

253 ) 

254 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k) 

255 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k) 

256 

257 

258if HAS_TLE: 

259 

260 @triton.jit 

261 def _get_topmask_and_fullmask(x): 

262 tl.static_assert( 

263 x.dtype.is_int_unsigned(), 

264 "floating-point value must be passed as bits", 

265 ) 

266 tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) 

267 fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 

268 tm_arr = tl.full(x.shape, tm, dtype=x.dtype) 

269 fm_arr = tl.full(x.shape, fm, dtype=x.dtype) 

270 return tm_arr, fm_arr 

271 

272 @triton.jit 

273 def _fpval_to_key_with_nan(x, x_bits): 

274 tm, fm = _get_topmask_and_fullmask(x_bits) 

275 mask = tl.where((x_bits & tm) != 0, fm, tm) 

276 key = x_bits ^ mask 

277 return tl.where(x == x, key, fm) 

278 

279 @triton.jit 

280 def _key_to_fpval(x): 

281 tm, fm = _get_topmask_and_fullmask(x) 

282 mask = tl.where((x & tm) != 0, tm, fm) 

283 return x ^ mask 

284 

285 @libentry() 

286 @triton.jit 

287 def topk_kernel_radix_tle( 

288 X, 

289 Yv, 

290 Yi, 

291 stride_xm, 

292 stride_ym, 

293 n_cols, 

294 K: tl.constexpr, 

295 K_PAD: tl.constexpr, 

296 BLOCK_N: tl.constexpr, 

297 RADIX_BITS: tl.constexpr, 

298 ): 

299 pid = tl.program_id(0) 

300 x_dtype = X.dtype.element_ty 

301 x_nbits: tl.constexpr = x_dtype.primitive_bitwidth 

302 if x_nbits < 16: 

303 y_nbits: tl.constexpr = 32 

304 else: 

305 y_nbits: tl.constexpr = x_nbits * 2 

306 x_utype = tl.dtype(f"uint{x_nbits}") 

307 x_ultype = tl.dtype(f"uint{y_nbits}") 

308 

309 RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS 

310 RADIX_MASK: tl.constexpr = RADIX_SIZE - 1 

311 bins = tl.arange(0, RADIX_SIZE) 

312 one = tl.full([BLOCK_N], 1, tl.int32) 

313 

314 desired = tl.full((), 0, dtype=x_utype) 

315 desired_mask = tl.full((), 0, dtype=x_utype) 

316 k_to_find = tl.full((), K, dtype=tl.int32) 

317 n_tiles = tl.cdiv(n_cols, BLOCK_N) 

318 

319 smem_counts = tle_gpu.gpu.alloc( 

320 [RADIX_SIZE], 

321 dtype=tl.int32, 

322 layout=None, 

323 scope=tle_gpu.gpu.smem, 

324 nv_mma_shared_layout=False, 

325 ) 

326 smem_count_ptrs = tle_gpu.gpu.local_ptr(smem_counts, (bins,)) 

327 

328 for digit_pos in tl.static_range(x_nbits - RADIX_BITS, -1, -RADIX_BITS): 

329 tl.store(smem_count_ptrs, tl.zeros([RADIX_SIZE], dtype=tl.int32)) 

330 for t in tl.range(0, n_tiles): 

331 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

332 mask_n = offs_n < n_cols 

333 x_ptrs = X + pid * stride_xm + offs_n 

334 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

335 x_bits = x.to(x_utype, bitcast=True) 

336 x_key = _fpval_to_key_with_nan(x, x_bits) 

337 matches = (x_key & desired_mask) == desired 

338 digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32) 

339 valid = mask_n & matches 

340 count_addrs = tle_gpu.gpu.local_ptr(smem_counts, (digit,)) 

341 tl.atomic_add(count_addrs, one, mask=valid, sem="relaxed", scope="cta") 

342 

343 counts = tl.load(smem_count_ptrs) 

344 

345 cumsum_desc = tl.cumsum(counts, axis=0, reverse=True) 

346 tl.store(smem_count_ptrs, cumsum_desc) 

347 

348 selected_scalar = 0 

349 counts_gt_scalar = 0 

350 found = 0 

351 for rev in tl.static_range(RADIX_SIZE): 

352 d = RADIX_SIZE - 1 - rev 

353 cum_d = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d,))) 

354 if d + 1 < RADIX_SIZE: 

355 cum_next = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d + 1,))) 

356 else: 

357 cum_next = 0 

358 take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find) 

359 selected_scalar = tl.where(take, d, selected_scalar) 

360 counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar) 

361 found = tl.where(take, 1, found) 

362 

363 selected_u = selected_scalar.to(x_utype) 

364 desired = desired | (selected_u << digit_pos) 

365 desired_mask = desired_mask | ( 

366 tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos 

367 ) 

368 k_to_find = k_to_find - counts_gt_scalar 

369 

370 thr_key = desired 

371 

372 min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype) 

373 min_bits = min_val.to(x_utype, bitcast=True) 

374 min_key = _fpval_to_key_with_nan(min_val, min_bits) 

375 min_packed = min_key.to(x_ultype) << 16 

376 offs_k = tl.arange(0, K_PAD) 

377 

378 smem_selected = tle_gpu.gpu.alloc( 

379 [K_PAD], 

380 dtype=x_ultype, 

381 layout=None, 

382 scope=tle_gpu.gpu.smem, 

383 nv_mma_shared_layout=False, 

384 ) 

385 smem_selected_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (offs_k,)) 

386 tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype)) 

387 

388 smem_write_count = tle_gpu.gpu.alloc( 

389 [1], 

390 dtype=tl.int32, 

391 layout=None, 

392 scope=tle_gpu.gpu.smem, 

393 nv_mma_shared_layout=False, 

394 ) 

395 tl.store(tle_gpu.gpu.local_ptr(smem_write_count, (0,)), 0) 

396 write_count_ptrs = tle_gpu.gpu.local_ptr( 

397 smem_write_count, (tl.zeros([BLOCK_N], dtype=tl.int32),) 

398 ) 

399 

400 for t in tl.range(0, n_tiles): 

401 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

402 mask_n = offs_n < n_cols 

403 x_ptrs = X + pid * stride_xm + offs_n 

404 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

405 x_bits = x.to(x_utype, bitcast=True) 

406 x_key = _fpval_to_key_with_nan(x, x_bits) 

407 idx_key = (n_cols - offs_n).to(x_ultype) 

408 packed = (x_key.to(x_ultype) << 16) | idx_key 

409 take_gt = mask_n & (x_key > thr_key) 

410 pos = tl.atomic_add( 

411 write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta" 

412 ) 

413 write_mask = take_gt & (pos < K_PAD) 

414 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),)) 

415 tl.store(dst_ptrs, packed, mask=write_mask) 

416 

417 for t in tl.range(0, n_tiles): 

418 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N) 

419 mask_n = offs_n < n_cols 

420 x_ptrs = X + pid * stride_xm + offs_n 

421 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf")) 

422 x_bits = x.to(x_utype, bitcast=True) 

423 x_key = _fpval_to_key_with_nan(x, x_bits) 

424 idx_key = (n_cols - offs_n).to(x_ultype) 

425 packed = (x_key.to(x_ultype) << 16) | idx_key 

426 take_eq = mask_n & (x_key == thr_key) 

427 pos = tl.atomic_add( 

428 write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta" 

429 ) 

430 write_mask = take_eq & (pos < K_PAD) 

431 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),)) 

432 tl.store(dst_ptrs, packed, mask=write_mask) 

433 

434 selected_packed = tl.load(smem_selected_ptrs) 

435 

436 topk = tl.sort(selected_packed, dim=0, descending=True) 

437 idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype) 

438 idx_raw = (topk & idx_mask).to(tl.uint32) 

439 y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32) 

440 y_values_raw = (topk >> 16).to(x_utype) 

441 y_values = _key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) 

442 

443 mask_k = offs_k < K 

444 yv_ptrs = Yv + pid * stride_ym + offs_k 

445 yi_ptrs = Yi + pid * stride_ym + offs_k 

446 tl.store(yv_ptrs, y_values, mask=mask_k) 

447 tl.store(yi_ptrs, y_indices, mask=mask_k) 

448 

449 

450def topk(x, k, dim=-1, largest=True, sorted=True): 

451 logger.debug("GEMS TOPK") 

452 # If dim equals to last dim, we set it to -1. 

453 if dim < 0: 

454 dim = dim + x.ndim 

455 

456 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

457 # assert sorted, "Currently only support sorted == True" 

458 

459 # Early return for k=0 to avoid Triton kernel compilation error. 

460 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0. 

461 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0. 

462 if k == 0: 

463 out_shape = list(x.shape[:-1]) + [0] 

464 return ( 

465 torch.empty(out_shape, device=x.device, dtype=x.dtype), 

466 torch.empty(out_shape, device=x.device, dtype=torch.int64), 

467 ) 

468 

469 descending = True 

470 if not largest: 

471 descending = False 

472 

473 topk_elem_cnt = x.shape[dim] 

474 batch_size = math.prod(x.shape) // topk_elem_cnt 

475 

476 if ( 

477 HAS_TLE 

478 and sorted 

479 and descending 

480 and x.is_cuda 

481 and x.dtype in (torch.float16, torch.float32, torch.bfloat16) 

482 and k >= 8 

483 and topk_elem_cnt <= 65535 

484 and triton.next_power_of_2(k) <= 1024 

485 ): 

486 k_pad = triton.next_power_of_2(k) 

487 out_shape = x.shape[:-1] + (k,) 

488 y_vals = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

489 y_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32) 

490 block_n_radix = max(k_pad, min(512, triton.next_power_of_2(topk_elem_cnt))) 

491 block_n_radix = min(block_n_radix, 1024) 

492 

493 x_2d = x.reshape(batch_size, topk_elem_cnt) 

494 y_vals_2d = y_vals.reshape(batch_size, k) 

495 y_idx_2d = y_idx.reshape(batch_size, k) 

496 with torch_device_fn.device(x.device): 

497 topk_kernel_radix_tle[(batch_size,)]( 

498 x_2d, 

499 y_vals_2d, 

500 y_idx_2d, 

501 x_2d.stride(0), 

502 y_vals_2d.stride(0), 

503 topk_elem_cnt, 

504 K=k, 

505 K_PAD=k_pad, 

506 BLOCK_N=block_n_radix, 

507 RADIX_BITS=4, 

508 num_warps=4, 

509 num_stages=1, 

510 ) 

511 return (y_vals, y_idx.to(torch.int64)) 

512 

513 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

514 if topk_elem_cnt < 1024: 

515 chunk_size = 256 

516 else: 

517 chunk_size = 1024 

518 

519 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

520 if chunk_size < k: 

521 chunk_size = triton.next_power_of_2(k) 

522 

523 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

524 

525 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) 

526 stage1_out_idx = torch.empty( 

527 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

528 ) 

529 

530 out_shape = x.shape[:-1] + (k,) 

531 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

532 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

533 

534 with torch_device_fn.device(x.device): 

535 topk_stage1_kernel[ 

536 batch_size, 

537 chunk_num, 

538 ]( 

539 stage1_out, # pointer to the output 

540 stage1_out_idx, # pointer to the output 

541 x, # pointer to the input 

542 k, 

543 topk_elem_cnt, 

544 chunk_size, 

545 descending, 

546 ) 

547 stage2_elem_cnt = chunk_num * k 

548 

549 candidate_vals = stage1_out.view(batch_size, stage2_elem_cnt) 

550 candidate_indices = stage1_out_idx.view(batch_size, stage2_elem_cnt) 

551 # [sunrise fix] hits incorrect results once the stage2 bitonic sort spills 

552 # into the multi-warp path (BLOCK_SIZE >= 512). Reduce the candidate set 

553 # with additional stage1 passes until the final sort stays within 256 lanes. 

554 """ 

555 1. topk_stage2_kernel设置num_warps=1可以绕过 ptpu 后端 multi-warp reduction 的共享内存 path bug, 

556 根因是 ptpu 后端在 ReduceOpToLLVM.cpp 的 cross-warp reduction 路径存在共享内存线性化偏移计算问题 

557 — 官方 tl.sort() 在 N≥512 时也有同样的错误。 

558 2. 问题不只是“通用 inter-warp reduce lowering”这一处;至少在 topk 的完整 bitonic sort 路径里,还有别的 multi-warp 交互在出错。 

559 最可能的下一步不是继续硬改通用 ReduceOp,而是针对 topk_stage2_kernel 的某个具体 stage 做精确复现, 

560 直接盯 _compare_and_swap 后几轮的 TTIR/LLVM IR 

561 """ 

562 safe_stage2_elem_cnt = 256 

563 reduction_chunk_size = max(256, triton.next_power_of_2(k + 1)) 

564 while ( 

565 k <= safe_stage2_elem_cnt 

566 and stage2_elem_cnt > safe_stage2_elem_cnt 

567 and triton.next_power_of_2(stage2_elem_cnt) > safe_stage2_elem_cnt 

568 ): 

569 round_chunk_size = min(stage2_elem_cnt, reduction_chunk_size) 

570 round_chunk_num = triton.cdiv(stage2_elem_cnt, round_chunk_size) 

571 reduced_elem_cnt = round_chunk_num * k 

572 

573 reduced_vals = torch.empty( 

574 batch_size * reduced_elem_cnt, device=x.device, dtype=x.dtype 

575 ) 

576 reduced_local_indices = torch.empty( 

577 batch_size * reduced_elem_cnt, device=x.device, dtype=torch.int64 

578 ) 

579 

580 with torch_device_fn.device(x.device): 

581 topk_stage1_kernel[ 

582 batch_size, 

583 round_chunk_num, 

584 ]( 

585 reduced_vals, 

586 reduced_local_indices, 

587 candidate_vals, 

588 k, 

589 stage2_elem_cnt, 

590 round_chunk_size, 

591 descending, 

592 ) 

593 

594 candidate_indices = torch.gather( 

595 candidate_indices, 

596 1, 

597 reduced_local_indices.view(batch_size, reduced_elem_cnt).to(torch.int64), 

598 ).contiguous() 

599 candidate_vals = reduced_vals.view(batch_size, reduced_elem_cnt) 

600 stage2_elem_cnt = reduced_elem_cnt 

601 

602 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

603 

604 with torch_device_fn.device(x.device): 

605 topk_stage2_kernel[batch_size,]( 

606 stage2_out, 

607 stage2_out_idx, 

608 candidate_vals, 

609 candidate_indices, 

610 dim, 

611 k, 

612 stage2_elem_cnt, 

613 BLOCK_SIZE, 

614 descending, 

615 ) 

616 

617 return (stage2_out, stage2_out_idx)