Coverage for src/flag_gems/runtime/backend/_sunrise/ops/topk.py: 0%
323 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.language.core as core
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2
13except ImportError:
14 pass
16from flag_gems.runtime import torch_device_fn
17from flag_gems.utils import libentry
18from flag_gems.utils import triton_lang_extension as ext
19from flag_gems.utils.limits import get_dtype_max, get_dtype_min
20from flag_gems.utils.triton_version_utils import HAS_TLE
22if HAS_TLE:
23 import triton.experimental.tle.language as tle_gpu
24else:
25 tle_gpu = None
27logger = logging.getLogger(__name__)
28_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
29_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
30_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
31_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
32_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
33_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
34_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
35_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
36_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
37_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
38_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
39_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
40_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
41_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
44@triton.jit
45def _get_finfo_val(
46 dtype,
47 return_max,
48):
49 if dtype is tl.float32:
50 if return_max:
51 return _MAX_FLOAT32_VAL
52 else:
53 return _MIN_FLOAT32_VAL
54 elif dtype is tl.float16:
55 if return_max:
56 return _MAX_FLOAT16_VAL
57 else:
58 return _MIN_FLOAT16_VAL
59 elif dtype is tl.bfloat16:
60 if return_max:
61 return _MAX_BFLOAT16_VAL
62 else:
63 return _MIN_BFLOAT16_VAL
66@triton.jit
67def _get_iinfo_val(
68 dtype,
69 return_max,
70):
71 if return_max:
72 return get_dtype_max(dtype)
73 else:
74 return get_dtype_min(dtype)
77@libentry()
78@triton.jit
79def topk_stage1_kernel(
80 y_ptr,
81 index_ptr,
82 x_ptr,
83 k,
84 N: tl.constexpr,
85 CHUNK_SIZE: tl.constexpr,
86 DESCENDING: tl.constexpr,
87):
88 cur_batch = ext.program_id(0)
89 cur_chunk_idx = ext.program_id(1)
90 chunk_num = ext.num_programs(1)
92 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
93 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
95 chunk_offset = cur_chunk_idx * CHUNK_SIZE
96 x_ptr += cur_batch * N + chunk_offset
98 cols = tl.arange(0, CHUNK_SIZE)
99 mask = (chunk_offset + cols) < N
101 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)
102 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)
103 for k_idx in range(k):
104 if DESCENDING:
105 chunk_select_val = tl.max(x_val)
106 chunk_select_idx = tl.argmax(x_val, axis=0)
107 else:
108 chunk_select_val = tl.min(x_val)
109 chunk_select_idx = tl.argmin(x_val, axis=0)
111 tl.store(y_ptr + k_idx, chunk_select_val)
112 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)
114 if DESCENDING:
115 x_val = tl.where(
116 cols == chunk_select_idx,
117 _get_finfo_val(tl.float32, return_max=False),
118 x_val,
119 )
120 else:
121 x_val = tl.where(
122 cols == chunk_select_idx,
123 _get_finfo_val(tl.float32, return_max=True),
124 x_val,
125 )
128"""
129Note(Zhengzekang):
130Refer from triton2.2 official `sort` implementation:
131https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
132Just add indices to sort with values.
133"""
136@triton.jit
137def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
138 n_outer: core.constexpr = x.numel >> n_dims
139 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
141 # tl.device_print("shape is: ", shape)
142 y = core.reshape(x, shape)
143 y_idx = core.reshape(ids, shape)
145 # slice left/right with 'stride' 2**(n_dims - i - 1)
146 mask = core.arange(0, 2)[None, :, None]
147 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
148 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
149 left = core.reshape(left, x.shape)
150 right = core.reshape(right, x.shape)
152 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
153 ids.dtype
154 )
155 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
156 ids.dtype
157 )
158 left_idx = core.reshape(left_idx, ids.shape)
159 right_idx = core.reshape(right_idx, ids.shape)
161 # actual compare-and-swap
162 # is_right indicator: 0 for left, 1 for right element in each pair.
163 is_right = core.reshape(
164 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
165 )
167 # Paired value: for left (is_right=0), the paired is right;
168 # for right (is_right=1), the paired is left.
169 paired_val = core.where(is_right, left, right)
170 paired_idx = core.where(is_right, left_idx, right_idx)
172 # Conditional swap following the official Triton pattern:
173 # swap if (current > paired) differs from (flip ^ is_right).
174 flip_right = (flip ^ is_right) != 0
175 cond = (x > paired_val) != flip_right
176 x = core.where(cond, paired_val, x)
177 ids = core.where(cond, paired_idx, ids)
179 return x, ids
182@triton.jit
183def _bitonic_merge(
184 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
185):
186 """
187 order_type 0 == ascending
188 order_type 1 == descending
189 order_type 2 == alternating
190 """
191 n_outer: core.constexpr = x.numel >> n_dims
192 core.static_assert(stage <= n_dims)
193 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
194 # descending order.
195 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
196 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
197 # a stride of 2) at this stage
198 if order == 2:
199 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
200 flip = core.reshape(
201 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
202 )
203 else:
204 flip = order
205 # perform `stage` rounds of `compare-and-swap`
206 for i in core.static_range(stage):
207 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
208 return x, ids
211@triton.jit
212def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
213 # handle default dimension or check that it is the most minor dim
214 _dim: core.constexpr = dim
215 n_dims: core.constexpr = _log2(x.shape[_dim])
216 for i in core.static_range(1, n_dims + 1):
217 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
218 return x, ids
221@libentry()
222@triton.jit
223def topk_stage2_kernel(
224 y_ptr,
225 index_ptr,
226 chunk_x,
227 chunk_index,
228 sort_dim: tl.constexpr,
229 k: tl.constexpr,
230 N: tl.constexpr,
231 BLOCK_SIZE: tl.constexpr,
232 DESCENDING: tl.constexpr,
233):
234 cur_batch = ext.program_id(0)
235 chunk_x += cur_batch * N
236 chunk_index += cur_batch * N
237 y_ptr += cur_batch * k
238 index_ptr += cur_batch * k
240 cols = tl.arange(0, BLOCK_SIZE)
241 mask = cols < N
243 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
244 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL
246 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)
247 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(
248 tl.int32
249 )
251 sorted_chunk_x, sorted_chunk_index = argsort(
252 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
253 )
254 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)
255 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)
258if HAS_TLE:
260 @triton.jit
261 def _get_topmask_and_fullmask(x):
262 tl.static_assert(
263 x.dtype.is_int_unsigned(),
264 "floating-point value must be passed as bits",
265 )
266 tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
267 fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
268 tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
269 fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
270 return tm_arr, fm_arr
272 @triton.jit
273 def _fpval_to_key_with_nan(x, x_bits):
274 tm, fm = _get_topmask_and_fullmask(x_bits)
275 mask = tl.where((x_bits & tm) != 0, fm, tm)
276 key = x_bits ^ mask
277 return tl.where(x == x, key, fm)
279 @triton.jit
280 def _key_to_fpval(x):
281 tm, fm = _get_topmask_and_fullmask(x)
282 mask = tl.where((x & tm) != 0, tm, fm)
283 return x ^ mask
285 @libentry()
286 @triton.jit
287 def topk_kernel_radix_tle(
288 X,
289 Yv,
290 Yi,
291 stride_xm,
292 stride_ym,
293 n_cols,
294 K: tl.constexpr,
295 K_PAD: tl.constexpr,
296 BLOCK_N: tl.constexpr,
297 RADIX_BITS: tl.constexpr,
298 ):
299 pid = tl.program_id(0)
300 x_dtype = X.dtype.element_ty
301 x_nbits: tl.constexpr = x_dtype.primitive_bitwidth
302 if x_nbits < 16:
303 y_nbits: tl.constexpr = 32
304 else:
305 y_nbits: tl.constexpr = x_nbits * 2
306 x_utype = tl.dtype(f"uint{x_nbits}")
307 x_ultype = tl.dtype(f"uint{y_nbits}")
309 RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS
310 RADIX_MASK: tl.constexpr = RADIX_SIZE - 1
311 bins = tl.arange(0, RADIX_SIZE)
312 one = tl.full([BLOCK_N], 1, tl.int32)
314 desired = tl.full((), 0, dtype=x_utype)
315 desired_mask = tl.full((), 0, dtype=x_utype)
316 k_to_find = tl.full((), K, dtype=tl.int32)
317 n_tiles = tl.cdiv(n_cols, BLOCK_N)
319 smem_counts = tle_gpu.gpu.alloc(
320 [RADIX_SIZE],
321 dtype=tl.int32,
322 layout=None,
323 scope=tle_gpu.gpu.smem,
324 nv_mma_shared_layout=False,
325 )
326 smem_count_ptrs = tle_gpu.gpu.local_ptr(smem_counts, (bins,))
328 for digit_pos in tl.static_range(x_nbits - RADIX_BITS, -1, -RADIX_BITS):
329 tl.store(smem_count_ptrs, tl.zeros([RADIX_SIZE], dtype=tl.int32))
330 for t in tl.range(0, n_tiles):
331 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
332 mask_n = offs_n < n_cols
333 x_ptrs = X + pid * stride_xm + offs_n
334 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
335 x_bits = x.to(x_utype, bitcast=True)
336 x_key = _fpval_to_key_with_nan(x, x_bits)
337 matches = (x_key & desired_mask) == desired
338 digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32)
339 valid = mask_n & matches
340 count_addrs = tle_gpu.gpu.local_ptr(smem_counts, (digit,))
341 tl.atomic_add(count_addrs, one, mask=valid, sem="relaxed", scope="cta")
343 counts = tl.load(smem_count_ptrs)
345 cumsum_desc = tl.cumsum(counts, axis=0, reverse=True)
346 tl.store(smem_count_ptrs, cumsum_desc)
348 selected_scalar = 0
349 counts_gt_scalar = 0
350 found = 0
351 for rev in tl.static_range(RADIX_SIZE):
352 d = RADIX_SIZE - 1 - rev
353 cum_d = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d,)))
354 if d + 1 < RADIX_SIZE:
355 cum_next = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d + 1,)))
356 else:
357 cum_next = 0
358 take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find)
359 selected_scalar = tl.where(take, d, selected_scalar)
360 counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar)
361 found = tl.where(take, 1, found)
363 selected_u = selected_scalar.to(x_utype)
364 desired = desired | (selected_u << digit_pos)
365 desired_mask = desired_mask | (
366 tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos
367 )
368 k_to_find = k_to_find - counts_gt_scalar
370 thr_key = desired
372 min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype)
373 min_bits = min_val.to(x_utype, bitcast=True)
374 min_key = _fpval_to_key_with_nan(min_val, min_bits)
375 min_packed = min_key.to(x_ultype) << 16
376 offs_k = tl.arange(0, K_PAD)
378 smem_selected = tle_gpu.gpu.alloc(
379 [K_PAD],
380 dtype=x_ultype,
381 layout=None,
382 scope=tle_gpu.gpu.smem,
383 nv_mma_shared_layout=False,
384 )
385 smem_selected_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (offs_k,))
386 tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype))
388 smem_write_count = tle_gpu.gpu.alloc(
389 [1],
390 dtype=tl.int32,
391 layout=None,
392 scope=tle_gpu.gpu.smem,
393 nv_mma_shared_layout=False,
394 )
395 tl.store(tle_gpu.gpu.local_ptr(smem_write_count, (0,)), 0)
396 write_count_ptrs = tle_gpu.gpu.local_ptr(
397 smem_write_count, (tl.zeros([BLOCK_N], dtype=tl.int32),)
398 )
400 for t in tl.range(0, n_tiles):
401 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
402 mask_n = offs_n < n_cols
403 x_ptrs = X + pid * stride_xm + offs_n
404 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
405 x_bits = x.to(x_utype, bitcast=True)
406 x_key = _fpval_to_key_with_nan(x, x_bits)
407 idx_key = (n_cols - offs_n).to(x_ultype)
408 packed = (x_key.to(x_ultype) << 16) | idx_key
409 take_gt = mask_n & (x_key > thr_key)
410 pos = tl.atomic_add(
411 write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta"
412 )
413 write_mask = take_gt & (pos < K_PAD)
414 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),))
415 tl.store(dst_ptrs, packed, mask=write_mask)
417 for t in tl.range(0, n_tiles):
418 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
419 mask_n = offs_n < n_cols
420 x_ptrs = X + pid * stride_xm + offs_n
421 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
422 x_bits = x.to(x_utype, bitcast=True)
423 x_key = _fpval_to_key_with_nan(x, x_bits)
424 idx_key = (n_cols - offs_n).to(x_ultype)
425 packed = (x_key.to(x_ultype) << 16) | idx_key
426 take_eq = mask_n & (x_key == thr_key)
427 pos = tl.atomic_add(
428 write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta"
429 )
430 write_mask = take_eq & (pos < K_PAD)
431 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),))
432 tl.store(dst_ptrs, packed, mask=write_mask)
434 selected_packed = tl.load(smem_selected_ptrs)
436 topk = tl.sort(selected_packed, dim=0, descending=True)
437 idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype)
438 idx_raw = (topk & idx_mask).to(tl.uint32)
439 y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32)
440 y_values_raw = (topk >> 16).to(x_utype)
441 y_values = _key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
443 mask_k = offs_k < K
444 yv_ptrs = Yv + pid * stride_ym + offs_k
445 yi_ptrs = Yi + pid * stride_ym + offs_k
446 tl.store(yv_ptrs, y_values, mask=mask_k)
447 tl.store(yi_ptrs, y_indices, mask=mask_k)
450def topk(x, k, dim=-1, largest=True, sorted=True):
451 logger.debug("GEMS TOPK")
452 # If dim equals to last dim, we set it to -1.
453 if dim < 0:
454 dim = dim + x.ndim
456 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
457 # assert sorted, "Currently only support sorted == True"
459 # Early return for k=0 to avoid Triton kernel compilation error.
460 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0.
461 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0.
462 if k == 0:
463 out_shape = list(x.shape[:-1]) + [0]
464 return (
465 torch.empty(out_shape, device=x.device, dtype=x.dtype),
466 torch.empty(out_shape, device=x.device, dtype=torch.int64),
467 )
469 descending = True
470 if not largest:
471 descending = False
473 topk_elem_cnt = x.shape[dim]
474 batch_size = math.prod(x.shape) // topk_elem_cnt
476 if (
477 HAS_TLE
478 and sorted
479 and descending
480 and x.is_cuda
481 and x.dtype in (torch.float16, torch.float32, torch.bfloat16)
482 and k >= 8
483 and topk_elem_cnt <= 65535
484 and triton.next_power_of_2(k) <= 1024
485 ):
486 k_pad = triton.next_power_of_2(k)
487 out_shape = x.shape[:-1] + (k,)
488 y_vals = torch.empty(out_shape, device=x.device, dtype=x.dtype)
489 y_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32)
490 block_n_radix = max(k_pad, min(512, triton.next_power_of_2(topk_elem_cnt)))
491 block_n_radix = min(block_n_radix, 1024)
493 x_2d = x.reshape(batch_size, topk_elem_cnt)
494 y_vals_2d = y_vals.reshape(batch_size, k)
495 y_idx_2d = y_idx.reshape(batch_size, k)
496 with torch_device_fn.device(x.device):
497 topk_kernel_radix_tle[(batch_size,)](
498 x_2d,
499 y_vals_2d,
500 y_idx_2d,
501 x_2d.stride(0),
502 y_vals_2d.stride(0),
503 topk_elem_cnt,
504 K=k,
505 K_PAD=k_pad,
506 BLOCK_N=block_n_radix,
507 RADIX_BITS=4,
508 num_warps=4,
509 num_stages=1,
510 )
511 return (y_vals, y_idx.to(torch.int64))
513 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
514 if topk_elem_cnt < 1024:
515 chunk_size = 256
516 else:
517 chunk_size = 1024
519 # Note(Zhengzekang): We should promise chunk_size is larger than k.
520 if chunk_size < k:
521 chunk_size = triton.next_power_of_2(k)
523 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
525 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
526 stage1_out_idx = torch.empty(
527 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
528 )
530 out_shape = x.shape[:-1] + (k,)
531 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
532 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
534 with torch_device_fn.device(x.device):
535 topk_stage1_kernel[
536 batch_size,
537 chunk_num,
538 ](
539 stage1_out, # pointer to the output
540 stage1_out_idx, # pointer to the output
541 x, # pointer to the input
542 k,
543 topk_elem_cnt,
544 chunk_size,
545 descending,
546 )
547 stage2_elem_cnt = chunk_num * k
549 candidate_vals = stage1_out.view(batch_size, stage2_elem_cnt)
550 candidate_indices = stage1_out_idx.view(batch_size, stage2_elem_cnt)
551 # [sunrise fix] hits incorrect results once the stage2 bitonic sort spills
552 # into the multi-warp path (BLOCK_SIZE >= 512). Reduce the candidate set
553 # with additional stage1 passes until the final sort stays within 256 lanes.
554 """
555 1. topk_stage2_kernel设置num_warps=1可以绕过 ptpu 后端 multi-warp reduction 的共享内存 path bug,
556 根因是 ptpu 后端在 ReduceOpToLLVM.cpp 的 cross-warp reduction 路径存在共享内存线性化偏移计算问题
557 — 官方 tl.sort() 在 N≥512 时也有同样的错误。
558 2. 问题不只是“通用 inter-warp reduce lowering”这一处;至少在 topk 的完整 bitonic sort 路径里,还有别的 multi-warp 交互在出错。
559 最可能的下一步不是继续硬改通用 ReduceOp,而是针对 topk_stage2_kernel 的某个具体 stage 做精确复现,
560 直接盯 _compare_and_swap 后几轮的 TTIR/LLVM IR
561 """
562 safe_stage2_elem_cnt = 256
563 reduction_chunk_size = max(256, triton.next_power_of_2(k + 1))
564 while (
565 k <= safe_stage2_elem_cnt
566 and stage2_elem_cnt > safe_stage2_elem_cnt
567 and triton.next_power_of_2(stage2_elem_cnt) > safe_stage2_elem_cnt
568 ):
569 round_chunk_size = min(stage2_elem_cnt, reduction_chunk_size)
570 round_chunk_num = triton.cdiv(stage2_elem_cnt, round_chunk_size)
571 reduced_elem_cnt = round_chunk_num * k
573 reduced_vals = torch.empty(
574 batch_size * reduced_elem_cnt, device=x.device, dtype=x.dtype
575 )
576 reduced_local_indices = torch.empty(
577 batch_size * reduced_elem_cnt, device=x.device, dtype=torch.int64
578 )
580 with torch_device_fn.device(x.device):
581 topk_stage1_kernel[
582 batch_size,
583 round_chunk_num,
584 ](
585 reduced_vals,
586 reduced_local_indices,
587 candidate_vals,
588 k,
589 stage2_elem_cnt,
590 round_chunk_size,
591 descending,
592 )
594 candidate_indices = torch.gather(
595 candidate_indices,
596 1,
597 reduced_local_indices.view(batch_size, reduced_elem_cnt).to(torch.int64),
598 ).contiguous()
599 candidate_vals = reduced_vals.view(batch_size, reduced_elem_cnt)
600 stage2_elem_cnt = reduced_elem_cnt
602 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
604 with torch_device_fn.device(x.device):
605 topk_stage2_kernel[batch_size,](
606 stage2_out,
607 stage2_out_idx,
608 candidate_vals,
609 candidate_indices,
610 dim,
611 k,
612 stage2_elem_cnt,
613 BLOCK_SIZE,
614 descending,
615 )
617 return (stage2_out, stage2_out_idx)