Coverage for src/flag_gems/runtime/backend/_cambricon/ops/topk.py: 0%
319 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.ops.topk import topk_stage1_kernel, topk_stage2_kernel
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils.triton_version_utils import HAS_TLE
13if HAS_TLE:
14 import triton.experimental.tle.language as tle_gpu
15else:
16 tle_gpu = None
18from ..utils import TOTAL_CORE_NUM
20logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
21_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
22_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
23_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
24_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
25_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
26_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
27_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
28_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
29_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
30_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
31_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
32_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
35@triton.jit
36def _get_finfo_val(
37 dtype,
38 return_max,
39):
40 if dtype is tl.float32:
41 if return_max:
42 return _MAX_FLOAT32_VAL
43 else:
44 return _MIN_FLOAT32_VAL
45 elif dtype is tl.float16:
46 if return_max:
47 return _MAX_FLOAT16_VAL
48 else:
49 return _MIN_FLOAT16_VAL
50 elif dtype is tl.bfloat16:
51 if return_max:
52 return _MAX_BFLOAT16_VAL
53 else:
54 return _MIN_BFLOAT16_VAL
57@triton.jit
58def _get_iinfo_val(
59 dtype,
60 return_max,
61):
62 if dtype is tl.int16:
63 if return_max:
64 return _MAX_INT16_VAL
65 else:
66 return _MIN_INT16_VAL
67 elif dtype is tl.int32:
68 if return_max:
69 return _MAX_INT32_VAL
70 else:
71 return _MIN_INT32_VAL
72 elif dtype is tl.int64:
73 if return_max:
74 return _MAX_INT64_VAL
75 else:
76 return _MIN_INT64_VAL
79@triton.jit
80def get_topk_bubble_res(
81 buffer, buffer_ind, k, axis, mask_val, DESCENDING, BLOCK_M, BLOCK_N
82):
83 kep_buffer_n = buffer
84 topk_buffer_index_n = buffer_ind
85 ret = tl.empty([BLOCK_M, k], dtype=buffer.dtype)
86 ret_ind = tl.empty([BLOCK_M, k], dtype=buffer_ind.dtype)
87 for k_ind in tl.range(0, k):
88 if DESCENDING:
89 sel_val, sel_index = tl.max(kep_buffer_n, axis=axis, return_indices=True)
90 else:
91 sel_val, sel_index = tl.min(kep_buffer_n, axis=axis, return_indices=True)
93 if BLOCK_M > 1:
94 mask_sel = tl.arange(0, BLOCK_N)[None, :] == sel_index[:, None]
95 tep_sel_index_buffer = tl.where(mask_sel, topk_buffer_index_n, 0)
96 sel_index_res = tl.max(tep_sel_index_buffer, axis=axis)
97 sel_val_res = sel_val
98 ret[:, k_ind] = sel_val_res
99 ret_ind[:, k_ind] = sel_index_res
101 # Update buffer.
102 kep_buffer_n = tl.where(mask_sel, mask_val, kep_buffer_n)
103 else:
104 indices = sel_index[0]
105 ret[:, k_ind] = sel_val
106 ret_ind[:, k_ind] = topk_buffer_index_n[:, indices]
107 # Update buffer.
108 kep_buffer_n[:, indices] = mask_val
109 return ret, ret_ind
112BLOCK_BATCH = [1, 16]
113BLOCK_N = [128, 512, 1024, 2048]
116def topk_cfggen():
117 num_stage = [1, 3]
118 configs = [
119 triton.Config({"TILE_M": m, "TILE_N": n}, num_warps=1, num_stages=s)
120 for m in BLOCK_BATCH
121 for n in BLOCK_N
122 for s in num_stage
123 ]
124 return configs
127def topk_config_prune(configs, named_args, **kwargs):
128 k = named_args["k"]
129 N = named_args["N"]
130 block_m = named_args["BLOCK_M"]
131 new_configs = []
133 for config in configs:
134 tile_n = config.kwargs["TILE_N"]
135 tile_m = config.kwargs["TILE_M"]
136 if tile_n < k or tile_m > block_m:
137 continue
138 if len(new_configs) >= 1:
139 last_tn = new_configs[-1].kwargs["TILE_N"]
140 last_tm = new_configs[-1].kwargs["TILE_M"]
141 if tile_n > N and last_tn >= N and last_tm == tile_m:
142 continue
143 config.kwargs["TILE_M_NUM"] = triton.cdiv(block_m, tile_m)
144 config.kwargs["TILE_N_NUM"] = triton.cdiv(N, tile_n)
145 new_configs.append(config)
147 if (N not in BLOCK_N) and (N <= max(BLOCK_N)):
148 for tm in BLOCK_BATCH:
149 new_configs.append(
150 triton.Config(
151 {
152 "TILE_M": tm,
153 "TILE_N": N,
154 "TILE_M_NUM": triton.cdiv(block_m, tm),
155 "TILE_N_NUM": 1,
156 },
157 num_warps=1,
158 num_stages=3,
159 )
160 )
161 return new_configs
164@libentry()
165@libtuner(
166 configs=topk_cfggen(),
167 key=["k", "N", "M", "BLOCK_M"],
168 prune_configs_by={"early_config_prune": topk_config_prune},
169)
170@triton.jit
171def topk_bubble_kernel(
172 inp_ptr,
173 out_ptr,
174 out_index_ptr,
175 k: tl.constexpr,
176 M: tl.constexpr,
177 N: tl.constexpr,
178 BLOCK_M: tl.constexpr,
179 TILE_M: tl.constexpr,
180 TILE_N: tl.constexpr,
181 TILE_M_NUM: tl.constexpr,
182 TILE_N_NUM: tl.constexpr,
183 DESCENDING: tl.constexpr,
184):
185 pid = tl.program_id(0)
186 m_st = pid * BLOCK_M
188 mask_val = _get_finfo_val(inp_ptr.dtype.element_ty, return_max=not DESCENDING)
189 mask_val = mask_val.to(inp_ptr.dtype.element_ty)
191 for m_block_ind in tl.range(0, TILE_M_NUM):
192 m_iter_st = m_block_ind * TILE_M + m_st
193 m_offset_val = m_iter_st + tl.arange(0, TILE_M)
194 m_offset = m_offset_val[:, None]
195 m_offset_mask = m_offset < M
197 topk_buffer_n = tl.full(
198 [TILE_M, TILE_N_NUM * k], value=mask_val, dtype=inp_ptr.dtype.element_ty
199 )
200 topk_buffer_index_n = tl.full(
201 [TILE_M, TILE_N_NUM * k], value=0, dtype=out_index_ptr.dtype.element_ty
202 )
203 for n_block_ind in tl.range(0, TILE_N_NUM):
204 n_st = n_block_ind * TILE_N
205 n_offset = n_st + tl.arange(0, TILE_N)[None, :]
206 n_offset_mask = n_offset < N
208 inp_mask = m_offset_mask & n_offset_mask
209 inp_ptrs = inp_ptr + m_offset * N + n_offset
210 block_inp_val = tl.load(inp_ptrs, mask=inp_mask, other=mask_val)
212 local_buffer, local_buffer_ind = get_topk_bubble_res(
213 block_inp_val,
214 n_offset.to(out_index_ptr.dtype.element_ty),
215 k,
216 1,
217 mask_val,
218 DESCENDING,
219 TILE_M,
220 TILE_N,
221 )
222 tep_index = n_block_ind * k
223 topk_buffer_n[:, tep_index : tep_index + k] = local_buffer
224 topk_buffer_index_n[:, tep_index : tep_index + k] = local_buffer_ind
225 if TILE_N_NUM > 1:
226 global_res, global_res_ind = get_topk_bubble_res(
227 topk_buffer_n,
228 topk_buffer_index_n,
229 k,
230 1,
231 mask_val,
232 DESCENDING,
233 TILE_M,
234 TILE_N_NUM * k,
235 )
236 else:
237 global_res = topk_buffer_n
238 global_res_ind = topk_buffer_index_n
240 # Store topk.
241 store_ptrs = m_offset * k + tl.arange(0, k)[None, :]
242 store_mask = m_offset_mask
243 tl.store(store_ptrs + out_ptr, global_res, store_mask)
244 tl.store(store_ptrs + out_index_ptr, global_res_ind, store_mask)
247if HAS_TLE:
249 @triton.jit
250 def _get_topmask_and_fullmask(x):
251 tl.static_assert(
252 x.dtype.is_int_unsigned(),
253 "floating-point value must be passed as bits",
254 )
255 tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
256 fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
257 tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
258 fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
259 return tm_arr, fm_arr
261 @triton.jit
262 def _fpval_to_key_with_nan(x, x_bits):
263 tm, fm = _get_topmask_and_fullmask(x_bits)
264 mask = tl.where((x_bits & tm) != 0, fm, tm)
265 key = x_bits ^ mask
266 return tl.where(x == x, key, fm)
268 @triton.jit
269 def _key_to_fpval(x):
270 tm, fm = _get_topmask_and_fullmask(x)
271 mask = tl.where((x & tm) != 0, tm, fm)
272 return x ^ mask
274 @libentry()
275 @triton.jit
276 def topk_kernel_radix_tle(
277 X,
278 Yv,
279 Yi,
280 stride_xm,
281 stride_ym,
282 n_cols,
283 K: tl.constexpr,
284 K_PAD: tl.constexpr,
285 BLOCK_N: tl.constexpr,
286 RADIX_BITS: tl.constexpr,
287 ):
288 pid = tl.program_id(0)
289 x_dtype = X.dtype.element_ty
290 x_nbits: tl.constexpr = x_dtype.primitive_bitwidth
291 if x_nbits < 16:
292 y_nbits: tl.constexpr = 32
293 else:
294 y_nbits: tl.constexpr = x_nbits * 2
295 x_utype = tl.dtype(f"uint{x_nbits}")
296 x_ultype = tl.dtype(f"uint{y_nbits}")
298 RADIX_SIZE: tl.constexpr = 1 << RADIX_BITS
299 RADIX_MASK: tl.constexpr = RADIX_SIZE - 1
300 bins = tl.arange(0, RADIX_SIZE)
301 one = tl.full([BLOCK_N], 1, tl.int32)
303 desired = tl.full((), 0, dtype=x_utype)
304 desired_mask = tl.full((), 0, dtype=x_utype)
305 k_to_find = tl.full((), K, dtype=tl.int32)
306 n_tiles = tl.cdiv(n_cols, BLOCK_N)
308 smem_counts = tle_gpu.gpu.alloc(
309 [RADIX_SIZE],
310 dtype=tl.int32,
311 layout=None,
312 scope=tle_gpu.gpu.smem,
313 nv_mma_shared_layout=False,
314 )
315 smem_count_ptrs = tle_gpu.gpu.local_ptr(smem_counts, (bins,))
317 for digit_pos in tl.static_range(x_nbits - RADIX_BITS, -1, -RADIX_BITS):
318 tl.store(smem_count_ptrs, tl.zeros([RADIX_SIZE], dtype=tl.int32))
319 for t in tl.range(0, n_tiles):
320 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
321 mask_n = offs_n < n_cols
322 x_ptrs = X + pid * stride_xm + offs_n
323 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
324 x_bits = x.to(x_utype, bitcast=True)
325 x_key = _fpval_to_key_with_nan(x, x_bits)
326 matches = (x_key & desired_mask) == desired
327 digit = ((x_key >> digit_pos) & RADIX_MASK).to(tl.int32)
328 valid = mask_n & matches
329 count_addrs = tle_gpu.gpu.local_ptr(smem_counts, (digit,))
330 tl.atomic_add(count_addrs, one, mask=valid, sem="relaxed", scope="cta")
332 counts = tl.load(smem_count_ptrs)
334 cumsum_desc = tl.cumsum(counts, axis=0, reverse=True)
335 tl.store(smem_count_ptrs, cumsum_desc)
337 selected_scalar = 0
338 counts_gt_scalar = 0
339 found = 0
340 for rev in tl.static_range(RADIX_SIZE):
341 d = RADIX_SIZE - 1 - rev
342 cum_d = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d,)))
343 if d + 1 < RADIX_SIZE:
344 cum_next = tl.load(tle_gpu.gpu.local_ptr(smem_counts, (d + 1,)))
345 else:
346 cum_next = 0
347 take = (found == 0) & (cum_d >= k_to_find) & (cum_next < k_to_find)
348 selected_scalar = tl.where(take, d, selected_scalar)
349 counts_gt_scalar = tl.where(take, cum_next, counts_gt_scalar)
350 found = tl.where(take, 1, found)
352 selected_u = selected_scalar.to(x_utype)
353 desired = desired | (selected_u << digit_pos)
354 desired_mask = desired_mask | (
355 tl.full((), RADIX_MASK, dtype=x_utype) << digit_pos
356 )
357 k_to_find = k_to_find - counts_gt_scalar
359 thr_key = desired
361 min_val = tl.full((), float("-inf"), tl.float32).to(x_dtype)
362 min_bits = min_val.to(x_utype, bitcast=True)
363 min_key = _fpval_to_key_with_nan(min_val, min_bits)
364 min_packed = min_key.to(x_ultype) << 16
365 offs_k = tl.arange(0, K_PAD)
367 smem_selected = tle_gpu.gpu.alloc(
368 [K_PAD],
369 dtype=x_ultype,
370 layout=None,
371 scope=tle_gpu.gpu.smem,
372 nv_mma_shared_layout=False,
373 )
374 smem_selected_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (offs_k,))
375 tl.store(smem_selected_ptrs, tl.full([K_PAD], min_packed, dtype=x_ultype))
377 smem_write_count = tle_gpu.gpu.alloc(
378 [1],
379 dtype=tl.int32,
380 layout=None,
381 scope=tle_gpu.gpu.smem,
382 nv_mma_shared_layout=False,
383 )
384 tl.store(tle_gpu.gpu.local_ptr(smem_write_count, (0,)), 0)
385 write_count_ptrs = tle_gpu.gpu.local_ptr(
386 smem_write_count, (tl.zeros([BLOCK_N], dtype=tl.int32),)
387 )
389 for t in tl.range(0, n_tiles):
390 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
391 mask_n = offs_n < n_cols
392 x_ptrs = X + pid * stride_xm + offs_n
393 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
394 x_bits = x.to(x_utype, bitcast=True)
395 x_key = _fpval_to_key_with_nan(x, x_bits)
396 idx_key = (n_cols - offs_n).to(x_ultype)
397 packed = (x_key.to(x_ultype) << 16) | idx_key
398 take_gt = mask_n & (x_key > thr_key)
399 pos = tl.atomic_add(
400 write_count_ptrs, one, mask=take_gt, sem="relaxed", scope="cta"
401 )
402 write_mask = take_gt & (pos < K_PAD)
403 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),))
404 tl.store(dst_ptrs, packed, mask=write_mask)
406 for t in tl.range(0, n_tiles):
407 offs_n = t * BLOCK_N + tl.arange(0, BLOCK_N)
408 mask_n = offs_n < n_cols
409 x_ptrs = X + pid * stride_xm + offs_n
410 x = tl.load(x_ptrs, mask=mask_n, other=float("-inf"))
411 x_bits = x.to(x_utype, bitcast=True)
412 x_key = _fpval_to_key_with_nan(x, x_bits)
413 idx_key = (n_cols - offs_n).to(x_ultype)
414 packed = (x_key.to(x_ultype) << 16) | idx_key
415 take_eq = mask_n & (x_key == thr_key)
416 pos = tl.atomic_add(
417 write_count_ptrs, one, mask=take_eq, sem="relaxed", scope="cta"
418 )
419 write_mask = take_eq & (pos < K_PAD)
420 dst_ptrs = tle_gpu.gpu.local_ptr(smem_selected, (pos.to(tl.int32),))
421 tl.store(dst_ptrs, packed, mask=write_mask)
423 selected_packed = tl.load(smem_selected_ptrs)
425 topk = tl.sort(selected_packed, dim=0, descending=True)
426 idx_mask = tl.full(topk.shape, (1 << 16) - 1, dtype=topk.dtype)
427 idx_raw = (topk & idx_mask).to(tl.uint32)
428 y_indices = (n_cols - idx_raw.to(tl.int32)).to(tl.int32)
429 y_values_raw = (topk >> 16).to(x_utype)
430 y_values = _key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
432 mask_k = offs_k < K
433 yv_ptrs = Yv + pid * stride_ym + offs_k
434 yi_ptrs = Yi + pid * stride_ym + offs_k
435 tl.store(yv_ptrs, y_values, mask=mask_k)
436 tl.store(yi_ptrs, y_indices, mask=mask_k)
439def topk(x, k, dim=-1, largest=True, sorted=True):
440 logger.debug("GEMS_CAMBRICON TOPK")
441 # If dim equals to last dim, we set it to -1.
442 if dim < 0:
443 dim = dim + x.ndim
445 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
446 assert sorted, "Currently only support sorted == True"
448 # Early return for k=0 to avoid Triton kernel compilation error.
449 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0.
450 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0.
451 if k == 0:
452 out_shape = list(x.shape[:-1]) + [0]
453 return (
454 torch.empty(out_shape, device=x.device, dtype=x.dtype),
455 torch.empty(out_shape, device=x.device, dtype=torch.int64),
456 )
458 descending = True
459 if not largest:
460 descending = False
462 topk_elem_cnt = x.shape[dim]
463 batch_size = math.prod(x.shape) // topk_elem_cnt
464 out_shape = x.shape[:-1] + (k,)
466 if (
467 HAS_TLE
468 and sorted
469 and descending
470 and x.is_cuda
471 and x.dtype in (torch.float16, torch.float32, torch.bfloat16)
472 and topk_elem_cnt <= 65535
473 and triton.next_power_of_2(k) <= 1024
474 ):
475 k_pad = triton.next_power_of_2(k)
476 out_shape = x.shape[:-1] + (k,)
477 y_vals = torch.empty(out_shape, device=x.device, dtype=x.dtype)
478 y_idx = torch.empty(out_shape, device=x.device, dtype=torch.int32)
479 block_n_radix = max(k_pad, min(512, triton.next_power_of_2(topk_elem_cnt)))
480 block_n_radix = min(block_n_radix, 1024)
482 x_2d = x.reshape(batch_size, topk_elem_cnt)
483 with torch_device_fn.device(x.device):
484 topk_kernel_radix_tle[(batch_size,)](
485 x_2d,
486 y_vals,
487 y_idx,
488 x_2d.stride(0),
489 y_vals.stride(0),
490 topk_elem_cnt,
491 K=k,
492 K_PAD=k_pad,
493 BLOCK_N=block_n_radix,
494 RADIX_BITS=4,
495 num_warps=4,
496 num_stages=1,
497 )
498 return (y_vals, y_idx.to(torch.int64))
500 if k <= math.log2(topk_elem_cnt):
501 logger.debug("GEMS_CAMBRICON TOPK USING BUBBLE")
502 topk_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
503 topk_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
505 def grid_fn(meta):
506 return (min(batch_size, TOTAL_CORE_NUM),)
508 block_m = triton.cdiv(batch_size, TOTAL_CORE_NUM)
509 topk_bubble_kernel[grid_fn](
510 x,
511 topk_out,
512 topk_out_idx,
513 k,
514 batch_size,
515 topk_elem_cnt,
516 block_m,
517 DESCENDING=descending,
518 )
519 return (topk_out, topk_out_idx)
520 else:
521 logger.debug("GEMS_CAMBRICON TOPK USING SORT")
522 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
523 if topk_elem_cnt < 1024:
524 chunk_size = 256
525 else:
526 chunk_size = 1024
528 # Note(Zhengzekang): We should promise chunk_size is larger than k.
529 if chunk_size < k:
530 chunk_size = triton.next_power_of_2(k)
532 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
534 stage1_out = torch.empty(
535 batch_size * chunk_num * k, device=x.device, dtype=x.dtype
536 )
537 stage1_out_idx = torch.empty(
538 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
539 )
541 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
542 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
544 with torch_device_fn.device(x.device):
545 topk_stage1_kernel[
546 batch_size,
547 chunk_num,
548 ](
549 stage1_out, # pointer to the output
550 stage1_out_idx, # pointer to the output
551 x, # pointer to the input
552 k,
553 topk_elem_cnt,
554 chunk_size,
555 descending,
556 )
557 stage2_elem_cnt = chunk_num * k
558 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
560 with torch_device_fn.device(x.device):
561 topk_stage2_kernel[batch_size,](
562 stage2_out,
563 stage2_out_idx,
564 stage1_out,
565 stage1_out_idx,
566 dim,
567 k,
568 stage2_elem_cnt,
569 BLOCK_SIZE,
570 descending,
571 )
573 return (stage2_out, stage2_out_idx)