Coverage for src/flag_gems/fused/DSA/bin_topk.py: 6%

540 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.utils.triton_version_utils import has_triton_tle 

6 

7if has_triton_tle(3, 6, 0): 

8 try: 

9 import triton.experimental.tle.language as tle 

10 

11 HAS_TLE = True 

12 except ImportError: 

13 tle = None 

14 HAS_TLE = False 

15else: 

16 tle = None 

17 HAS_TLE = False 

18 

19 

20TLE_FIXED_BLOCK_SIZE = 1024 

21TLE_FIXED_NUM_WARPS = TLE_FIXED_BLOCK_SIZE // 32 

22TLE_FIXED_NUM_STAGES = 1 

23TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD = 12288 

24 

25 

26@triton.jit 

27def convert_to_uint16(x): 

28 bits_uint = convert_to_uint32(x) 

29 return ((bits_uint >> 24) & 0xFF).to(tl.uint16) 

30 

31 

32@triton.jit 

33def convert_to_uint32(x): 

34 bits_uint = x.to(tl.uint32, bitcast=True) 

35 bits_uint = tl.where( 

36 x < 0, 

37 ~bits_uint & tl.full(bits_uint.shape, 0xFFFFFFFF, tl.uint32), 

38 bits_uint | tl.full(bits_uint.shape, 0x80000000, tl.uint32), 

39 ) 

40 return bits_uint 

41 

42 

43@triton.autotune( 

44 configs=[ 

45 triton.Config({"BS": 32, "BSS": 32}, num_stages=1, num_warps=1), 

46 triton.Config({"BS": 64, "BSS": 32}, num_stages=1, num_warps=1), 

47 triton.Config({"BS": 128, "BSS": 32}, num_stages=2, num_warps=1), 

48 triton.Config({"BS": 256, "BSS": 32}, num_stages=2, num_warps=2), 

49 triton.Config({"BS": 512, "BSS": 64}, num_stages=2, num_warps=2), 

50 triton.Config({"BS": 1024, "BSS": 256}, num_stages=2, num_warps=2), 

51 triton.Config({"BS": 2048, "BSS": 256}, num_stages=2, num_warps=4), 

52 triton.Config({"BS": 4096, "BSS": 512}, num_stages=3, num_warps=4), 

53 triton.Config({"BS": 8192, "BSS": 512}, num_stages=3, num_warps=8), 

54 triton.Config({"BS": 8192, "BSS": 1024}, num_stages=3, num_warps=8), 

55 ], 

56 key=["S", "K"], 

57) 

58@triton.jit 

59def kernel_bucket_sort_topk( # grid(B, BS) 

60 inputs, # (B, S) Note: no H because MLA is based on MQA and MHA, not GQA 

61 indices, # (B, K) topk index array 

62 s_input_ids, # Data indices to be filtered in the next round 

63 starts, # for variable length 

64 ends, # for variable length 

65 S: tl.constexpr, # sequence length 

66 K: tl.constexpr, # k of topk 

67 HISTOGRAM_SIZE: tl.constexpr, 

68 SMEM_INPUT_SIZE: tl.constexpr, # to save candidates of next loop 

69 BS: tl.constexpr, # block size of S 

70 BSS: tl.constexpr, # block size of SMEM_INPUT 

71): 

72 # Get thread block id 

73 i_b = tl.program_id(0) 

74 

75 # Block base pointer definitions 

76 s_base = inputs + i_b * S 

77 indices_base = indices + i_b * K 

78 s_input_ids_base = s_input_ids + i_b * SMEM_INPUT_SIZE 

79 

80 # Histogram initialization 

81 s_histogram = tl.zeros([HISTOGRAM_SIZE], dtype=tl.int32) 

82 

83 # Support variable length 

84 l_start_idx = tl.load(starts + i_b).to(tl.int32) 

85 l_end_idx = tl.load(ends + i_b).to(tl.int32) 

86 

87 # Record how many positions remain to fill the topk array 

88 l_new_topk = K 

89 

90 TS = tl.cdiv(S, BS) 

91 for s in range(TS): 

92 input_idx = s * BS + tl.arange(0, BS) 

93 input_mask = ( 

94 (input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S) 

95 ) 

96 input = tl.load(s_base + input_idx, input_mask, other=float("-inf")).to( 

97 tl.float32 

98 ) 

99 inval_int16 = convert_to_uint16(input) 

100 s_histogram += inval_int16.to(tl.int32).histogram(HISTOGRAM_SIZE) 

101 

102 s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum 

103 

104 mv_idx = ( 

105 tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE 

106 ) # Construct offset index matrix 

107 

108 cond = (s_histogram > l_new_topk) & ( 

109 (s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0) 

110 ) 

111 l_threshold_bin_id = cond.argmax(0) 

112 

113 l_new_topk -= tl.where( 

114 tl.arange(0, HISTOGRAM_SIZE) == l_threshold_bin_id + 1, s_histogram, 0 

115 ).max(0) 

116 sum = 0 

117 thre_bin_sum = 0 

118 for s in range(TS): 

119 input_idx = s * BS + tl.arange(0, BS) 

120 input_mask = ( 

121 (input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S) 

122 ) 

123 input = tl.load(s_base + input_idx, input_mask, other=float("-inf")).to( 

124 tl.float32 

125 ) 

126 inval_int16 = convert_to_uint16(input) 

127 # inval_int16 = tl.where(input_mask, inval_int16, 0) 

128 # This method would slow down the speed, so using other=float("-inf") saves time. 

129 

130 over_thre = inval_int16.to(tl.int32) > l_threshold_bin_id 

131 cur_sum = over_thre.to(tl.int32).sum(-1) 

132 

133 eq_thre = inval_int16.to(tl.int32) == l_threshold_bin_id 

134 thre_bin_cur_sum = eq_thre.to(tl.int32).sum(-1) 

135 

136 topk_idx = over_thre.to(tl.int32).cumsum(-1) 

137 thre_bin_idx = eq_thre.to(tl.int32).cumsum(-1) 

138 

139 concat_mask = tl.cat(over_thre, eq_thre, True) 

140 concat_input = tl.cat(input_idx, input_idx, True) 

141 concat_pointer_matrix = tl.cat( 

142 indices_base + sum + topk_idx - 1, 

143 s_input_ids_base + thre_bin_sum + thre_bin_idx - 1, 

144 True, 

145 ) 

146 tl.store(concat_pointer_matrix, concat_input, mask=concat_mask) 

147 

148 thre_bin_sum += thre_bin_cur_sum 

149 sum += cur_sum 

150 

151 round = 0 

152 # print("l_new_topk:", l_new_topk) 

153 while round < 4 and l_new_topk > 0: 

154 ss = tl.cdiv(thre_bin_sum, BSS) 

155 s_histogram = tl.zeros([HISTOGRAM_SIZE], dtype=tl.int32) 

156 padding_num = 0.0 if round else float("-inf") 

157 # When round == 0, if the padding value is set to 0.0, the following problem occurs: 

158 # 

159 # 0.0 = 0x00000000, inval_int32(0x|00|000000, round=0) = 0x80 

160 # This causes the padding bucket to be larger than negative candidates, 

161 # thus being prioritized and assigned to the next bucket 

162 # or even directly into the topk sequence. 

163 # 

164 # However, if the padding value is set to "-inf": 

165 # float("-inf") = 0xFFFFE000, inval_int32(0x|FF|FFE000, round=0) = 0x00 

166 # This ensures the padding value is placed in the smallest bin, 

167 # not affecting the sorting of all normal candidate numbers before it. 

168 # 

169 # But when round > 0, if the padding value remains "-inf", the following problem occurs: 

170 # float("-inf") = 0xFFFFE000, inval_int32(0xFFFFE0|00|, round=3) = 0xFF 

171 # This causes the padding bucket to be larger than all values, 

172 # thus preferentially entering the topk sequence and causing errors. 

173 # Therefore, the padding value should be set to 0.0 

174 for s in range(ss): 

175 s_input_idx = s * BSS + tl.arange(0, BSS) 

176 s_input_idx_mask = s_input_idx < thre_bin_sum 

177 input_idx = tl.load( 

178 s_input_ids_base + s_input_idx, s_input_idx_mask, other=-1 

179 ) 

180 s_input_mask = s_input_idx_mask 

181 s_input = tl.load(s_base + input_idx, s_input_mask, other=padding_num).to( 

182 tl.float32 

183 ) 

184 inval_int32 = ( 

185 convert_to_uint32(s_input) >> (24 - round * 8) 

186 ) & 0xFF # Ensure all bits except the last eight are zero 

187 s_histogram += inval_int32.to(tl.int32).histogram(HISTOGRAM_SIZE) 

188 s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum 

189 mv_idx = ( 

190 tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE 

191 ) # Construct offset index matrix 

192 cond = (s_histogram > l_new_topk) & ( 

193 (s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0) 

194 ) 

195 l_threshold_bin_id = cond.argmax(0) 

196 l_new_topk -= tl.where( 

197 tl.arange(0, HISTOGRAM_SIZE) == l_threshold_bin_id + 1, s_histogram, 0 

198 ).max(0) 

199 thre_bin_sum, old_thre_bin_sum = 0, thre_bin_sum 

200 

201 for s in range(ss): 

202 s_input_idx = s * BSS + tl.arange(0, BSS) 

203 s_input_idx_mask = s_input_idx < old_thre_bin_sum 

204 input_idx = tl.load( 

205 s_input_ids_base + s_input_idx, s_input_idx_mask, other=-1 

206 ) 

207 s_input_mask = s_input_idx_mask 

208 s_input = tl.load(s_base + input_idx, s_input_mask, other=padding_num).to( 

209 tl.float32 

210 ) 

211 inval_int32 = (convert_to_uint32(s_input) >> (24 - round * 8)) & 0xFF 

212 

213 over_thre = inval_int32.to(tl.int32) > l_threshold_bin_id 

214 cur_sum = over_thre.to(tl.int32).sum(-1) 

215 eq_thre = inval_int32.to(tl.int32) == l_threshold_bin_id 

216 thre_bin_cur_sum = eq_thre.to(tl.int32).sum(-1) 

217 

218 topk_idx = over_thre.to(tl.int32).cumsum(-1) 

219 thre_bin_idx = eq_thre.to(tl.int32).cumsum(-1) 

220 

221 concat_mask = tl.cat(over_thre, eq_thre, True) 

222 concat_input = tl.cat(input_idx, input_idx, True) 

223 concat_pointer_matrix = tl.cat( 

224 indices_base + sum + topk_idx - 1, 

225 s_input_ids_base + thre_bin_sum + thre_bin_idx - 1, 

226 True, 

227 ) 

228 

229 tl.store(concat_pointer_matrix, concat_input, mask=concat_mask) 

230 

231 thre_bin_sum += thre_bin_cur_sum 

232 sum += cur_sum 

233 

234 round += 1 

235 

236 if l_new_topk > 0: 

237 ss = tl.cdiv(l_new_topk, BSS) 

238 for s in range(ss): 

239 s_input_idx = s * BSS + tl.arange(0, BSS) 

240 s_input_idx_mask = s_input_idx < l_new_topk 

241 input_idx = tl.load( 

242 s_input_ids_base + s_input_idx, s_input_idx_mask, other=-1 

243 ) 

244 s_input_mask = s_input_idx_mask 

245 tl.store( 

246 indices_base + sum + tl.arange(0, BSS), input_idx, mask=s_input_mask 

247 ) 

248 sum += BSS 

249 

250 

251def bucket_sort_topk_triton(inputs, starts, ends, topk): 

252 B, S = inputs.shape 

253 K = topk 

254 HISTOGRAM_SIZE = 256 

255 SMEM_INPUT_SIZE = 4096 

256 indices = torch.full((B, topk), -1, dtype=torch.int32, device=inputs.device) 

257 s_input_idx = torch.zeros( 

258 B, SMEM_INPUT_SIZE, dtype=torch.int32, device=inputs.device 

259 ) 

260 grid = (B,) 

261 kernel_bucket_sort_topk[grid]( 

262 inputs, 

263 indices, 

264 s_input_idx, 

265 starts, 

266 ends, 

267 S, 

268 K, 

269 HISTOGRAM_SIZE, 

270 SMEM_INPUT_SIZE, 

271 ) 

272 return indices 

273 

274 

275@triton.jit 

276def _convert_to_trt_uint32(x): 

277 bits = x.to(tl.uint32, bitcast=True) 

278 sign_mask = tl.full(bits.shape, 0x80000000, tl.uint32) 

279 sign_set = (bits & sign_mask) != 0 

280 inv = (~bits) & tl.full(bits.shape, 0x7FFFFFFF, tl.uint32) 

281 return tl.where(sign_set, bits, inv) 

282 

283 

284@triton.jit 

285def _convert_to_trt_uint16_hi11(x): 

286 h = x.to(tl.float16) 

287 bits = h.to(tl.uint16, bitcast=True) 

288 sign_mask = tl.full(bits.shape, 0x8000, tl.uint16) 

289 sign_set = (bits & sign_mask) != 0 

290 inv = (~bits) & tl.full(bits.shape, 0x7FFF, tl.uint16) 

291 mapped = tl.where(sign_set, bits, inv) 

292 return (mapped >> 5).to(tl.int32) 

293 

294 

295@triton.jit 

296def _tle_process_histogram_step( 

297 row_ptr, 

298 stride_xn, 

299 row_start, 

300 row_end, 

301 seq_len, 

302 step_idx: tl.constexpr, 

303 logit_pattern, 

304 s_step_thresholds_ptr, 

305 found_topk_values, 

306 hist_base_ptr, 

307 s_out_indices_ptr, 

308 s_final_cnt_ptr, 

309 s_found_topk_values_ptr, 

310 s_threshold_bin_idx_ptr, 

311 s_final_bin_size_ptr, 

312 assume_aligned, 

313 TOPK: tl.constexpr, 

314 BLOCK_SIZE: tl.constexpr, 

315): 

316 VEC: tl.constexpr = 4 

317 FINAL_SORT_ITEMS: tl.constexpr = 2048 

318 RADIX11_SIZE: tl.constexpr = 2048 

319 RADIX11_MASK: tl.constexpr = 0x7FF 

320 RADIX10_SIZE: tl.constexpr = 1024 

321 RADIX10_MASK: tl.constexpr = 0x3FF 

322 

323 lane = tl.arange(0, BLOCK_SIZE) 

324 vec = tl.arange(0, VEC) 

325 ones = tl.full([BLOCK_SIZE], 1, tl.int32) 

326 ones_vec_2d = tl.full([BLOCK_SIZE, VEC], 1, tl.int32) 

327 zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) 

328 zeros_vec_2d = tl.zeros([BLOCK_SIZE, VEC], dtype=tl.int32) 

329 

330 clear_rounds = tl.where( 

331 step_idx == 3, 

332 RADIX10_SIZE // BLOCK_SIZE, 

333 RADIX11_SIZE // BLOCK_SIZE, 

334 ) 

335 for clear_round in tl.range(0, clear_rounds): 

336 clear_bins = clear_round * BLOCK_SIZE + lane 

337 tl.store(hist_base_ptr + clear_bins, 0) 

338 tl.debug_barrier() 

339 

340 if step_idx == 2: 

341 step1_threshold = tl.load(s_step_thresholds_ptr + 1) 

342 logit_pattern = (step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21 

343 elif step_idx == 3: 

344 step1_threshold = tl.load(s_step_thresholds_ptr + 1) 

345 step2_threshold = tl.load(s_step_thresholds_ptr + 2) 

346 logit_pattern = ((step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21) | ( 

347 (step2_threshold.to(tl.uint32) & RADIX11_MASK) << 10 

348 ) 

349 

350 n_tiles = tl.cdiv(seq_len, BLOCK_SIZE) 

351 n_vec_full = seq_len // (BLOCK_SIZE * VEC) 

352 rem_tiles = (seq_len - n_vec_full * BLOCK_SIZE * VEC) // BLOCK_SIZE 

353 

354 if assume_aligned: 

355 for t in tl.range(0, n_vec_full): 

356 base = t * BLOCK_SIZE * VEC + lane * VEC 

357 offs = base[:, None] + vec[None, :] 

358 x_vec = tl.load(row_ptr + offs) 

359 key = _convert_to_trt_uint32(x_vec) 

360 if step_idx == 0: 

361 digit = _convert_to_trt_uint16_hi11(x_vec) 

362 elif step_idx == 1: 

363 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) 

364 elif step_idx == 2: 

365 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) 

366 else: 

367 digit = (key & RADIX10_MASK).to(tl.int32) 

368 

369 if step_idx < 2: 

370 partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1) 

371 elif step_idx == 2: 

372 partial = ((key ^ logit_pattern) >> 21) == 0 

373 else: 

374 partial = ((key ^ logit_pattern) >> 10) == 0 

375 

376 tl.atomic_add( 

377 hist_base_ptr + digit, 

378 ones_vec_2d, 

379 mask=partial, 

380 sem="relaxed", 

381 scope="cta", 

382 ) 

383 

384 for t in tl.range(0, rem_tiles): 

385 offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane 

386 x = tl.load(row_ptr + offs) 

387 key = _convert_to_trt_uint32(x) 

388 if step_idx == 0: 

389 digit = _convert_to_trt_uint16_hi11(x) 

390 elif step_idx == 1: 

391 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) 

392 elif step_idx == 2: 

393 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) 

394 else: 

395 digit = (key & RADIX10_MASK).to(tl.int32) 

396 

397 if step_idx < 2: 

398 partial = tl.full([BLOCK_SIZE], True, tl.int1) 

399 elif step_idx == 2: 

400 partial = ((key ^ logit_pattern) >> 21) == 0 

401 else: 

402 partial = ((key ^ logit_pattern) >> 10) == 0 

403 

404 tl.atomic_add( 

405 hist_base_ptr + digit, 

406 ones, 

407 mask=partial, 

408 sem="relaxed", 

409 scope="cta", 

410 ) 

411 else: 

412 for t in tl.range(0, n_tiles): 

413 offs = t * BLOCK_SIZE + lane 

414 in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) 

415 x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) 

416 key = _convert_to_trt_uint32(x) 

417 if step_idx == 0: 

418 digit = _convert_to_trt_uint16_hi11(x) 

419 elif step_idx == 1: 

420 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) 

421 elif step_idx == 2: 

422 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) 

423 else: 

424 digit = (key & RADIX10_MASK).to(tl.int32) 

425 

426 if step_idx < 2: 

427 partial = in_range 

428 elif step_idx == 2: 

429 partial = in_range & (((key ^ logit_pattern) >> 21) == 0) 

430 else: 

431 partial = in_range & (((key ^ logit_pattern) >> 10) == 0) 

432 

433 tl.atomic_add( 

434 hist_base_ptr + digit, 

435 ones, 

436 mask=partial, 

437 sem="relaxed", 

438 scope="cta", 

439 ) 

440 tl.debug_barrier() 

441 

442 tl.store(s_threshold_bin_idx_ptr, -1) 

443 tl.store(s_final_bin_size_ptr, 0) 

444 threshold_bin_ptrs = s_threshold_bin_idx_ptr + zeros 

445 final_bin_size_ptrs = s_final_bin_size_ptr + zeros 

446 last_value = found_topk_values 

447 threshold_found = False 

448 threshold_rounds = tl.where( 

449 step_idx == 3, 

450 RADIX10_SIZE // BLOCK_SIZE, 

451 RADIX11_SIZE // BLOCK_SIZE, 

452 ) 

453 for round_idx in tl.range(0, threshold_rounds): 

454 if not threshold_found: 

455 bins = round_idx * BLOCK_SIZE + lane 

456 counts = tl.load(hist_base_ptr + bins) 

457 prefix_sum, counts_total = tle.cumsum(counts, axis=0, reverse=False) 

458 prefix_sum = prefix_sum + last_value 

459 total_sum = last_value + counts_total 

460 next_prefix_sum = prefix_sum + counts 

461 threshold_mask = (prefix_sum < TOPK) & (next_prefix_sum >= TOPK) 

462 threshold_bin = bins 

463 threshold_bin_size = next_prefix_sum - prefix_sum 

464 tl.store(threshold_bin_ptrs, threshold_bin, mask=threshold_mask) 

465 tl.store(final_bin_size_ptrs, threshold_bin_size, mask=threshold_mask) 

466 found_round = tl.reduce_or(threshold_mask, axis=0) 

467 threshold_found = found_round 

468 last_value = total_sum 

469 

470 threshold_bin_idx = tl.load(s_threshold_bin_idx_ptr) 

471 final_bin_size = tl.load(s_final_bin_size_ptr) 

472 tl.store(s_step_thresholds_ptr + step_idx, threshold_bin_idx) 

473 

474 use_final = ( 

475 (step_idx < 3) & (threshold_bin_idx >= 0) & (final_bin_size <= FINAL_SORT_ITEMS) 

476 ) 

477 if use_final: 

478 tl.store(s_final_cnt_ptr, 0) 

479 

480 found_ptrs = s_found_topk_values_ptr + zeros 

481 final_cnt_ptrs = s_final_cnt_ptr + zeros 

482 if assume_aligned: 

483 found_ptrs_vec_2d = s_found_topk_values_ptr + zeros_vec_2d 

484 final_cnt_ptrs_vec_2d = s_final_cnt_ptr + zeros_vec_2d 

485 for t in tl.range(0, n_vec_full): 

486 base = t * BLOCK_SIZE * VEC + lane * VEC 

487 offs = base[:, None] + vec[None, :] 

488 x_vec = tl.load(row_ptr + offs) 

489 key = _convert_to_trt_uint32(x_vec) 

490 if step_idx == 0: 

491 digit = _convert_to_trt_uint16_hi11(x_vec) 

492 elif step_idx == 1: 

493 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) 

494 elif step_idx == 2: 

495 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) 

496 else: 

497 digit = (key & RADIX10_MASK).to(tl.int32) 

498 

499 if step_idx < 2: 

500 partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1) 

501 elif step_idx == 2: 

502 partial = ((key ^ logit_pattern) >> 21) == 0 

503 else: 

504 partial = ((key ^ logit_pattern) >> 10) == 0 

505 

506 take_lt = partial & (digit < threshold_bin_idx) 

507 out_pos_lt = tl.atomic_add( 

508 found_ptrs_vec_2d, 

509 ones_vec_2d, 

510 mask=take_lt, 

511 sem="relaxed", 

512 scope="cta", 

513 ) 

514 tl.store( 

515 s_out_indices_ptr + out_pos_lt, 

516 offs.to(tl.int32), 

517 mask=take_lt & (out_pos_lt < TOPK), 

518 ) 

519 

520 if step_idx == 3: 

521 take_eq = partial & (digit == threshold_bin_idx) 

522 out_pos_eq = tl.atomic_add( 

523 hist_base_ptr + digit, 

524 ones_vec_2d, 

525 mask=take_eq, 

526 sem="relaxed", 

527 scope="cta", 

528 ) 

529 tl.store( 

530 s_out_indices_ptr + out_pos_eq, 

531 offs.to(tl.int32), 

532 mask=take_eq & (out_pos_eq < TOPK), 

533 ) 

534 elif use_final: 

535 take_eq_final = partial & (digit == threshold_bin_idx) 

536 final_pos = tl.atomic_add( 

537 final_cnt_ptrs_vec_2d, 

538 ones_vec_2d, 

539 mask=take_eq_final, 

540 sem="relaxed", 

541 scope="cta", 

542 ) 

543 tl.store( 

544 hist_base_ptr + final_pos, 

545 offs.to(tl.int32), 

546 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), 

547 ) 

548 tl.store( 

549 hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), 

550 x_vec.to(tl.int32, bitcast=True), 

551 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), 

552 ) 

553 

554 for t in tl.range(0, rem_tiles): 

555 offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane 

556 x = tl.load(row_ptr + offs) 

557 key = _convert_to_trt_uint32(x) 

558 if step_idx == 0: 

559 digit = _convert_to_trt_uint16_hi11(x) 

560 elif step_idx == 1: 

561 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) 

562 elif step_idx == 2: 

563 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) 

564 else: 

565 digit = (key & RADIX10_MASK).to(tl.int32) 

566 

567 if step_idx < 2: 

568 partial = tl.full([BLOCK_SIZE], True, tl.int1) 

569 elif step_idx == 2: 

570 partial = ((key ^ logit_pattern) >> 21) == 0 

571 else: 

572 partial = ((key ^ logit_pattern) >> 10) == 0 

573 

574 take_lt = partial & (digit < threshold_bin_idx) 

575 out_pos_lt = tl.atomic_add( 

576 found_ptrs, 

577 ones, 

578 mask=take_lt, 

579 sem="relaxed", 

580 scope="cta", 

581 ) 

582 tl.store( 

583 s_out_indices_ptr + out_pos_lt, 

584 offs.to(tl.int32), 

585 mask=take_lt & (out_pos_lt < TOPK), 

586 ) 

587 

588 if step_idx == 3: 

589 take_eq = partial & (digit == threshold_bin_idx) 

590 out_pos_eq = tl.atomic_add( 

591 hist_base_ptr + digit, 

592 ones, 

593 mask=take_eq, 

594 sem="relaxed", 

595 scope="cta", 

596 ) 

597 tl.store( 

598 s_out_indices_ptr + out_pos_eq, 

599 offs.to(tl.int32), 

600 mask=take_eq & (out_pos_eq < TOPK), 

601 ) 

602 elif use_final: 

603 take_eq_final = partial & (digit == threshold_bin_idx) 

604 final_pos = tl.atomic_add( 

605 final_cnt_ptrs, 

606 ones, 

607 mask=take_eq_final, 

608 sem="relaxed", 

609 scope="cta", 

610 ) 

611 tl.store( 

612 hist_base_ptr + final_pos, 

613 offs.to(tl.int32), 

614 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), 

615 ) 

616 tl.store( 

617 hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), 

618 x.to(tl.int32, bitcast=True), 

619 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), 

620 ) 

621 else: 

622 for t in tl.range(0, n_tiles): 

623 offs = t * BLOCK_SIZE + lane 

624 in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end) 

625 x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf")) 

626 key = _convert_to_trt_uint32(x) 

627 if step_idx == 0: 

628 digit = _convert_to_trt_uint16_hi11(x) 

629 elif step_idx == 1: 

630 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32) 

631 elif step_idx == 2: 

632 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32) 

633 else: 

634 digit = (key & RADIX10_MASK).to(tl.int32) 

635 

636 if step_idx < 2: 

637 partial = in_range 

638 elif step_idx == 2: 

639 partial = in_range & (((key ^ logit_pattern) >> 21) == 0) 

640 else: 

641 partial = in_range & (((key ^ logit_pattern) >> 10) == 0) 

642 

643 take_lt = partial & (digit < threshold_bin_idx) 

644 out_pos_lt = tl.atomic_add( 

645 found_ptrs, 

646 ones, 

647 mask=take_lt, 

648 sem="relaxed", 

649 scope="cta", 

650 ) 

651 tl.store( 

652 s_out_indices_ptr + out_pos_lt, 

653 offs.to(tl.int32), 

654 mask=take_lt & (out_pos_lt < TOPK), 

655 ) 

656 

657 if step_idx == 3: 

658 take_eq = partial & (digit == threshold_bin_idx) 

659 out_pos_eq = tl.atomic_add( 

660 hist_base_ptr + digit, 

661 ones, 

662 mask=take_eq, 

663 sem="relaxed", 

664 scope="cta", 

665 ) 

666 tl.store( 

667 s_out_indices_ptr + out_pos_eq, 

668 offs.to(tl.int32), 

669 mask=take_eq & (out_pos_eq < TOPK), 

670 ) 

671 elif use_final: 

672 take_eq_final = partial & (digit == threshold_bin_idx) 

673 final_pos = tl.atomic_add( 

674 final_cnt_ptrs, 

675 ones, 

676 mask=take_eq_final, 

677 sem="relaxed", 

678 scope="cta", 

679 ) 

680 tl.store( 

681 hist_base_ptr + final_pos, 

682 offs.to(tl.int32), 

683 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), 

684 ) 

685 tl.store( 

686 hist_base_ptr + (FINAL_SORT_ITEMS + final_pos), 

687 x.to(tl.int32, bitcast=True), 

688 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS), 

689 ) 

690 

691 if step_idx < 3: 

692 if use_final: 

693 need_final_sort = True 

694 continue_to_next_step = False 

695 else: 

696 need_final_sort = False 

697 continue_to_next_step = True 

698 else: 

699 tl.store(s_found_topk_values_ptr, TOPK) 

700 need_final_sort = False 

701 continue_to_next_step = False 

702 

703 tl.debug_barrier() 

704 return continue_to_next_step, need_final_sort, logit_pattern 

705 

706 

707@triton.jit 

708def _tle_final_select_radix( 

709 hist_base_ptr, 

710 s_out_indices_ptr, 

711 s_final_cnt_ptr, 

712 s_found_topk_values_ptr, 

713 TOPK: tl.constexpr, 

714 BLOCK_SIZE: tl.constexpr, 

715 FINAL_SORT_ITEMS: tl.constexpr, 

716): 

717 RADIX_BITS_FINAL: tl.constexpr = 8 

718 RADIX_SIZE_FINAL: tl.constexpr = 1 << RADIX_BITS_FINAL 

719 RADIX_MASK_FINAL: tl.constexpr = RADIX_SIZE_FINAL - 1 

720 DIGIT_START: tl.constexpr = 32 - RADIX_BITS_FINAL 

721 

722 lane = tl.arange(0, BLOCK_SIZE) 

723 ones = tl.full([BLOCK_SIZE], 1, tl.int32) 

724 zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32) 

725 bins = tl.arange(0, RADIX_SIZE_FINAL) 

726 

727 s_radix_counts = tle.gpu.alloc( 

728 [RADIX_SIZE_FINAL], 

729 dtype=tl.int32, 

730 layout=None, 

731 scope=tle.gpu.smem, 

732 nv_mma_shared_layout=False, 

733 ) 

734 radix_count_ptr = tle.gpu.local_ptr(s_radix_counts, (0,)) 

735 radix_count_vec_ptr = tle.gpu.local_ptr(s_radix_counts, (bins,)) 

736 

737 base_idx = tl.load(s_found_topk_values_ptr) 

738 final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS) 

739 remain = tl.minimum(TOPK - base_idx, final_cnt) 

740 if remain > 0: 

741 desired = tl.zeros((), dtype=tl.uint32) 

742 desired_mask = tl.zeros((), dtype=tl.uint32) 

743 k_to_find = remain + 1 

744 

745 for digit_pos in tl.static_range(DIGIT_START, -1, -RADIX_BITS_FINAL): 

746 tl.store(radix_count_ptr + lane, 0, mask=lane < RADIX_SIZE_FINAL) 

747 tl.debug_barrier() 

748 

749 cnt_tiles = tl.cdiv(final_cnt, BLOCK_SIZE) 

750 for t in tl.range(0, cnt_tiles): 

751 pos = t * BLOCK_SIZE + lane 

752 valid = pos < final_cnt 

753 x_bits_i32 = tl.load( 

754 hist_base_ptr + (FINAL_SORT_ITEMS + pos), 

755 mask=valid, 

756 other=0, 

757 ) 

758 x = x_bits_i32.to(tl.float32, bitcast=True) 

759 key = _convert_to_trt_uint32(x) 

760 matches = (key & desired_mask) == desired 

761 digit = ((key >> digit_pos) & RADIX_MASK_FINAL).to(tl.int32) 

762 take = valid & matches 

763 tl.atomic_add( 

764 radix_count_ptr + digit, 

765 ones, 

766 mask=take, 

767 sem="relaxed", 

768 scope="cta", 

769 ) 

770 

771 tl.debug_barrier() 

772 counts = tl.load(radix_count_vec_ptr) 

773 prefix_sum, _ = tle.cumsum(counts, axis=0, reverse=False) 

774 next_prefix_sum = prefix_sum + counts 

775 threshold_mask = (prefix_sum < k_to_find) & (next_prefix_sum >= k_to_find) 

776 threshold_init = tl.full((), RADIX_SIZE_FINAL, dtype=tl.int32) 

777 threshold_bin = tl.min( 

778 tl.where(threshold_mask, bins, threshold_init), axis=0 

779 ).to(tl.int32) 

780 threshold_bin = tl.where( 

781 threshold_bin == RADIX_SIZE_FINAL, 

782 RADIX_SIZE_FINAL - 1, 

783 threshold_bin, 

784 ) 

785 counts_lt = tl.max( 

786 tl.where(bins == threshold_bin, prefix_sum, 0), 

787 axis=0, 

788 ).to(tl.int32) 

789 

790 desired = desired | (threshold_bin.to(tl.uint32) << digit_pos) 

791 desired_mask = desired_mask | ( 

792 tl.full((), RADIX_MASK_FINAL, dtype=tl.uint32) << digit_pos 

793 ) 

794 k_to_find = k_to_find - counts_lt 

795 

796 thr_key = desired 

797 found_ptrs = s_found_topk_values_ptr + zeros 

798 cnt_tiles = tl.cdiv(final_cnt, BLOCK_SIZE) 

799 for t in tl.range(0, cnt_tiles): 

800 pos = t * BLOCK_SIZE + lane 

801 valid = pos < final_cnt 

802 idx = tl.load(hist_base_ptr + pos, mask=valid, other=0) 

803 x_bits_i32 = tl.load( 

804 hist_base_ptr + (FINAL_SORT_ITEMS + pos), 

805 mask=valid, 

806 other=0, 

807 ) 

808 x = x_bits_i32.to(tl.float32, bitcast=True) 

809 key = _convert_to_trt_uint32(x) 

810 take_lt = valid & (key < thr_key) 

811 out_pos_gt = tl.atomic_add( 

812 found_ptrs, 

813 ones, 

814 mask=take_lt, 

815 sem="relaxed", 

816 scope="cta", 

817 ) 

818 tl.store( 

819 s_out_indices_ptr + out_pos_gt, 

820 idx, 

821 mask=take_lt & (out_pos_gt < TOPK), 

822 ) 

823 

824 cur = tl.load(s_found_topk_values_ptr) 

825 if cur < TOPK: 

826 for t in tl.range(0, cnt_tiles): 

827 cur = tl.load(s_found_topk_values_ptr) 

828 if cur < TOPK: 

829 pos = t * BLOCK_SIZE + lane 

830 valid = pos < final_cnt 

831 idx = tl.load(hist_base_ptr + pos, mask=valid, other=0) 

832 x_bits_i32 = tl.load( 

833 hist_base_ptr + (FINAL_SORT_ITEMS + pos), 

834 mask=valid, 

835 other=0, 

836 ) 

837 x = x_bits_i32.to(tl.float32, bitcast=True) 

838 key = _convert_to_trt_uint32(x) 

839 take_eq = valid & (key == thr_key) 

840 out_pos_eq = tl.atomic_add( 

841 found_ptrs, 

842 ones, 

843 mask=take_eq, 

844 sem="relaxed", 

845 scope="cta", 

846 ) 

847 tl.store( 

848 s_out_indices_ptr + out_pos_eq, 

849 idx, 

850 mask=take_eq & (out_pos_eq < TOPK), 

851 ) 

852 

853 tl.store(s_found_topk_values_ptr, TOPK) 

854 

855 

856@triton.jit 

857def kernel_tle_bucket_sort_topk( 

858 x_ptr, 

859 out_ptr, 

860 starts_ptr, 

861 ends_ptr, 

862 stride_xm, 

863 stride_xn, 

864 stride_outm, 

865 stride_outn, 

866 seq_len, 

867 K: tl.constexpr, 

868 BLOCK_SIZE: tl.constexpr, 

869 USE_RADIX_FINAL: tl.constexpr, 

870): 

871 pid = tl.program_id(0) 

872 row_start = tl.load(starts_ptr + pid).to(tl.int32) 

873 row_end = tl.load(ends_ptr + pid).to(tl.int32) 

874 

875 row_ptr = x_ptr + pid * stride_xm 

876 out_row = out_ptr + pid * stride_outm 

877 row_len = row_end - row_start 

878 

879 auto_aligned = ( 

880 (stride_xn == 1) 

881 & (stride_outn == 1) 

882 & (row_start == 0) 

883 & (row_end == seq_len) 

884 & (seq_len % BLOCK_SIZE == 0) 

885 ) 

886 assume_aligned = auto_aligned 

887 if assume_aligned: 

888 seq_len = tl.multiple_of(seq_len, BLOCK_SIZE) 

889 

890 lane = tl.arange(0, BLOCK_SIZE) 

891 if row_len <= K: 

892 chunks: tl.constexpr = (K + BLOCK_SIZE - 1) // BLOCK_SIZE 

893 for chunk_idx in tl.range(0, chunks): 

894 pos = chunk_idx * BLOCK_SIZE + lane 

895 take_row = pos < row_len 

896 tl.store( 

897 out_row + pos * stride_outn, 

898 (row_start + pos).to(tl.int32), 

899 mask=take_row, 

900 ) 

901 take_pad = (pos >= row_len) & (pos < K) 

902 tl.store(out_row + pos * stride_outn, -1, mask=take_pad) 

903 return 

904 

905 FINAL_SORT_ITEMS: tl.constexpr = 2048 

906 HIST_SIZE: tl.constexpr = 4096 

907 

908 s_histogram = tle.gpu.alloc( 

909 [HIST_SIZE], 

910 dtype=tl.int32, 

911 layout=None, 

912 scope=tle.gpu.smem, 

913 nv_mma_shared_layout=False, 

914 ) 

915 hist_base_ptr = tle.gpu.local_ptr(s_histogram, (0,)) 

916 s_out_indices = tle.gpu.alloc( 

917 [K], 

918 dtype=tl.int32, 

919 layout=None, 

920 scope=tle.gpu.smem, 

921 nv_mma_shared_layout=False, 

922 ) 

923 s_final_cnt = tle.gpu.alloc( 

924 [1], 

925 dtype=tl.int32, 

926 layout=None, 

927 scope=tle.gpu.smem, 

928 nv_mma_shared_layout=False, 

929 ) 

930 s_threshold_bin_idx = tle.gpu.alloc( 

931 [1], 

932 dtype=tl.int32, 

933 layout=None, 

934 scope=tle.gpu.smem, 

935 nv_mma_shared_layout=False, 

936 ) 

937 s_final_bin_size = tle.gpu.alloc( 

938 [1], 

939 dtype=tl.int32, 

940 layout=None, 

941 scope=tle.gpu.smem, 

942 nv_mma_shared_layout=False, 

943 ) 

944 s_found_topk_values = tle.gpu.alloc( 

945 [1], 

946 dtype=tl.int32, 

947 layout=None, 

948 scope=tle.gpu.smem, 

949 nv_mma_shared_layout=False, 

950 ) 

951 s_step_thresholds = tle.gpu.alloc( 

952 [4], 

953 dtype=tl.int32, 

954 layout=None, 

955 scope=tle.gpu.smem, 

956 nv_mma_shared_layout=False, 

957 ) 

958 s_final_cnt_ptr = tle.gpu.local_ptr(s_final_cnt, (0,)) 

959 s_threshold_bin_idx_ptr = tle.gpu.local_ptr(s_threshold_bin_idx, (0,)) 

960 s_final_bin_size_ptr = tle.gpu.local_ptr(s_final_bin_size, (0,)) 

961 s_found_topk_values_ptr = tle.gpu.local_ptr(s_found_topk_values, (0,)) 

962 s_step_thresholds_ptr = tle.gpu.local_ptr(s_step_thresholds, (0,)) 

963 s_out_indices_ptr = tle.gpu.local_ptr(s_out_indices, (0,)) 

964 tl.store(s_final_cnt_ptr, 0) 

965 tl.store(s_threshold_bin_idx_ptr, -1) 

966 tl.store(s_final_bin_size_ptr, 0) 

967 tl.store(s_found_topk_values_ptr, 0) 

968 

969 logit_pattern = tl.zeros((), dtype=tl.uint32) 

970 continue_to_next_step = True 

971 need_final_sort = False 

972 init_chunks: tl.constexpr = (K + BLOCK_SIZE - 1) // BLOCK_SIZE 

973 for init_idx in tl.range(0, init_chunks): 

974 pos = init_idx * BLOCK_SIZE + lane 

975 tl.store(tle.gpu.local_ptr(s_out_indices, (pos,)), -1, mask=pos < K) 

976 

977 for step_idx in tl.static_range(0, 4): 

978 if continue_to_next_step: 

979 found_topk_values = tl.load(s_found_topk_values_ptr) 

980 ( 

981 continue_to_next_step, 

982 step_need_final_sort, 

983 logit_pattern, 

984 ) = _tle_process_histogram_step( 

985 row_ptr, 

986 stride_xn, 

987 row_start, 

988 row_end, 

989 seq_len, 

990 step_idx, 

991 logit_pattern, 

992 s_step_thresholds_ptr, 

993 found_topk_values, 

994 hist_base_ptr, 

995 s_out_indices_ptr, 

996 s_final_cnt_ptr, 

997 s_found_topk_values_ptr, 

998 s_threshold_bin_idx_ptr, 

999 s_final_bin_size_ptr, 

1000 assume_aligned, 

1001 TOPK=K, 

1002 BLOCK_SIZE=BLOCK_SIZE, 

1003 ) 

1004 need_final_sort = need_final_sort | step_need_final_sort 

1005 

1006 if need_final_sort: 

1007 if USE_RADIX_FINAL: 

1008 _tle_final_select_radix( 

1009 hist_base_ptr, 

1010 s_out_indices_ptr, 

1011 s_final_cnt_ptr, 

1012 s_found_topk_values_ptr, 

1013 TOPK=K, 

1014 BLOCK_SIZE=BLOCK_SIZE, 

1015 FINAL_SORT_ITEMS=FINAL_SORT_ITEMS, 

1016 ) 

1017 else: 

1018 base_idx = tl.load(s_found_topk_values_ptr) 

1019 final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS) 

1020 sort_chunks = tl.cdiv(final_cnt, BLOCK_SIZE) 

1021 for sort_chunk in tl.range(0, sort_chunks): 

1022 pos = sort_chunk * BLOCK_SIZE + lane 

1023 valid = pos < final_cnt 

1024 logit_i_bits = tl.load( 

1025 tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + pos,)), 

1026 mask=valid, 

1027 other=0, 

1028 ) 

1029 logit_i = logit_i_bits.to(tl.float32, bitcast=True) 

1030 out_rank = tl.zeros([BLOCK_SIZE], dtype=tl.int32) 

1031 for j in tl.range(0, final_cnt): 

1032 logit_j_bits = tl.load( 

1033 tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + j,)) 

1034 ) 

1035 logit_j = logit_j_bits.to(tl.float32, bitcast=True) 

1036 better = (logit_i < logit_j) | ((logit_i == logit_j) & (pos < j)) 

1037 out_rank = out_rank + (valid & better).to(tl.int32) 

1038 dst_pos = base_idx + out_rank 

1039 take = valid & (dst_pos < K) 

1040 idx_i = tl.load( 

1041 tle.gpu.local_ptr(s_histogram, (pos,)), 

1042 mask=take, 

1043 other=0, 

1044 ) 

1045 tl.store(tle.gpu.local_ptr(s_out_indices, (dst_pos,)), idx_i, mask=take) 

1046 tl.store(s_found_topk_values_ptr, K) 

1047 

1048 flush_chunks: tl.constexpr = (K + BLOCK_SIZE - 1) // BLOCK_SIZE 

1049 for flush_chunk in tl.static_range(flush_chunks): 

1050 pos = flush_chunk * BLOCK_SIZE + lane 

1051 mask = pos < K 

1052 out_vals = tl.load( 

1053 tle.gpu.local_ptr(s_out_indices, (pos,)), mask=mask, other=-1 

1054 ) 

1055 tl.store(out_row + pos * stride_outn, out_vals, mask=mask) 

1056 

1057 

1058def tle_bucket_sort_topk( 

1059 inputs, 

1060 starts, 

1061 ends, 

1062 topk, 

1063): 

1064 if not HAS_TLE: 

1065 raise RuntimeError( 

1066 "TLE is unavailable. bucket_sort_topk TLE kernel requires Triton >= 3.6 with triton.experimental.tle." 

1067 ) 

1068 if inputs.ndim != 2: 

1069 raise ValueError("inputs must be a 2D tensor") 

1070 if starts.ndim != 1 or ends.ndim != 1: 

1071 raise ValueError("starts and ends must be 1D tensors") 

1072 

1073 x = inputs.float() if inputs.dtype != torch.float32 else inputs 

1074 batch, seq_len = x.shape 

1075 out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device) 

1076 use_radix_final = seq_len >= TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD 

1077 

1078 grid = (batch,) 

1079 kernel_tle_bucket_sort_topk[grid]( 

1080 x, 

1081 out, 

1082 starts, 

1083 ends, 

1084 x.stride(0), 

1085 x.stride(1), 

1086 out.stride(0), 

1087 out.stride(1), 

1088 seq_len, 

1089 K=topk, 

1090 BLOCK_SIZE=TLE_FIXED_BLOCK_SIZE, 

1091 USE_RADIX_FINAL=use_radix_final, 

1092 num_warps=TLE_FIXED_NUM_WARPS, 

1093 num_stages=TLE_FIXED_NUM_STAGES, 

1094 ) 

1095 return out 

1096 

1097 

1098def _should_use_tle_bucket_sort_topk(inputs, topk): 

1099 if not HAS_TLE: 

1100 return False 

1101 if not isinstance(inputs, torch.Tensor) or inputs.device.type != "cuda": 

1102 return False 

1103 return True 

1104 

1105 

1106def bucket_sort_topk(inputs, starts, ends, topk): 

1107 if _should_use_tle_bucket_sort_topk(inputs, topk): 

1108 try: 

1109 return tle_bucket_sort_topk(inputs, starts, ends, topk) 

1110 except Exception: 

1111 # Fallback to legacy implementation when TLE path is unsupported at runtime. 

1112 return bucket_sort_topk_triton(inputs, starts, ends, topk) 

1113 return bucket_sort_topk_triton(inputs, starts, ends, topk)