Coverage for src/flag_gems/fused/DSA/bin_topk.py: 6%
540 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.utils.triton_version_utils import has_triton_tle
7if has_triton_tle(3, 6, 0):
8 try:
9 import triton.experimental.tle.language as tle
11 HAS_TLE = True
12 except ImportError:
13 tle = None
14 HAS_TLE = False
15else:
16 tle = None
17 HAS_TLE = False
20TLE_FIXED_BLOCK_SIZE = 1024
21TLE_FIXED_NUM_WARPS = TLE_FIXED_BLOCK_SIZE // 32
22TLE_FIXED_NUM_STAGES = 1
23TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD = 12288
26@triton.jit
27def convert_to_uint16(x):
28 bits_uint = convert_to_uint32(x)
29 return ((bits_uint >> 24) & 0xFF).to(tl.uint16)
32@triton.jit
33def convert_to_uint32(x):
34 bits_uint = x.to(tl.uint32, bitcast=True)
35 bits_uint = tl.where(
36 x < 0,
37 ~bits_uint & tl.full(bits_uint.shape, 0xFFFFFFFF, tl.uint32),
38 bits_uint | tl.full(bits_uint.shape, 0x80000000, tl.uint32),
39 )
40 return bits_uint
43@triton.autotune(
44 configs=[
45 triton.Config({"BS": 32, "BSS": 32}, num_stages=1, num_warps=1),
46 triton.Config({"BS": 64, "BSS": 32}, num_stages=1, num_warps=1),
47 triton.Config({"BS": 128, "BSS": 32}, num_stages=2, num_warps=1),
48 triton.Config({"BS": 256, "BSS": 32}, num_stages=2, num_warps=2),
49 triton.Config({"BS": 512, "BSS": 64}, num_stages=2, num_warps=2),
50 triton.Config({"BS": 1024, "BSS": 256}, num_stages=2, num_warps=2),
51 triton.Config({"BS": 2048, "BSS": 256}, num_stages=2, num_warps=4),
52 triton.Config({"BS": 4096, "BSS": 512}, num_stages=3, num_warps=4),
53 triton.Config({"BS": 8192, "BSS": 512}, num_stages=3, num_warps=8),
54 triton.Config({"BS": 8192, "BSS": 1024}, num_stages=3, num_warps=8),
55 ],
56 key=["S", "K"],
57)
58@triton.jit
59def kernel_bucket_sort_topk( # grid(B, BS)
60 inputs, # (B, S) Note: no H because MLA is based on MQA and MHA, not GQA
61 indices, # (B, K) topk index array
62 s_input_ids, # Data indices to be filtered in the next round
63 starts, # for variable length
64 ends, # for variable length
65 S: tl.constexpr, # sequence length
66 K: tl.constexpr, # k of topk
67 HISTOGRAM_SIZE: tl.constexpr,
68 SMEM_INPUT_SIZE: tl.constexpr, # to save candidates of next loop
69 BS: tl.constexpr, # block size of S
70 BSS: tl.constexpr, # block size of SMEM_INPUT
71):
72 # Get thread block id
73 i_b = tl.program_id(0)
75 # Block base pointer definitions
76 s_base = inputs + i_b * S
77 indices_base = indices + i_b * K
78 s_input_ids_base = s_input_ids + i_b * SMEM_INPUT_SIZE
80 # Histogram initialization
81 s_histogram = tl.zeros([HISTOGRAM_SIZE], dtype=tl.int32)
83 # Support variable length
84 l_start_idx = tl.load(starts + i_b).to(tl.int32)
85 l_end_idx = tl.load(ends + i_b).to(tl.int32)
87 # Record how many positions remain to fill the topk array
88 l_new_topk = K
90 TS = tl.cdiv(S, BS)
91 for s in range(TS):
92 input_idx = s * BS + tl.arange(0, BS)
93 input_mask = (
94 (input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S)
95 )
96 input = tl.load(s_base + input_idx, input_mask, other=float("-inf")).to(
97 tl.float32
98 )
99 inval_int16 = convert_to_uint16(input)
100 s_histogram += inval_int16.to(tl.int32).histogram(HISTOGRAM_SIZE)
102 s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum
104 mv_idx = (
105 tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE
106 ) # Construct offset index matrix
108 cond = (s_histogram > l_new_topk) & (
109 (s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0)
110 )
111 l_threshold_bin_id = cond.argmax(0)
113 l_new_topk -= tl.where(
114 tl.arange(0, HISTOGRAM_SIZE) == l_threshold_bin_id + 1, s_histogram, 0
115 ).max(0)
116 sum = 0
117 thre_bin_sum = 0
118 for s in range(TS):
119 input_idx = s * BS + tl.arange(0, BS)
120 input_mask = (
121 (input_idx < l_end_idx) & (input_idx >= l_start_idx) & (input_idx < S)
122 )
123 input = tl.load(s_base + input_idx, input_mask, other=float("-inf")).to(
124 tl.float32
125 )
126 inval_int16 = convert_to_uint16(input)
127 # inval_int16 = tl.where(input_mask, inval_int16, 0)
128 # This method would slow down the speed, so using other=float("-inf") saves time.
130 over_thre = inval_int16.to(tl.int32) > l_threshold_bin_id
131 cur_sum = over_thre.to(tl.int32).sum(-1)
133 eq_thre = inval_int16.to(tl.int32) == l_threshold_bin_id
134 thre_bin_cur_sum = eq_thre.to(tl.int32).sum(-1)
136 topk_idx = over_thre.to(tl.int32).cumsum(-1)
137 thre_bin_idx = eq_thre.to(tl.int32).cumsum(-1)
139 concat_mask = tl.cat(over_thre, eq_thre, True)
140 concat_input = tl.cat(input_idx, input_idx, True)
141 concat_pointer_matrix = tl.cat(
142 indices_base + sum + topk_idx - 1,
143 s_input_ids_base + thre_bin_sum + thre_bin_idx - 1,
144 True,
145 )
146 tl.store(concat_pointer_matrix, concat_input, mask=concat_mask)
148 thre_bin_sum += thre_bin_cur_sum
149 sum += cur_sum
151 round = 0
152 # print("l_new_topk:", l_new_topk)
153 while round < 4 and l_new_topk > 0:
154 ss = tl.cdiv(thre_bin_sum, BSS)
155 s_histogram = tl.zeros([HISTOGRAM_SIZE], dtype=tl.int32)
156 padding_num = 0.0 if round else float("-inf")
157 # When round == 0, if the padding value is set to 0.0, the following problem occurs:
158 #
159 # 0.0 = 0x00000000, inval_int32(0x|00|000000, round=0) = 0x80
160 # This causes the padding bucket to be larger than negative candidates,
161 # thus being prioritized and assigned to the next bucket
162 # or even directly into the topk sequence.
163 #
164 # However, if the padding value is set to "-inf":
165 # float("-inf") = 0xFFFFE000, inval_int32(0x|FF|FFE000, round=0) = 0x00
166 # This ensures the padding value is placed in the smallest bin,
167 # not affecting the sorting of all normal candidate numbers before it.
168 #
169 # But when round > 0, if the padding value remains "-inf", the following problem occurs:
170 # float("-inf") = 0xFFFFE000, inval_int32(0xFFFFE0|00|, round=3) = 0xFF
171 # This causes the padding bucket to be larger than all values,
172 # thus preferentially entering the topk sequence and causing errors.
173 # Therefore, the padding value should be set to 0.0
174 for s in range(ss):
175 s_input_idx = s * BSS + tl.arange(0, BSS)
176 s_input_idx_mask = s_input_idx < thre_bin_sum
177 input_idx = tl.load(
178 s_input_ids_base + s_input_idx, s_input_idx_mask, other=-1
179 )
180 s_input_mask = s_input_idx_mask
181 s_input = tl.load(s_base + input_idx, s_input_mask, other=padding_num).to(
182 tl.float32
183 )
184 inval_int32 = (
185 convert_to_uint32(s_input) >> (24 - round * 8)
186 ) & 0xFF # Ensure all bits except the last eight are zero
187 s_histogram += inval_int32.to(tl.int32).histogram(HISTOGRAM_SIZE)
188 s_histogram = s_histogram.cumsum(0, reverse=True) # Suffix sum
189 mv_idx = (
190 tl.arange(1, HISTOGRAM_SIZE + 1) % HISTOGRAM_SIZE
191 ) # Construct offset index matrix
192 cond = (s_histogram > l_new_topk) & (
193 (s_histogram.gather(mv_idx, 0) <= l_new_topk) | (mv_idx == 0)
194 )
195 l_threshold_bin_id = cond.argmax(0)
196 l_new_topk -= tl.where(
197 tl.arange(0, HISTOGRAM_SIZE) == l_threshold_bin_id + 1, s_histogram, 0
198 ).max(0)
199 thre_bin_sum, old_thre_bin_sum = 0, thre_bin_sum
201 for s in range(ss):
202 s_input_idx = s * BSS + tl.arange(0, BSS)
203 s_input_idx_mask = s_input_idx < old_thre_bin_sum
204 input_idx = tl.load(
205 s_input_ids_base + s_input_idx, s_input_idx_mask, other=-1
206 )
207 s_input_mask = s_input_idx_mask
208 s_input = tl.load(s_base + input_idx, s_input_mask, other=padding_num).to(
209 tl.float32
210 )
211 inval_int32 = (convert_to_uint32(s_input) >> (24 - round * 8)) & 0xFF
213 over_thre = inval_int32.to(tl.int32) > l_threshold_bin_id
214 cur_sum = over_thre.to(tl.int32).sum(-1)
215 eq_thre = inval_int32.to(tl.int32) == l_threshold_bin_id
216 thre_bin_cur_sum = eq_thre.to(tl.int32).sum(-1)
218 topk_idx = over_thre.to(tl.int32).cumsum(-1)
219 thre_bin_idx = eq_thre.to(tl.int32).cumsum(-1)
221 concat_mask = tl.cat(over_thre, eq_thre, True)
222 concat_input = tl.cat(input_idx, input_idx, True)
223 concat_pointer_matrix = tl.cat(
224 indices_base + sum + topk_idx - 1,
225 s_input_ids_base + thre_bin_sum + thre_bin_idx - 1,
226 True,
227 )
229 tl.store(concat_pointer_matrix, concat_input, mask=concat_mask)
231 thre_bin_sum += thre_bin_cur_sum
232 sum += cur_sum
234 round += 1
236 if l_new_topk > 0:
237 ss = tl.cdiv(l_new_topk, BSS)
238 for s in range(ss):
239 s_input_idx = s * BSS + tl.arange(0, BSS)
240 s_input_idx_mask = s_input_idx < l_new_topk
241 input_idx = tl.load(
242 s_input_ids_base + s_input_idx, s_input_idx_mask, other=-1
243 )
244 s_input_mask = s_input_idx_mask
245 tl.store(
246 indices_base + sum + tl.arange(0, BSS), input_idx, mask=s_input_mask
247 )
248 sum += BSS
251def bucket_sort_topk_triton(inputs, starts, ends, topk):
252 B, S = inputs.shape
253 K = topk
254 HISTOGRAM_SIZE = 256
255 SMEM_INPUT_SIZE = 4096
256 indices = torch.full((B, topk), -1, dtype=torch.int32, device=inputs.device)
257 s_input_idx = torch.zeros(
258 B, SMEM_INPUT_SIZE, dtype=torch.int32, device=inputs.device
259 )
260 grid = (B,)
261 kernel_bucket_sort_topk[grid](
262 inputs,
263 indices,
264 s_input_idx,
265 starts,
266 ends,
267 S,
268 K,
269 HISTOGRAM_SIZE,
270 SMEM_INPUT_SIZE,
271 )
272 return indices
275@triton.jit
276def _convert_to_trt_uint32(x):
277 bits = x.to(tl.uint32, bitcast=True)
278 sign_mask = tl.full(bits.shape, 0x80000000, tl.uint32)
279 sign_set = (bits & sign_mask) != 0
280 inv = (~bits) & tl.full(bits.shape, 0x7FFFFFFF, tl.uint32)
281 return tl.where(sign_set, bits, inv)
284@triton.jit
285def _convert_to_trt_uint16_hi11(x):
286 h = x.to(tl.float16)
287 bits = h.to(tl.uint16, bitcast=True)
288 sign_mask = tl.full(bits.shape, 0x8000, tl.uint16)
289 sign_set = (bits & sign_mask) != 0
290 inv = (~bits) & tl.full(bits.shape, 0x7FFF, tl.uint16)
291 mapped = tl.where(sign_set, bits, inv)
292 return (mapped >> 5).to(tl.int32)
295@triton.jit
296def _tle_process_histogram_step(
297 row_ptr,
298 stride_xn,
299 row_start,
300 row_end,
301 seq_len,
302 step_idx: tl.constexpr,
303 logit_pattern,
304 s_step_thresholds_ptr,
305 found_topk_values,
306 hist_base_ptr,
307 s_out_indices_ptr,
308 s_final_cnt_ptr,
309 s_found_topk_values_ptr,
310 s_threshold_bin_idx_ptr,
311 s_final_bin_size_ptr,
312 assume_aligned,
313 TOPK: tl.constexpr,
314 BLOCK_SIZE: tl.constexpr,
315):
316 VEC: tl.constexpr = 4
317 FINAL_SORT_ITEMS: tl.constexpr = 2048
318 RADIX11_SIZE: tl.constexpr = 2048
319 RADIX11_MASK: tl.constexpr = 0x7FF
320 RADIX10_SIZE: tl.constexpr = 1024
321 RADIX10_MASK: tl.constexpr = 0x3FF
323 lane = tl.arange(0, BLOCK_SIZE)
324 vec = tl.arange(0, VEC)
325 ones = tl.full([BLOCK_SIZE], 1, tl.int32)
326 ones_vec_2d = tl.full([BLOCK_SIZE, VEC], 1, tl.int32)
327 zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
328 zeros_vec_2d = tl.zeros([BLOCK_SIZE, VEC], dtype=tl.int32)
330 clear_rounds = tl.where(
331 step_idx == 3,
332 RADIX10_SIZE // BLOCK_SIZE,
333 RADIX11_SIZE // BLOCK_SIZE,
334 )
335 for clear_round in tl.range(0, clear_rounds):
336 clear_bins = clear_round * BLOCK_SIZE + lane
337 tl.store(hist_base_ptr + clear_bins, 0)
338 tl.debug_barrier()
340 if step_idx == 2:
341 step1_threshold = tl.load(s_step_thresholds_ptr + 1)
342 logit_pattern = (step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21
343 elif step_idx == 3:
344 step1_threshold = tl.load(s_step_thresholds_ptr + 1)
345 step2_threshold = tl.load(s_step_thresholds_ptr + 2)
346 logit_pattern = ((step1_threshold.to(tl.uint32) & RADIX11_MASK) << 21) | (
347 (step2_threshold.to(tl.uint32) & RADIX11_MASK) << 10
348 )
350 n_tiles = tl.cdiv(seq_len, BLOCK_SIZE)
351 n_vec_full = seq_len // (BLOCK_SIZE * VEC)
352 rem_tiles = (seq_len - n_vec_full * BLOCK_SIZE * VEC) // BLOCK_SIZE
354 if assume_aligned:
355 for t in tl.range(0, n_vec_full):
356 base = t * BLOCK_SIZE * VEC + lane * VEC
357 offs = base[:, None] + vec[None, :]
358 x_vec = tl.load(row_ptr + offs)
359 key = _convert_to_trt_uint32(x_vec)
360 if step_idx == 0:
361 digit = _convert_to_trt_uint16_hi11(x_vec)
362 elif step_idx == 1:
363 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32)
364 elif step_idx == 2:
365 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32)
366 else:
367 digit = (key & RADIX10_MASK).to(tl.int32)
369 if step_idx < 2:
370 partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1)
371 elif step_idx == 2:
372 partial = ((key ^ logit_pattern) >> 21) == 0
373 else:
374 partial = ((key ^ logit_pattern) >> 10) == 0
376 tl.atomic_add(
377 hist_base_ptr + digit,
378 ones_vec_2d,
379 mask=partial,
380 sem="relaxed",
381 scope="cta",
382 )
384 for t in tl.range(0, rem_tiles):
385 offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane
386 x = tl.load(row_ptr + offs)
387 key = _convert_to_trt_uint32(x)
388 if step_idx == 0:
389 digit = _convert_to_trt_uint16_hi11(x)
390 elif step_idx == 1:
391 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32)
392 elif step_idx == 2:
393 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32)
394 else:
395 digit = (key & RADIX10_MASK).to(tl.int32)
397 if step_idx < 2:
398 partial = tl.full([BLOCK_SIZE], True, tl.int1)
399 elif step_idx == 2:
400 partial = ((key ^ logit_pattern) >> 21) == 0
401 else:
402 partial = ((key ^ logit_pattern) >> 10) == 0
404 tl.atomic_add(
405 hist_base_ptr + digit,
406 ones,
407 mask=partial,
408 sem="relaxed",
409 scope="cta",
410 )
411 else:
412 for t in tl.range(0, n_tiles):
413 offs = t * BLOCK_SIZE + lane
414 in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end)
415 x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf"))
416 key = _convert_to_trt_uint32(x)
417 if step_idx == 0:
418 digit = _convert_to_trt_uint16_hi11(x)
419 elif step_idx == 1:
420 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32)
421 elif step_idx == 2:
422 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32)
423 else:
424 digit = (key & RADIX10_MASK).to(tl.int32)
426 if step_idx < 2:
427 partial = in_range
428 elif step_idx == 2:
429 partial = in_range & (((key ^ logit_pattern) >> 21) == 0)
430 else:
431 partial = in_range & (((key ^ logit_pattern) >> 10) == 0)
433 tl.atomic_add(
434 hist_base_ptr + digit,
435 ones,
436 mask=partial,
437 sem="relaxed",
438 scope="cta",
439 )
440 tl.debug_barrier()
442 tl.store(s_threshold_bin_idx_ptr, -1)
443 tl.store(s_final_bin_size_ptr, 0)
444 threshold_bin_ptrs = s_threshold_bin_idx_ptr + zeros
445 final_bin_size_ptrs = s_final_bin_size_ptr + zeros
446 last_value = found_topk_values
447 threshold_found = False
448 threshold_rounds = tl.where(
449 step_idx == 3,
450 RADIX10_SIZE // BLOCK_SIZE,
451 RADIX11_SIZE // BLOCK_SIZE,
452 )
453 for round_idx in tl.range(0, threshold_rounds):
454 if not threshold_found:
455 bins = round_idx * BLOCK_SIZE + lane
456 counts = tl.load(hist_base_ptr + bins)
457 prefix_sum, counts_total = tle.cumsum(counts, axis=0, reverse=False)
458 prefix_sum = prefix_sum + last_value
459 total_sum = last_value + counts_total
460 next_prefix_sum = prefix_sum + counts
461 threshold_mask = (prefix_sum < TOPK) & (next_prefix_sum >= TOPK)
462 threshold_bin = bins
463 threshold_bin_size = next_prefix_sum - prefix_sum
464 tl.store(threshold_bin_ptrs, threshold_bin, mask=threshold_mask)
465 tl.store(final_bin_size_ptrs, threshold_bin_size, mask=threshold_mask)
466 found_round = tl.reduce_or(threshold_mask, axis=0)
467 threshold_found = found_round
468 last_value = total_sum
470 threshold_bin_idx = tl.load(s_threshold_bin_idx_ptr)
471 final_bin_size = tl.load(s_final_bin_size_ptr)
472 tl.store(s_step_thresholds_ptr + step_idx, threshold_bin_idx)
474 use_final = (
475 (step_idx < 3) & (threshold_bin_idx >= 0) & (final_bin_size <= FINAL_SORT_ITEMS)
476 )
477 if use_final:
478 tl.store(s_final_cnt_ptr, 0)
480 found_ptrs = s_found_topk_values_ptr + zeros
481 final_cnt_ptrs = s_final_cnt_ptr + zeros
482 if assume_aligned:
483 found_ptrs_vec_2d = s_found_topk_values_ptr + zeros_vec_2d
484 final_cnt_ptrs_vec_2d = s_final_cnt_ptr + zeros_vec_2d
485 for t in tl.range(0, n_vec_full):
486 base = t * BLOCK_SIZE * VEC + lane * VEC
487 offs = base[:, None] + vec[None, :]
488 x_vec = tl.load(row_ptr + offs)
489 key = _convert_to_trt_uint32(x_vec)
490 if step_idx == 0:
491 digit = _convert_to_trt_uint16_hi11(x_vec)
492 elif step_idx == 1:
493 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32)
494 elif step_idx == 2:
495 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32)
496 else:
497 digit = (key & RADIX10_MASK).to(tl.int32)
499 if step_idx < 2:
500 partial = tl.full([BLOCK_SIZE, VEC], True, tl.int1)
501 elif step_idx == 2:
502 partial = ((key ^ logit_pattern) >> 21) == 0
503 else:
504 partial = ((key ^ logit_pattern) >> 10) == 0
506 take_lt = partial & (digit < threshold_bin_idx)
507 out_pos_lt = tl.atomic_add(
508 found_ptrs_vec_2d,
509 ones_vec_2d,
510 mask=take_lt,
511 sem="relaxed",
512 scope="cta",
513 )
514 tl.store(
515 s_out_indices_ptr + out_pos_lt,
516 offs.to(tl.int32),
517 mask=take_lt & (out_pos_lt < TOPK),
518 )
520 if step_idx == 3:
521 take_eq = partial & (digit == threshold_bin_idx)
522 out_pos_eq = tl.atomic_add(
523 hist_base_ptr + digit,
524 ones_vec_2d,
525 mask=take_eq,
526 sem="relaxed",
527 scope="cta",
528 )
529 tl.store(
530 s_out_indices_ptr + out_pos_eq,
531 offs.to(tl.int32),
532 mask=take_eq & (out_pos_eq < TOPK),
533 )
534 elif use_final:
535 take_eq_final = partial & (digit == threshold_bin_idx)
536 final_pos = tl.atomic_add(
537 final_cnt_ptrs_vec_2d,
538 ones_vec_2d,
539 mask=take_eq_final,
540 sem="relaxed",
541 scope="cta",
542 )
543 tl.store(
544 hist_base_ptr + final_pos,
545 offs.to(tl.int32),
546 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS),
547 )
548 tl.store(
549 hist_base_ptr + (FINAL_SORT_ITEMS + final_pos),
550 x_vec.to(tl.int32, bitcast=True),
551 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS),
552 )
554 for t in tl.range(0, rem_tiles):
555 offs = (n_vec_full * VEC + t) * BLOCK_SIZE + lane
556 x = tl.load(row_ptr + offs)
557 key = _convert_to_trt_uint32(x)
558 if step_idx == 0:
559 digit = _convert_to_trt_uint16_hi11(x)
560 elif step_idx == 1:
561 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32)
562 elif step_idx == 2:
563 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32)
564 else:
565 digit = (key & RADIX10_MASK).to(tl.int32)
567 if step_idx < 2:
568 partial = tl.full([BLOCK_SIZE], True, tl.int1)
569 elif step_idx == 2:
570 partial = ((key ^ logit_pattern) >> 21) == 0
571 else:
572 partial = ((key ^ logit_pattern) >> 10) == 0
574 take_lt = partial & (digit < threshold_bin_idx)
575 out_pos_lt = tl.atomic_add(
576 found_ptrs,
577 ones,
578 mask=take_lt,
579 sem="relaxed",
580 scope="cta",
581 )
582 tl.store(
583 s_out_indices_ptr + out_pos_lt,
584 offs.to(tl.int32),
585 mask=take_lt & (out_pos_lt < TOPK),
586 )
588 if step_idx == 3:
589 take_eq = partial & (digit == threshold_bin_idx)
590 out_pos_eq = tl.atomic_add(
591 hist_base_ptr + digit,
592 ones,
593 mask=take_eq,
594 sem="relaxed",
595 scope="cta",
596 )
597 tl.store(
598 s_out_indices_ptr + out_pos_eq,
599 offs.to(tl.int32),
600 mask=take_eq & (out_pos_eq < TOPK),
601 )
602 elif use_final:
603 take_eq_final = partial & (digit == threshold_bin_idx)
604 final_pos = tl.atomic_add(
605 final_cnt_ptrs,
606 ones,
607 mask=take_eq_final,
608 sem="relaxed",
609 scope="cta",
610 )
611 tl.store(
612 hist_base_ptr + final_pos,
613 offs.to(tl.int32),
614 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS),
615 )
616 tl.store(
617 hist_base_ptr + (FINAL_SORT_ITEMS + final_pos),
618 x.to(tl.int32, bitcast=True),
619 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS),
620 )
621 else:
622 for t in tl.range(0, n_tiles):
623 offs = t * BLOCK_SIZE + lane
624 in_range = (offs < seq_len) & (offs >= row_start) & (offs < row_end)
625 x = tl.load(row_ptr + offs * stride_xn, mask=in_range, other=float("-inf"))
626 key = _convert_to_trt_uint32(x)
627 if step_idx == 0:
628 digit = _convert_to_trt_uint16_hi11(x)
629 elif step_idx == 1:
630 digit = ((key >> 21) & RADIX11_MASK).to(tl.int32)
631 elif step_idx == 2:
632 digit = ((key >> 10) & RADIX11_MASK).to(tl.int32)
633 else:
634 digit = (key & RADIX10_MASK).to(tl.int32)
636 if step_idx < 2:
637 partial = in_range
638 elif step_idx == 2:
639 partial = in_range & (((key ^ logit_pattern) >> 21) == 0)
640 else:
641 partial = in_range & (((key ^ logit_pattern) >> 10) == 0)
643 take_lt = partial & (digit < threshold_bin_idx)
644 out_pos_lt = tl.atomic_add(
645 found_ptrs,
646 ones,
647 mask=take_lt,
648 sem="relaxed",
649 scope="cta",
650 )
651 tl.store(
652 s_out_indices_ptr + out_pos_lt,
653 offs.to(tl.int32),
654 mask=take_lt & (out_pos_lt < TOPK),
655 )
657 if step_idx == 3:
658 take_eq = partial & (digit == threshold_bin_idx)
659 out_pos_eq = tl.atomic_add(
660 hist_base_ptr + digit,
661 ones,
662 mask=take_eq,
663 sem="relaxed",
664 scope="cta",
665 )
666 tl.store(
667 s_out_indices_ptr + out_pos_eq,
668 offs.to(tl.int32),
669 mask=take_eq & (out_pos_eq < TOPK),
670 )
671 elif use_final:
672 take_eq_final = partial & (digit == threshold_bin_idx)
673 final_pos = tl.atomic_add(
674 final_cnt_ptrs,
675 ones,
676 mask=take_eq_final,
677 sem="relaxed",
678 scope="cta",
679 )
680 tl.store(
681 hist_base_ptr + final_pos,
682 offs.to(tl.int32),
683 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS),
684 )
685 tl.store(
686 hist_base_ptr + (FINAL_SORT_ITEMS + final_pos),
687 x.to(tl.int32, bitcast=True),
688 mask=take_eq_final & (final_pos < FINAL_SORT_ITEMS),
689 )
691 if step_idx < 3:
692 if use_final:
693 need_final_sort = True
694 continue_to_next_step = False
695 else:
696 need_final_sort = False
697 continue_to_next_step = True
698 else:
699 tl.store(s_found_topk_values_ptr, TOPK)
700 need_final_sort = False
701 continue_to_next_step = False
703 tl.debug_barrier()
704 return continue_to_next_step, need_final_sort, logit_pattern
707@triton.jit
708def _tle_final_select_radix(
709 hist_base_ptr,
710 s_out_indices_ptr,
711 s_final_cnt_ptr,
712 s_found_topk_values_ptr,
713 TOPK: tl.constexpr,
714 BLOCK_SIZE: tl.constexpr,
715 FINAL_SORT_ITEMS: tl.constexpr,
716):
717 RADIX_BITS_FINAL: tl.constexpr = 8
718 RADIX_SIZE_FINAL: tl.constexpr = 1 << RADIX_BITS_FINAL
719 RADIX_MASK_FINAL: tl.constexpr = RADIX_SIZE_FINAL - 1
720 DIGIT_START: tl.constexpr = 32 - RADIX_BITS_FINAL
722 lane = tl.arange(0, BLOCK_SIZE)
723 ones = tl.full([BLOCK_SIZE], 1, tl.int32)
724 zeros = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
725 bins = tl.arange(0, RADIX_SIZE_FINAL)
727 s_radix_counts = tle.gpu.alloc(
728 [RADIX_SIZE_FINAL],
729 dtype=tl.int32,
730 layout=None,
731 scope=tle.gpu.smem,
732 nv_mma_shared_layout=False,
733 )
734 radix_count_ptr = tle.gpu.local_ptr(s_radix_counts, (0,))
735 radix_count_vec_ptr = tle.gpu.local_ptr(s_radix_counts, (bins,))
737 base_idx = tl.load(s_found_topk_values_ptr)
738 final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS)
739 remain = tl.minimum(TOPK - base_idx, final_cnt)
740 if remain > 0:
741 desired = tl.zeros((), dtype=tl.uint32)
742 desired_mask = tl.zeros((), dtype=tl.uint32)
743 k_to_find = remain + 1
745 for digit_pos in tl.static_range(DIGIT_START, -1, -RADIX_BITS_FINAL):
746 tl.store(radix_count_ptr + lane, 0, mask=lane < RADIX_SIZE_FINAL)
747 tl.debug_barrier()
749 cnt_tiles = tl.cdiv(final_cnt, BLOCK_SIZE)
750 for t in tl.range(0, cnt_tiles):
751 pos = t * BLOCK_SIZE + lane
752 valid = pos < final_cnt
753 x_bits_i32 = tl.load(
754 hist_base_ptr + (FINAL_SORT_ITEMS + pos),
755 mask=valid,
756 other=0,
757 )
758 x = x_bits_i32.to(tl.float32, bitcast=True)
759 key = _convert_to_trt_uint32(x)
760 matches = (key & desired_mask) == desired
761 digit = ((key >> digit_pos) & RADIX_MASK_FINAL).to(tl.int32)
762 take = valid & matches
763 tl.atomic_add(
764 radix_count_ptr + digit,
765 ones,
766 mask=take,
767 sem="relaxed",
768 scope="cta",
769 )
771 tl.debug_barrier()
772 counts = tl.load(radix_count_vec_ptr)
773 prefix_sum, _ = tle.cumsum(counts, axis=0, reverse=False)
774 next_prefix_sum = prefix_sum + counts
775 threshold_mask = (prefix_sum < k_to_find) & (next_prefix_sum >= k_to_find)
776 threshold_init = tl.full((), RADIX_SIZE_FINAL, dtype=tl.int32)
777 threshold_bin = tl.min(
778 tl.where(threshold_mask, bins, threshold_init), axis=0
779 ).to(tl.int32)
780 threshold_bin = tl.where(
781 threshold_bin == RADIX_SIZE_FINAL,
782 RADIX_SIZE_FINAL - 1,
783 threshold_bin,
784 )
785 counts_lt = tl.max(
786 tl.where(bins == threshold_bin, prefix_sum, 0),
787 axis=0,
788 ).to(tl.int32)
790 desired = desired | (threshold_bin.to(tl.uint32) << digit_pos)
791 desired_mask = desired_mask | (
792 tl.full((), RADIX_MASK_FINAL, dtype=tl.uint32) << digit_pos
793 )
794 k_to_find = k_to_find - counts_lt
796 thr_key = desired
797 found_ptrs = s_found_topk_values_ptr + zeros
798 cnt_tiles = tl.cdiv(final_cnt, BLOCK_SIZE)
799 for t in tl.range(0, cnt_tiles):
800 pos = t * BLOCK_SIZE + lane
801 valid = pos < final_cnt
802 idx = tl.load(hist_base_ptr + pos, mask=valid, other=0)
803 x_bits_i32 = tl.load(
804 hist_base_ptr + (FINAL_SORT_ITEMS + pos),
805 mask=valid,
806 other=0,
807 )
808 x = x_bits_i32.to(tl.float32, bitcast=True)
809 key = _convert_to_trt_uint32(x)
810 take_lt = valid & (key < thr_key)
811 out_pos_gt = tl.atomic_add(
812 found_ptrs,
813 ones,
814 mask=take_lt,
815 sem="relaxed",
816 scope="cta",
817 )
818 tl.store(
819 s_out_indices_ptr + out_pos_gt,
820 idx,
821 mask=take_lt & (out_pos_gt < TOPK),
822 )
824 cur = tl.load(s_found_topk_values_ptr)
825 if cur < TOPK:
826 for t in tl.range(0, cnt_tiles):
827 cur = tl.load(s_found_topk_values_ptr)
828 if cur < TOPK:
829 pos = t * BLOCK_SIZE + lane
830 valid = pos < final_cnt
831 idx = tl.load(hist_base_ptr + pos, mask=valid, other=0)
832 x_bits_i32 = tl.load(
833 hist_base_ptr + (FINAL_SORT_ITEMS + pos),
834 mask=valid,
835 other=0,
836 )
837 x = x_bits_i32.to(tl.float32, bitcast=True)
838 key = _convert_to_trt_uint32(x)
839 take_eq = valid & (key == thr_key)
840 out_pos_eq = tl.atomic_add(
841 found_ptrs,
842 ones,
843 mask=take_eq,
844 sem="relaxed",
845 scope="cta",
846 )
847 tl.store(
848 s_out_indices_ptr + out_pos_eq,
849 idx,
850 mask=take_eq & (out_pos_eq < TOPK),
851 )
853 tl.store(s_found_topk_values_ptr, TOPK)
856@triton.jit
857def kernel_tle_bucket_sort_topk(
858 x_ptr,
859 out_ptr,
860 starts_ptr,
861 ends_ptr,
862 stride_xm,
863 stride_xn,
864 stride_outm,
865 stride_outn,
866 seq_len,
867 K: tl.constexpr,
868 BLOCK_SIZE: tl.constexpr,
869 USE_RADIX_FINAL: tl.constexpr,
870):
871 pid = tl.program_id(0)
872 row_start = tl.load(starts_ptr + pid).to(tl.int32)
873 row_end = tl.load(ends_ptr + pid).to(tl.int32)
875 row_ptr = x_ptr + pid * stride_xm
876 out_row = out_ptr + pid * stride_outm
877 row_len = row_end - row_start
879 auto_aligned = (
880 (stride_xn == 1)
881 & (stride_outn == 1)
882 & (row_start == 0)
883 & (row_end == seq_len)
884 & (seq_len % BLOCK_SIZE == 0)
885 )
886 assume_aligned = auto_aligned
887 if assume_aligned:
888 seq_len = tl.multiple_of(seq_len, BLOCK_SIZE)
890 lane = tl.arange(0, BLOCK_SIZE)
891 if row_len <= K:
892 chunks: tl.constexpr = (K + BLOCK_SIZE - 1) // BLOCK_SIZE
893 for chunk_idx in tl.range(0, chunks):
894 pos = chunk_idx * BLOCK_SIZE + lane
895 take_row = pos < row_len
896 tl.store(
897 out_row + pos * stride_outn,
898 (row_start + pos).to(tl.int32),
899 mask=take_row,
900 )
901 take_pad = (pos >= row_len) & (pos < K)
902 tl.store(out_row + pos * stride_outn, -1, mask=take_pad)
903 return
905 FINAL_SORT_ITEMS: tl.constexpr = 2048
906 HIST_SIZE: tl.constexpr = 4096
908 s_histogram = tle.gpu.alloc(
909 [HIST_SIZE],
910 dtype=tl.int32,
911 layout=None,
912 scope=tle.gpu.smem,
913 nv_mma_shared_layout=False,
914 )
915 hist_base_ptr = tle.gpu.local_ptr(s_histogram, (0,))
916 s_out_indices = tle.gpu.alloc(
917 [K],
918 dtype=tl.int32,
919 layout=None,
920 scope=tle.gpu.smem,
921 nv_mma_shared_layout=False,
922 )
923 s_final_cnt = tle.gpu.alloc(
924 [1],
925 dtype=tl.int32,
926 layout=None,
927 scope=tle.gpu.smem,
928 nv_mma_shared_layout=False,
929 )
930 s_threshold_bin_idx = tle.gpu.alloc(
931 [1],
932 dtype=tl.int32,
933 layout=None,
934 scope=tle.gpu.smem,
935 nv_mma_shared_layout=False,
936 )
937 s_final_bin_size = tle.gpu.alloc(
938 [1],
939 dtype=tl.int32,
940 layout=None,
941 scope=tle.gpu.smem,
942 nv_mma_shared_layout=False,
943 )
944 s_found_topk_values = tle.gpu.alloc(
945 [1],
946 dtype=tl.int32,
947 layout=None,
948 scope=tle.gpu.smem,
949 nv_mma_shared_layout=False,
950 )
951 s_step_thresholds = tle.gpu.alloc(
952 [4],
953 dtype=tl.int32,
954 layout=None,
955 scope=tle.gpu.smem,
956 nv_mma_shared_layout=False,
957 )
958 s_final_cnt_ptr = tle.gpu.local_ptr(s_final_cnt, (0,))
959 s_threshold_bin_idx_ptr = tle.gpu.local_ptr(s_threshold_bin_idx, (0,))
960 s_final_bin_size_ptr = tle.gpu.local_ptr(s_final_bin_size, (0,))
961 s_found_topk_values_ptr = tle.gpu.local_ptr(s_found_topk_values, (0,))
962 s_step_thresholds_ptr = tle.gpu.local_ptr(s_step_thresholds, (0,))
963 s_out_indices_ptr = tle.gpu.local_ptr(s_out_indices, (0,))
964 tl.store(s_final_cnt_ptr, 0)
965 tl.store(s_threshold_bin_idx_ptr, -1)
966 tl.store(s_final_bin_size_ptr, 0)
967 tl.store(s_found_topk_values_ptr, 0)
969 logit_pattern = tl.zeros((), dtype=tl.uint32)
970 continue_to_next_step = True
971 need_final_sort = False
972 init_chunks: tl.constexpr = (K + BLOCK_SIZE - 1) // BLOCK_SIZE
973 for init_idx in tl.range(0, init_chunks):
974 pos = init_idx * BLOCK_SIZE + lane
975 tl.store(tle.gpu.local_ptr(s_out_indices, (pos,)), -1, mask=pos < K)
977 for step_idx in tl.static_range(0, 4):
978 if continue_to_next_step:
979 found_topk_values = tl.load(s_found_topk_values_ptr)
980 (
981 continue_to_next_step,
982 step_need_final_sort,
983 logit_pattern,
984 ) = _tle_process_histogram_step(
985 row_ptr,
986 stride_xn,
987 row_start,
988 row_end,
989 seq_len,
990 step_idx,
991 logit_pattern,
992 s_step_thresholds_ptr,
993 found_topk_values,
994 hist_base_ptr,
995 s_out_indices_ptr,
996 s_final_cnt_ptr,
997 s_found_topk_values_ptr,
998 s_threshold_bin_idx_ptr,
999 s_final_bin_size_ptr,
1000 assume_aligned,
1001 TOPK=K,
1002 BLOCK_SIZE=BLOCK_SIZE,
1003 )
1004 need_final_sort = need_final_sort | step_need_final_sort
1006 if need_final_sort:
1007 if USE_RADIX_FINAL:
1008 _tle_final_select_radix(
1009 hist_base_ptr,
1010 s_out_indices_ptr,
1011 s_final_cnt_ptr,
1012 s_found_topk_values_ptr,
1013 TOPK=K,
1014 BLOCK_SIZE=BLOCK_SIZE,
1015 FINAL_SORT_ITEMS=FINAL_SORT_ITEMS,
1016 )
1017 else:
1018 base_idx = tl.load(s_found_topk_values_ptr)
1019 final_cnt = tl.minimum(tl.load(s_final_cnt_ptr), FINAL_SORT_ITEMS)
1020 sort_chunks = tl.cdiv(final_cnt, BLOCK_SIZE)
1021 for sort_chunk in tl.range(0, sort_chunks):
1022 pos = sort_chunk * BLOCK_SIZE + lane
1023 valid = pos < final_cnt
1024 logit_i_bits = tl.load(
1025 tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + pos,)),
1026 mask=valid,
1027 other=0,
1028 )
1029 logit_i = logit_i_bits.to(tl.float32, bitcast=True)
1030 out_rank = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
1031 for j in tl.range(0, final_cnt):
1032 logit_j_bits = tl.load(
1033 tle.gpu.local_ptr(s_histogram, (FINAL_SORT_ITEMS + j,))
1034 )
1035 logit_j = logit_j_bits.to(tl.float32, bitcast=True)
1036 better = (logit_i < logit_j) | ((logit_i == logit_j) & (pos < j))
1037 out_rank = out_rank + (valid & better).to(tl.int32)
1038 dst_pos = base_idx + out_rank
1039 take = valid & (dst_pos < K)
1040 idx_i = tl.load(
1041 tle.gpu.local_ptr(s_histogram, (pos,)),
1042 mask=take,
1043 other=0,
1044 )
1045 tl.store(tle.gpu.local_ptr(s_out_indices, (dst_pos,)), idx_i, mask=take)
1046 tl.store(s_found_topk_values_ptr, K)
1048 flush_chunks: tl.constexpr = (K + BLOCK_SIZE - 1) // BLOCK_SIZE
1049 for flush_chunk in tl.static_range(flush_chunks):
1050 pos = flush_chunk * BLOCK_SIZE + lane
1051 mask = pos < K
1052 out_vals = tl.load(
1053 tle.gpu.local_ptr(s_out_indices, (pos,)), mask=mask, other=-1
1054 )
1055 tl.store(out_row + pos * stride_outn, out_vals, mask=mask)
1058def tle_bucket_sort_topk(
1059 inputs,
1060 starts,
1061 ends,
1062 topk,
1063):
1064 if not HAS_TLE:
1065 raise RuntimeError(
1066 "TLE is unavailable. bucket_sort_topk TLE kernel requires Triton >= 3.6 with triton.experimental.tle."
1067 )
1068 if inputs.ndim != 2:
1069 raise ValueError("inputs must be a 2D tensor")
1070 if starts.ndim != 1 or ends.ndim != 1:
1071 raise ValueError("starts and ends must be 1D tensors")
1073 x = inputs.float() if inputs.dtype != torch.float32 else inputs
1074 batch, seq_len = x.shape
1075 out = torch.full((batch, topk), -1, dtype=torch.int32, device=x.device)
1076 use_radix_final = seq_len >= TLE_RADIX_FINAL_SEQ_LEN_THRESHOLD
1078 grid = (batch,)
1079 kernel_tle_bucket_sort_topk[grid](
1080 x,
1081 out,
1082 starts,
1083 ends,
1084 x.stride(0),
1085 x.stride(1),
1086 out.stride(0),
1087 out.stride(1),
1088 seq_len,
1089 K=topk,
1090 BLOCK_SIZE=TLE_FIXED_BLOCK_SIZE,
1091 USE_RADIX_FINAL=use_radix_final,
1092 num_warps=TLE_FIXED_NUM_WARPS,
1093 num_stages=TLE_FIXED_NUM_STAGES,
1094 )
1095 return out
1098def _should_use_tle_bucket_sort_topk(inputs, topk):
1099 if not HAS_TLE:
1100 return False
1101 if not isinstance(inputs, torch.Tensor) or inputs.device.type != "cuda":
1102 return False
1103 return True
1106def bucket_sort_topk(inputs, starts, ends, topk):
1107 if _should_use_tle_bucket_sort_topk(inputs, topk):
1108 try:
1109 return tle_bucket_sort_topk(inputs, starts, ends, topk)
1110 except Exception:
1111 # Fallback to legacy implementation when TLE path is unsupported at runtime.
1112 return bucket_sort_topk_triton(inputs, starts, ends, topk)
1113 return bucket_sort_topk_triton(inputs, starts, ends, topk)