Coverage for src/flag_gems/runtime/backend/_sunrise/ops/topk.py: 0%
179 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.language.core as core
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2
13except ImportError:
14 pass
16from flag_gems.runtime import torch_device_fn
17from flag_gems.utils import libentry
18from flag_gems.utils import triton_lang_extension as ext
19from flag_gems.utils.limits import get_dtype_max, get_dtype_min
21logger = logging.getLogger(__name__)
22_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
23_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
24_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
25_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
26_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
27_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
28_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
29_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
30_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
31_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
32_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
33_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
34_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
35_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
38@triton.jit
39def _get_finfo_val(
40 dtype,
41 return_max,
42):
43 if dtype is tl.float32:
44 if return_max:
45 return _MAX_FLOAT32_VAL
46 else:
47 return _MIN_FLOAT32_VAL
48 elif dtype is tl.float16:
49 if return_max:
50 return _MAX_FLOAT16_VAL
51 else:
52 return _MIN_FLOAT16_VAL
53 elif dtype is tl.bfloat16:
54 if return_max:
55 return _MAX_BFLOAT16_VAL
56 else:
57 return _MIN_BFLOAT16_VAL
60@triton.jit
61def _get_iinfo_val(
62 dtype,
63 return_max,
64):
65 if return_max:
66 return get_dtype_max(dtype)
67 else:
68 return get_dtype_min(dtype)
71@libentry()
72@triton.jit
73def topk_stage1_kernel(
74 y_ptr,
75 index_ptr,
76 x_ptr,
77 k,
78 N: tl.constexpr,
79 CHUNK_SIZE: tl.constexpr,
80 DESCENDING: tl.constexpr,
81):
82 cur_batch = ext.program_id(0)
83 cur_chunk_idx = ext.program_id(1)
84 chunk_num = ext.num_programs(1)
86 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
87 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
89 chunk_offset = cur_chunk_idx * CHUNK_SIZE
90 x_ptr += cur_batch * N + chunk_offset
92 cols = tl.arange(0, CHUNK_SIZE)
93 mask = (chunk_offset + cols) < N
95 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)
96 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)
97 for k_idx in range(k):
98 if DESCENDING:
99 chunk_select_val = tl.max(x_val)
100 chunk_select_idx = tl.argmax(x_val, axis=0)
101 else:
102 chunk_select_val = tl.min(x_val)
103 chunk_select_idx = tl.argmin(x_val, axis=0)
105 tl.store(y_ptr + k_idx, chunk_select_val)
106 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)
108 if DESCENDING:
109 x_val = tl.where(
110 cols == chunk_select_idx,
111 _get_finfo_val(tl.float32, return_max=False),
112 x_val,
113 )
114 else:
115 x_val = tl.where(
116 cols == chunk_select_idx,
117 _get_finfo_val(tl.float32, return_max=True),
118 x_val,
119 )
122"""
123Note(Zhengzekang):
124Refer from triton2.2 official `sort` implementation:
125https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
126Just add indices to sort with values.
127"""
130@triton.jit
131def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
132 n_outer: core.constexpr = x.numel >> n_dims
133 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
135 # tl.device_print("shape is: ", shape)
136 y = core.reshape(x, shape)
137 y_idx = core.reshape(ids, shape)
139 # slice left/right with 'stride' 2**(n_dims - i - 1)
140 mask = core.arange(0, 2)[None, :, None]
141 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
142 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
143 left = core.reshape(left, x.shape)
144 right = core.reshape(right, x.shape)
146 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
147 ids.dtype
148 )
149 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
150 ids.dtype
151 )
152 left_idx = core.reshape(left_idx, ids.shape)
153 right_idx = core.reshape(right_idx, ids.shape)
155 # actual compare-and-swap
156 # is_right indicator: 0 for left, 1 for right element in each pair.
157 is_right = core.reshape(
158 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
159 )
161 # Paired value: for left (is_right=0), the paired is right;
162 # for right (is_right=1), the paired is left.
163 paired_val = core.where(is_right, left, right)
164 paired_idx = core.where(is_right, left_idx, right_idx)
166 # Conditional swap following the official Triton pattern:
167 # swap if (current > paired) differs from (flip ^ is_right).
168 flip_right = (flip ^ is_right) != 0
169 cond = (x > paired_val) != flip_right
170 x = core.where(cond, paired_val, x)
171 ids = core.where(cond, paired_idx, ids)
173 return x, ids
176@triton.jit
177def _bitonic_merge(
178 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
179):
180 """
181 order_type 0 == ascending
182 order_type 1 == descending
183 order_type 2 == alternating
184 """
185 n_outer: core.constexpr = x.numel >> n_dims
186 core.static_assert(stage <= n_dims)
187 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
188 # descending order.
189 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
190 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
191 # a stride of 2) at this stage
192 if order == 2:
193 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
194 flip = core.reshape(
195 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
196 )
197 else:
198 flip = order
199 # perform `stage` rounds of `compare-and-swap`
200 for i in core.static_range(stage):
201 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
202 return x, ids
205@triton.jit
206def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
207 # handle default dimension or check that it is the most minor dim
208 _dim: core.constexpr = dim
209 n_dims: core.constexpr = _log2(x.shape[_dim])
210 for i in core.static_range(1, n_dims + 1):
211 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
212 return x, ids
215@libentry()
216@triton.jit
217def topk_stage2_kernel(
218 y_ptr,
219 index_ptr,
220 chunk_x,
221 chunk_index,
222 sort_dim: tl.constexpr,
223 k: tl.constexpr,
224 N: tl.constexpr,
225 BLOCK_SIZE: tl.constexpr,
226 DESCENDING: tl.constexpr,
227):
228 cur_batch = ext.program_id(0)
229 chunk_x += cur_batch * N
230 chunk_index += cur_batch * N
231 y_ptr += cur_batch * k
232 index_ptr += cur_batch * k
234 cols = tl.arange(0, BLOCK_SIZE)
235 mask = cols < N
237 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
238 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL
240 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)
241 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(
242 tl.int32
243 )
245 sorted_chunk_x, sorted_chunk_index = argsort(
246 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
247 )
248 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)
249 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)
252def topk(x, k, dim=-1, largest=True, sorted=True):
253 logger.debug("GEMS TOPK")
254 # If dim equals to last dim, we set it to -1.
255 if dim < 0:
256 dim = dim + x.ndim
258 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
259 # assert sorted, "Currently only support sorted == True"
261 # Early return for k=0 to avoid Triton kernel compilation error.
262 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0.
263 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0.
264 if k == 0:
265 out_shape = list(x.shape[:-1]) + [0]
266 return (
267 torch.empty(out_shape, device=x.device, dtype=x.dtype),
268 torch.empty(out_shape, device=x.device, dtype=torch.int64),
269 )
271 descending = True
272 if not largest:
273 descending = False
275 topk_elem_cnt = x.shape[dim]
276 batch_size = math.prod(x.shape) // topk_elem_cnt
278 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
279 if topk_elem_cnt < 1024:
280 chunk_size = 256
281 else:
282 chunk_size = 1024
284 # Note(Zhengzekang): We should promise chunk_size is larger than k.
285 if chunk_size < k:
286 chunk_size = triton.next_power_of_2(k)
288 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
290 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
291 stage1_out_idx = torch.empty(
292 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
293 )
295 out_shape = x.shape[:-1] + (k,)
296 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
297 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
299 with torch_device_fn.device(x.device):
300 topk_stage1_kernel[
301 batch_size,
302 chunk_num,
303 ](
304 stage1_out, # pointer to the output
305 stage1_out_idx, # pointer to the output
306 x, # pointer to the input
307 k,
308 topk_elem_cnt,
309 chunk_size,
310 descending,
311 )
312 stage2_elem_cnt = chunk_num * k
314 candidate_vals = stage1_out.view(batch_size, stage2_elem_cnt)
315 candidate_indices = stage1_out_idx.view(batch_size, stage2_elem_cnt)
316 # [sunrise fix] hits incorrect results once the stage2 bitonic sort spills
317 # into the multi-warp path (BLOCK_SIZE >= 512). Reduce the candidate set
318 # with additional stage1 passes until the final sort stays within 256 lanes.
319 """
320 1. topk_stage2_kernel设置num_warps=1可以绕过 ptpu 后端 multi-warp reduction 的共享内存 path bug,
321 根因是 ptpu 后端在 ReduceOpToLLVM.cpp 的 cross-warp reduction 路径存在共享内存线性化偏移计算问题
322 — 官方 tl.sort() 在 N≥512 时也有同样的错误。
323 2. 问题不只是“通用 inter-warp reduce lowering”这一处;至少在 topk 的完整 bitonic sort 路径里,还有别的 multi-warp 交互在出错。
324 最可能的下一步不是继续硬改通用 ReduceOp,而是针对 topk_stage2_kernel 的某个具体 stage 做精确复现,
325 直接盯 _compare_and_swap 后几轮的 TTIR/LLVM IR
326 """
327 safe_stage2_elem_cnt = 256
328 reduction_chunk_size = max(256, triton.next_power_of_2(k + 1))
329 while (
330 k <= safe_stage2_elem_cnt
331 and stage2_elem_cnt > safe_stage2_elem_cnt
332 and triton.next_power_of_2(stage2_elem_cnt) > safe_stage2_elem_cnt
333 ):
334 round_chunk_size = min(stage2_elem_cnt, reduction_chunk_size)
335 round_chunk_num = triton.cdiv(stage2_elem_cnt, round_chunk_size)
336 reduced_elem_cnt = round_chunk_num * k
338 reduced_vals = torch.empty(
339 batch_size * reduced_elem_cnt, device=x.device, dtype=x.dtype
340 )
341 reduced_local_indices = torch.empty(
342 batch_size * reduced_elem_cnt, device=x.device, dtype=torch.int64
343 )
345 with torch_device_fn.device(x.device):
346 topk_stage1_kernel[
347 batch_size,
348 round_chunk_num,
349 ](
350 reduced_vals,
351 reduced_local_indices,
352 candidate_vals,
353 k,
354 stage2_elem_cnt,
355 round_chunk_size,
356 descending,
357 )
359 candidate_indices = torch.gather(
360 candidate_indices,
361 1,
362 reduced_local_indices.view(batch_size, reduced_elem_cnt).to(torch.int64),
363 ).contiguous()
364 candidate_vals = reduced_vals.view(batch_size, reduced_elem_cnt)
365 stage2_elem_cnt = reduced_elem_cnt
367 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
369 with torch_device_fn.device(x.device):
370 topk_stage2_kernel[batch_size,](
371 stage2_out,
372 stage2_out_idx,
373 candidate_vals,
374 candidate_indices,
375 dim,
376 k,
377 stage2_elem_cnt,
378 BLOCK_SIZE,
379 descending,
380 )
382 return (stage2_out, stage2_out_idx)