Coverage for src/flag_gems/runtime/backend/_arm/ops/topk.py: 0%
195 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import numpy as np
5import torch
6import triton
7import triton.language as tl
8import triton.language.core as core
10try:
11 # TODO: Triton 2.1 does not implement _log2.
12 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
13 from triton.language.standard import _log2, zeros_like
14except ImportError:
15 pass
17from flag_gems.utils import triton_lang_extension as tle
18from flag_gems.utils.limits import get_dtype_max, get_dtype_min
20logger = logging.getLogger(__name__)
21_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min)
22_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max)
23_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min)
24_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max)
25_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min)
26_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max)
27_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min)
28_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max)
29_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min)
30_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max)
31_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min)
32_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max)
33_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min)
34_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max)
37@triton.jit
38def _get_finfo_val(
39 dtype,
40 return_max,
41):
42 if dtype is tl.float32:
43 if return_max:
44 return _MAX_FLOAT32_VAL
45 else:
46 return _MIN_FLOAT32_VAL
47 elif dtype is tl.float16:
48 if return_max:
49 return _MAX_FLOAT16_VAL
50 else:
51 return _MIN_FLOAT16_VAL
52 elif dtype is tl.bfloat16:
53 if return_max:
54 return _MAX_BFLOAT16_VAL
55 else:
56 return _MIN_BFLOAT16_VAL
59@triton.jit
60def _get_iinfo_val(
61 dtype,
62 return_max,
63):
64 if return_max:
65 return get_dtype_max(dtype)
66 else:
67 return get_dtype_min(dtype)
70# @libentry()
71@triton.jit
72def topk_stage1_kernel(
73 y_ptr,
74 index_ptr,
75 x_ptr,
76 k,
77 N: tl.constexpr,
78 CHUNK_SIZE: tl.constexpr,
79 DESCENDING: tl.constexpr,
80):
81 cur_batch = tle.program_id(0)
82 cur_chunk_idx = tle.program_id(1)
83 chunk_num = tle.num_programs(1)
85 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
86 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k
88 chunk_offset = cur_chunk_idx * CHUNK_SIZE
89 x_ptr += cur_batch * N + chunk_offset
91 cols = tl.arange(0, CHUNK_SIZE)
92 mask = (chunk_offset + cols) < N
94 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING)
95 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32)
96 for k_idx in range(k):
97 if DESCENDING:
98 chunk_select_val = tl.max(x_val)
99 chunk_select_idx = tl.argmax(x_val, axis=0)
100 else:
101 chunk_select_val = tl.min(x_val)
102 chunk_select_idx = tl.argmin(x_val, axis=0)
104 tl.store(y_ptr + k_idx, chunk_select_val)
105 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset)
107 if DESCENDING:
108 x_val = tl.where(
109 cols == chunk_select_idx,
110 _get_finfo_val(tl.float32, return_max=False),
111 x_val,
112 )
113 else:
114 x_val = tl.where(
115 cols == chunk_select_idx,
116 _get_finfo_val(tl.float32, return_max=True),
117 x_val,
118 )
121"""
122Note(Zhengzekang):
123Refer from triton2.2 official `sort` implementation:
124https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
125Just add indices to sort with values.
126"""
129@triton.jit
130def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
131 n_outer: core.constexpr = x.numel >> n_dims
132 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
134 # tl.device_print("shape is: ", shape)
135 y = core.reshape(x, shape)
136 y_idx = core.reshape(ids, shape)
138 # slice left/right with 'stride' 2**(n_dims - i - 1)
139 mask = core.arange(0, 2)[None, :, None]
140 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
141 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
142 left = core.reshape(left, x.shape)
143 right = core.reshape(right, x.shape)
145 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
146 ids.dtype
147 )
148 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
149 ids.dtype
150 )
151 left_idx = core.reshape(left_idx, ids.shape)
152 right_idx = core.reshape(right_idx, ids.shape)
154 # actual compare-and-swap
155 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
156 idtype = core.int8
157 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
158 idtype = core.int16
159 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
160 idtype = core.int32
161 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
162 idtype = core.int64
163 else:
164 raise ValueError("Unsupported dtype")
166 ileft = left.to(idtype, bitcast=True)
167 iright = right.to(idtype, bitcast=True)
168 ix = x.to(idtype, bitcast=True)
170 cond = (left > right) ^ flip
171 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
173 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
174 idx_dtype = core.int8
175 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
176 idx_dtype = core.int16
177 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
178 idx_dtype = core.int32
179 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
180 idx_dtype = core.int64
181 else:
182 raise ValueError("Unsupported dtype")
184 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
185 iright_idx = right_idx.to(idx_dtype, bitcast=True)
186 ix_idx = ids.to(idx_dtype, bitcast=True)
187 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
189 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
192@triton.jit
193def _bitonic_merge(
194 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
195):
196 """
197 order_type 0 == ascending
198 order_type 1 == descending
199 order_type 2 == alternating
200 """
201 n_outer: core.constexpr = x.numel >> n_dims
202 core.static_assert(stage <= n_dims)
203 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
204 # descending order.
205 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
206 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
207 # a stride of 2) at this stage
208 if order == 2:
209 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
210 flip = core.reshape(
211 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
212 )
213 else:
214 flip = order
215 # perform `stage` rounds of `compare-and-swap`
216 for i in core.static_range(stage):
217 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
218 return x, ids
221@triton.jit
222def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
223 # handle default dimension or check that it is the most minor dim
224 _dim: core.constexpr = dim
225 n_dims: core.constexpr = _log2(x.shape[_dim])
226 for i in core.static_range(1, n_dims + 1):
227 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
228 return x, ids
231# @libentry()
232@triton.jit
233def topk_stage2_kernel(
234 y_ptr,
235 index_ptr,
236 chunk_x,
237 chunk_index,
238 sort_dim: tl.constexpr,
239 k: tl.constexpr,
240 N: tl.constexpr,
241 BLOCK_SIZE: tl.constexpr,
242 DESCENDING: tl.constexpr,
243):
244 cur_batch = tle.program_id(0)
245 chunk_x += cur_batch * N
246 chunk_index += cur_batch * N
247 y_ptr += cur_batch * k
248 index_ptr += cur_batch * k
250 cols = tl.arange(0, BLOCK_SIZE)
251 mask = cols < N
253 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING)
254 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL
256 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32)
257 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to(
258 tl.int32
259 )
261 sorted_chunk_x, sorted_chunk_index = argsort(
262 chunk_x_val, chunk_index_val, 0, descending=DESCENDING
263 )
264 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k)
265 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k)
268def topk(x, k, dim=-1, largest=True, sorted=True):
269 logger.debug("GEMS TOPK")
270 if dim < 0:
271 dim = dim + x.ndim
273 if x.device.type == "cpu":
274 # CPU Triton topk kernel is unstable for some shapes/dtypes. Use a
275 # deterministic host fallback for correctness.
276 x_np = x.detach().cpu().to(torch.float32).numpy()
277 if largest:
278 part = np.argpartition(x_np, x_np.shape[dim] - k, axis=dim)
279 idx = np.take(
280 part, indices=range(x_np.shape[dim] - k, x_np.shape[dim]), axis=dim
281 )
282 vals = np.take_along_axis(x_np, idx, axis=dim)
283 if sorted:
284 order = np.flip(np.argsort(vals, axis=dim), axis=dim)
285 idx = np.take_along_axis(idx, order, axis=dim)
286 vals = np.take_along_axis(vals, order, axis=dim)
287 else:
288 part = np.argpartition(x_np, k - 1, axis=dim)
289 idx = np.take(part, indices=range(k), axis=dim)
290 vals = np.take_along_axis(x_np, idx, axis=dim)
291 if sorted:
292 order = np.argsort(vals, axis=dim)
293 idx = np.take_along_axis(idx, order, axis=dim)
294 vals = np.take_along_axis(vals, order, axis=dim)
295 vals_t = torch.from_numpy(vals).to(device=x.device, dtype=x.dtype)
296 idx_t = torch.from_numpy(idx.astype(np.int64, copy=False)).to(device=x.device)
297 return vals_t, idx_t
299 # If dim equals to last dim, we set it to -1.
300 assert dim == x.ndim - 1, "Currently only support topk in last dimension"
301 # assert sorted, "Currently only support sorted == True"
303 descending = True
304 if not largest:
305 descending = False
307 topk_elem_cnt = x.shape[dim]
308 batch_size = math.prod(x.shape) // topk_elem_cnt
310 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size.
311 if topk_elem_cnt < 1024:
312 chunk_size = 256
313 else:
314 chunk_size = 1024
316 # Note(Zhengzekang): We should promise chunk_size is larger than k.
317 if chunk_size < k:
318 chunk_size = triton.next_power_of_2(k)
320 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size)
322 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype)
323 stage1_out_idx = torch.empty(
324 batch_size * chunk_num * k, device=x.device, dtype=torch.int64
325 )
327 out_shape = x.shape[:-1] + (k,)
328 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
329 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64)
331 # with torch_device_fn.device(x.device):
332 topk_stage1_kernel[
333 batch_size,
334 chunk_num,
335 ](
336 stage1_out, # pointer to the output
337 stage1_out_idx, # pointer to the output
338 x, # pointer to the input
339 k,
340 topk_elem_cnt,
341 chunk_size,
342 descending,
343 )
344 stage2_elem_cnt = chunk_num * k
345 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt)
347 # with torch_device_fn.device(x.device):
348 topk_stage2_kernel[batch_size,](
349 stage2_out,
350 stage2_out_idx,
351 stage1_out,
352 stage1_out_idx,
353 dim,
354 k,
355 stage2_elem_cnt,
356 BLOCK_SIZE,
357 descending,
358 )
360 return (stage2_out, stage2_out_idx)