Coverage for src/flag_gems/ops/topk.py: 24%
327 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.language.core as core
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2, zeros_like
13except ImportError:
14 pass
16from flag_gems.runtime import torch_device_fn
17from flag_gems.utils import libentry
18from flag_gems.utils import triton_lang_extension as ext
19from flag_gems.utils.limits import get_dtype_max, get_dtype_min
20from flag_gems.utils.triton_version_utils import HAS_TLE
22if HAS_TLE:
23 import triton.experimental.tle.language as tle_gpu
24else:
25 tle_gpu = None
27logger = logging.getLogger(__name__)
28_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
29_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
30_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
31_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
32_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
33_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
34_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
35_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
36_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
37_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
38_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
39_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
40_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
41_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
44@triton.jit
45def _get_finfo_val(
46 dtype,
47 return_max,
48):
49 if dtype is tl.float32:
50 if return_max:
51 return _MAX_FLOAT32_VAL
52 else:
53 return _MIN_FLOAT32_VAL
54 elif dtype is tl.float16:
55 if return_max:
56 return _MAX_FLOAT16_VAL
57 else:
58 return _MIN_FLOAT16_VAL
59 elif dtype is tl.bfloat16:
60 if return_max:
61 return _MAX_BFLOAT16_VAL
62 else:
63 return _MIN_BFLOAT16_VAL
66@triton.jit
67def _get_iinfo_val(
68 dtype,
69 return_max,
70):
71 if return_max:
72 return get_dtype_max(dtype)
73 else:
74 return get_dtype_min(dtype)
77@libentry()
78@triton.jit
79def topk_stage1_kernel(
80 y_ptr,
81 index_ptr,
82 x_ptr,
83 k,
84 N: tl.constexpr,
85 CHUNK_SIZE: tl.constexpr,
86 DESCENDING: tl.constexpr,
87):
88 cur_batch = ext.program_id(0)
89 cur_chunk_idx = ext.program_id(1)
90 chunk_num = ext.num_programs(1)
92 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
93 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
95 chunk_offset = cur_chunk_idx * CHUNK_SIZE
96 x_ptr += cur_batch * N + chunk_offset
98 cols = tl.arange(0, CHUNK_SIZE)
99 mask = (chunk_offset + cols) < N
101 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)
102 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)
103 for k_idx in range(k):
104 if DESCENDING:
105 chunk_select_val = tl.max(x_val)
106 chunk_select_idx = tl.argmax(x_val, axis=0)
107 else:
108 chunk_select_val = tl.min(x_val)
109 chunk_select_idx = tl.argmin(x_val, axis=0)
111 tl.store(y_ptr + k_idx, chunk_select_val)
112 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)
114 if DESCENDING:
115 x_val = tl.where(
116 cols == chunk_select_idx,
117 _get_finfo_val(tl.float32, return_max=False),
118 x_val,
119 )
120 else:
121 x_val = tl.where(
122 cols == chunk_select_idx,
123 _get_finfo_val(tl.float32, return_max=True),
124 x_val,
125 )
128"""
129Note(Zhengzekang):
130Refer from triton2.2 official `sort` implementation:
131https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
132Just add indices to sort with values.
133"""
136@triton.jit
137def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
138 n_outer: core.constexpr = x.numel >> n_dims
139 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
141 # tl.device_print("shape is: ", shape)
142 y = core.reshape(x, shape)
143 y_idx = core.reshape(ids, shape)
145 # slice left/right with 'stride' 2**(n_dims - i - 1)
146 mask = core.arange(0, 2)[None, :, None]
147 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
148 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
149 left = core.reshape(left, x.shape)
150 right = core.reshape(right, x.shape)
152 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
153 ids.dtype
154 )
155 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
156 ids.dtype
157 )
158 left_idx = core.reshape(left_idx, ids.shape)
159 right_idx = core.reshape(right_idx, ids.shape)
161 # actual compare-and-swap
162 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
163 idtype = core.int8
164 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
165 idtype = core.int16
166 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
167 idtype = core.int32
168 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
169 idtype = core.int64
170 else:
171 raise ValueError("Unsupported dtype")
173 ileft = left.to(idtype, bitcast=True)
174 iright = right.to(idtype, bitcast=True)
175 ix = x.to(idtype, bitcast=True)
177 cond = (left > right) ^ flip
178 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
180 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
181 idx_dtype = core.int8
182 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
183 idx_dtype = core.int16
184 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
185 idx_dtype = core.int32
186 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
187 idx_dtype = core.int64
188 else:
189 raise ValueError("Unsupported dtype")
191 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
192 iright_idx = right_idx.to(idx_dtype, bitcast=True)
193 ix_idx = ids.to(idx_dtype, bitcast=True)
194 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
196 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
199@triton.jit
200def _bitonic_merge(
201 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
202):
203 """
204 order_type 0 == ascending
205 order_type 1 == descending
206 order_type 2 == alternating
207 """
208 n_outer: core.constexpr = x.numel >> n_dims
209 core.static_assert(stage <= n_dims)
210 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
211 # descending order.
212 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
213 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
214 # a stride of 2) at this stage
215 if order == 2:
216 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
217 flip = core.reshape(
218 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
219 )
220 else:
221 flip = order
222 # perform `stage` rounds of `compare-and-swap`
223 for i in core.static_range(stage):
224 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
225 return x, ids
228@triton.jit
229def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
230 # handle default dimension or check that it is the most minor dim
231 _dim: core.constexpr = dim
232 n_dims: core.constexpr = _log2(x.shape[_dim])
233 for i in core.static_range(1, n_dims + 1):
234 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
235 return x, ids
238@libentry()
239@triton.jit
240def topk_stage2_kernel(
241 y_ptr,
242 index_ptr,
243 chunk_x,
244 chunk_index,
245 sort_dim: tl.constexpr,
246 k: tl.constexpr,
247 N: tl.constexpr,
248 BLOCK_SIZE: tl.constexpr,
249 DESCENDING: tl.constexpr,
250):
251 cur_batch = ext.program_id(0)
252 chunk_x += cur_batch * N
253 chunk_index += cur_batch * N
254 y_ptr += cur_batch * k
255 index_ptr += cur_batch * k
257 cols = tl.arange(0, BLOCK_SIZE)
258 mask = cols < N
260 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
261 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL
263 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)
264 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(
265 tl.int32
266 )
268 sorted_chunk_x, sorted_chunk_index = argsort(
269 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
270 )
271 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)
272 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)
275if HAS_TLE:
277 @triton.jit
278 def _get_topmask_and_fullmask(x):
279 tl.static_assert(
280 x.dtype.is_int_unsigned(),
281 "floating-point value must be passed as bits",
282 )
283 tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
284 fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
285 tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
286 fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
287 return tm_arr, fm_arr
289 @triton.jit
290 def _fpval_to_key_with_nan(x, x_bits):
291 tm, fm = _get_topmask_and_fullmask(x_bits)
292 mask = tl.where((x_bits & tm) != 0, fm, tm)
293 key = x_bits ^ mask
294 return tl.where(x == x, key, fm)
296 @triton.jit
297 def _key_to_fpval(x):
298 tm, fm = _get_topmask_and_fullmask(x)
299 mask = tl.where((x & tm) != 0, tm, fm)
300 return x ^ mask
302 @libentry()
303 @triton.jit
304 def topk_kernel_radix_tle(
305 X,
306 Yv,
307 Yi,
308 stride_xm,
309 stride_ym,
310 n_cols,
311 K: tl.constexpr,
312 K_PAD: tl.constexpr,
313 BLOCK_N: tl.constexpr,
314 RADIX_BITS: tl.constexpr,
315 ):
316 pid = tl.program_id(0)
317 x_dtype = X.dtype.element_ty
318 x_nbits: tl.constexpr = x_dtype.primitive_bitwidth
319 if x_nbits < 16:
320 y_nbits: tl.constexpr = 32
321 else:
322 y_nbits: tl.constexpr = x_nbits * 2
323 x_utype = tl.dtype(f"uint{x_nbits}")
324 x_ultype = tl.dtype(f"uint{y_nbits}")
326 RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS
327 RADIX_MASK: tl.constexpr = RADIX_SIZE - 1
328 bins = tl.arange(0, RADIX_SIZE)
329 one = tl.full([BLOCK_N], 1, tl.int32)
331 desired = tl.full((), 0, dtype=x_utype)
332 desired_mask = tl.full((), 0, dtype=x_utype)
333 k_to_find = tl.full((), K, dtype=tl.int32)
334 n_tiles = tl.cdiv(n_cols, BLOCK_N)
336 smem_counts = tle_gpu.gpu.alloc(
337 [RADIX_SIZE],
338 dtype=tl.int32,
339 layout=None,
340 scope=tle_gpu.gpu.smem,
341 nv_mma_shared_layout=False,
342 )
343 smem_count_ptrs = tle_gpu.gpu.local_ptr(smem_counts, (bins,))
345 for digit_pos in tl.static_range(x_nbits - RADIX_BITS, -1, -RADIX_BITS):
346 tl.store(smem_count_ptrs, tl.zeros([RADIX_SIZE], dtype=tl.int32))
347 for t in tl.range(0, n_tiles):
348 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
349 mask_n = offs_n < n_cols
350 x_ptrs = X + pid * stride_xm + offs_n
351 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
352 x_bits = x.to(x_utype, bitcast=True)
353 x_key = _fpval_to_key_with_nan(x, x_bits)
354 matches = (x_key & desired_mask) == desired
355 digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32)
356 valid = mask_n & matches
357 count_addrs = tle_gpu.gpu.local_ptr(smem_counts, (digit,))
358 tl.atomic_add(count_addrs, one, mask=valid, sem="relaxed", scope="cta")
360 counts = tl.load(smem_count_ptrs)
362 cumsum_desc = tl.cumsum(counts, axis=0, reverse=True)
363 tl.store(smem_count_ptrs, cumsum_desc)
365 selected_scalar = 0
366 counts_gt_scalar = 0
367 found = 0
368 for rev in tl.static_range(RADIX_SIZE):
369 d = RADIX_SIZE - 1 - rev
370 cum_d = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d,)))
371 if d + 1 < RADIX_SIZE:
372 cum_next = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d + 1,)))
373 else:
374 cum_next = 0
375 take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find)
376 selected_scalar = tl.where(take, d, selected_scalar)
377 counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar)
378 found = tl.where(take, 1, found)
380 selected_u = selected_scalar.to(x_utype)
381 desired = desired | (selected_u << digit_pos)
382 desired_mask = desired_mask | (
383 tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos
384 )
385 k_to_find = k_to_find - counts_gt_scalar
387 thr_key = desired
389 min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype)
390 min_bits = min_val.to(x_utype, bitcast=True)
391 min_key = _fpval_to_key_with_nan(min_val, min_bits)
392 min_packed = min_key.to(x_ultype) << 16
393 offs_k = tl.arange(0, K_PAD)
395 smem_selected = tle_gpu.gpu.alloc(
396 [K_PAD],
397 dtype=x_ultype,
398 layout=None,
399 scope=tle_gpu.gpu.smem,
400 nv_mma_shared_layout=False,
401 )
402 smem_selected_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (offs_k,))
403 tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype))
405 smem_write_count = tle_gpu.gpu.alloc(
406 [1],
407 dtype=tl.int32,
408 layout=None,
409 scope=tle_gpu.gpu.smem,
410 nv_mma_shared_layout=False,
411 )
412 tl.store(tle_gpu.gpu.local_ptr(smem_write_count, (0,)), 0)
413 write_count_ptrs = tle_gpu.gpu.local_ptr(
414 smem_write_count, (tl.zeros([BLOCK_N], dtype=tl.int32),)
415 )
417 for t in tl.range(0, n_tiles):
418 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
419 mask_n = offs_n < n_cols
420 x_ptrs = X + pid * stride_xm + offs_n
421 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
422 x_bits = x.to(x_utype, bitcast=True)
423 x_key = _fpval_to_key_with_nan(x, x_bits)
424 idx_key = (n_cols - offs_n).to(x_ultype)
425 packed = (x_key.to(x_ultype) << 16) | idx_key
426 take_gt = mask_n & (x_key > thr_key)
427 pos = tl.atomic_add(
428 write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta"
429 )
430 write_mask = take_gt & (pos < K_PAD)
431 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),))
432 tl.store(dst_ptrs, packed, mask=write_mask)
434 for t in tl.range(0, n_tiles):
435 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
436 mask_n = offs_n < n_cols
437 x_ptrs = X + pid * stride_xm + offs_n
438 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
439 x_bits = x.to(x_utype, bitcast=True)
440 x_key = _fpval_to_key_with_nan(x, x_bits)
441 idx_key = (n_cols - offs_n).to(x_ultype)
442 packed = (x_key.to(x_ultype) << 16) | idx_key
443 take_eq = mask_n & (x_key == thr_key)
444 pos = tl.atomic_add(
445 write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta"
446 )
447 write_mask = take_eq & (pos < K_PAD)
448 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),))
449 tl.store(dst_ptrs, packed, mask=write_mask)
451 selected_packed = tl.load(smem_selected_ptrs)
453 topk = tl.sort(selected_packed, dim=0, descending=True)
454 idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype)
455 idx_raw = (topk & idx_mask).to(tl.uint32)
456 y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32)
457 y_values_raw = (topk >> 16).to(x_utype)
458 y_values = _key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
460 mask_k = offs_k < K
461 yv_ptrs = Yv + pid * stride_ym + offs_k
462 yi_ptrs = Yi + pid * stride_ym + offs_k
463 tl.store(yv_ptrs, y_values, mask=mask_k)
464 tl.store(yi_ptrs, y_indices, mask=mask_k)
467def topk(x, k, dim=-1, largest=True, sorted=True):
468 logger.debug("GEMS TOPK")
469 # If dim equals to last dim, we set it to -1.
470 if dim < 0:
471 dim = dim + x.ndim
473 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
474 # assert sorted, "Currently only support sorted == True"
476 # Early return for k=0 to avoid Triton kernel compilation error.
477 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0.
478 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0.
479 if k == 0:
480 out_shape = list(x.shape[:-1]) + [0]
481 return (
482 torch.empty(out_shape, device=x.device, dtype=x.dtype),
483 torch.empty(out_shape, device=x.device, dtype=torch.int64),
484 )
486 descending = True
487 if not largest:
488 descending = False
490 topk_elem_cnt = x.shape[dim]
491 batch_size = math.prod(x.shape) // topk_elem_cnt
493 if (
494 HAS_TLE
495 and sorted
496 and descending
497 and x.is_cuda
498 and x.dtype in (torch.float16, torch.float32, torch.bfloat16)
499 and k >= 8
500 and topk_elem_cnt <= 65535
501 and triton.next_power_of_2(k) <= 1024
502 ):
503 k_pad = triton.next_power_of_2(k)
504 out_shape = x.shape[:-1] + (k,)
505 y_vals = torch.empty(out_shape, device=x.device, dtype=x.dtype)
506 y_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32)
507 block_n_radix = max(k_pad, min(512, triton.next_power_of_2(topk_elem_cnt)))
508 block_n_radix = min(block_n_radix, 1024)
510 x_2d = x.reshape(batch_size, topk_elem_cnt)
511 y_vals_2d = y_vals.reshape(batch_size, k)
512 y_idx_2d = y_idx.reshape(batch_size, k)
513 with torch_device_fn.device(x.device):
514 topk_kernel_radix_tle[(batch_size,)](
515 x_2d,
516 y_vals_2d,
517 y_idx_2d,
518 x_2d.stride(0),
519 y_vals_2d.stride(0),
520 topk_elem_cnt,
521 K=k,
522 K_PAD=k_pad,
523 BLOCK_N=block_n_radix,
524 RADIX_BITS=4,
525 num_warps=4,
526 num_stages=1,
527 )
528 return (y_vals, y_idx.to(torch.int64))
530 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
531 if topk_elem_cnt < 1024:
532 chunk_size = 256
533 else:
534 chunk_size = 1024
536 # Note(Zhengzekang): We should promise chunk_size is larger than k.
537 if chunk_size < k:
538 chunk_size = triton.next_power_of_2(k)
540 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
542 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
543 stage1_out_idx = torch.empty(
544 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
545 )
547 out_shape = x.shape[:-1] + (k,)
548 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
549 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
551 with torch_device_fn.device(x.device):
552 topk_stage1_kernel[
553 batch_size,
554 chunk_num,
555 ](
556 stage1_out, # pointer to the output
557 stage1_out_idx, # pointer to the output
558 x, # pointer to the input
559 k,
560 topk_elem_cnt,
561 chunk_size,
562 descending,
563 )
564 stage2_elem_cnt = chunk_num * k
565 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
567 with torch_device_fn.device(x.device):
568 topk_stage2_kernel[batch_size,](
569 stage2_out,
570 stage2_out_idx,
571 stage1_out,
572 stage1_out_idx,
573 dim,
574 k,
575 stage2_elem_cnt,
576 BLOCK_SIZE,
577 descending,
578 )
580 return (stage2_out, stage2_out_idx)