Coverage for src/flag_gems/runtime/backend/_cambricon/ops/quantile.py: 0%
224 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.core as core
7from torch import Tensor
9try:
10 # TODO: Triton 2.1 does not implement _log2.
11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton.
12 from triton.language.standard import _log2, zeros_like
13except ImportError:
14 pass
15from flag_gems.runtime import torch_device_fn
16from flag_gems.utils import libentry, tl_extra_shim
17from flag_gems.utils import triton_lang_extension as ext
19from ..utils import MAX_GRID_SIZE_X
20from .topk import _get_finfo_val
22logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
24INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]
25MAX_BITONIC_M = 1024
27"""
28Note(Zhengzekang):
29Refer from triton2.2 official `sort` implementation:
30https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404
31Just add indices to sort with values.
32"""
35@triton.jit
36def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr):
37 n_outer: core.constexpr = x.numel >> n_dims
38 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
40 # tl.device_print("shape is: ", shape)
41 y = core.reshape(x, shape)
42 y_idx = core.reshape(ids, shape)
44 # slice left/right with 'stride' 2**(n_dims - i - 1)
45 mask = core.arange(0, 2)[None, :, None]
46 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype)
47 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype)
48 left = core.reshape(left, x.shape)
49 right = core.reshape(right, x.shape)
51 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to(
52 ids.dtype
53 )
54 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to(
55 ids.dtype
56 )
57 left_idx = core.reshape(left_idx, ids.shape)
58 right_idx = core.reshape(right_idx, ids.shape)
60 # actual compare-and-swap
61 if core.constexpr(x.dtype.primitive_bitwidth) == 8:
62 idtype = core.int8
63 elif core.constexpr(x.dtype.primitive_bitwidth) == 16:
64 idtype = core.int16
65 elif core.constexpr(x.dtype.primitive_bitwidth) == 32:
66 idtype = core.int32
67 elif core.constexpr(x.dtype.primitive_bitwidth) == 64:
68 idtype = core.int64
69 else:
70 raise ValueError("Unsupported dtype")
72 ileft = left.to(idtype, bitcast=True)
73 iright = right.to(idtype, bitcast=True)
74 ix = x.to(idtype, bitcast=True)
76 cond = (left > right) ^ flip
77 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix))
79 if core.constexpr(ids.dtype.primitive_bitwidth) == 8:
80 idx_dtype = core.int8
81 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16:
82 idx_dtype = core.int16
83 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32:
84 idx_dtype = core.int32
85 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64:
86 idx_dtype = core.int64
87 else:
88 raise ValueError("Unsupported dtype")
90 ileft_idx = left_idx.to(idx_dtype, bitcast=True)
91 iright_idx = right_idx.to(idx_dtype, bitcast=True)
92 ix_idx = ids.to(idx_dtype, bitcast=True)
93 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx))
95 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True)
98@triton.jit
99def _bitonic_merge(
100 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr
101):
102 """
103 order_type 0 == ascending
104 order_type 1 == descending
105 order_type 2 == alternating
106 """
107 n_outer: core.constexpr = x.numel >> n_dims
108 core.static_assert(stage <= n_dims)
109 # flip denotes whether to re-arrange sub-sequences of elements in ascending or
110 # descending order.
111 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
112 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
113 # a stride of 2) at this stage
114 if order == 2:
115 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
116 flip = core.reshape(
117 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape
118 )
119 else:
120 flip = order
121 # perform `stage` rounds of `compare-and-swap`
122 for i in core.static_range(stage):
123 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
124 return x, ids
127@triton.jit
128def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr):
129 # handle default dimension or check that it is the most minor dim
130 _dim: core.constexpr = dim
131 n_dims: core.constexpr = _log2(x.shape[_dim])
132 for i in core.static_range(1, n_dims + 1):
133 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
134 return x, ids
137def heur_block_q(args):
138 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))
141def heur_block_n(args):
142 if args["N"] >= 65536:
143 return triton.next_power_of_2(triton.cdiv(args["N"], 512))
144 elif args["N"] >= 4096:
145 return triton.next_power_of_2(triton.cdiv(args["N"], 128))
146 elif args["N"] >= 64:
147 return 32
148 elif args["N"] >= 32:
149 return 4
150 else:
151 return 1
154@libentry()
155@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
156@triton.jit
157def quantile_kernel(
158 inp,
159 q,
160 out,
161 N,
162 M,
163 Q,
164 BLOCK_Q: tl.constexpr,
165 BLOCK_N: tl.constexpr,
166 interpolation: tl.constexpr,
167):
168 pid_Q = ext.program_id(0)
169 pid_N = ext.program_id(1)
170 ctype = inp.dtype.element_ty
172 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
173 mask_Q = offsets_Q < Q
174 q_ptrs = q + offsets_Q
176 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
177 mask_N = offsets_N < N
179 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
180 mask_out = mask_N[:, None] & mask_Q[None, :]
182 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
183 q_lower = tl.floor(q_block).to(tl.int32)
184 q_upper = tl.ceil(q_block).to(tl.int32)
186 inp_lower = tl.load(
187 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
188 )
189 inp_upper = tl.load(
190 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
191 )
193 if interpolation == "linear":
194 q_frac = q_block - q_lower
195 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)
197 elif interpolation == "lower":
198 tl.store(out_ptrs, inp_lower, mask_out)
200 elif interpolation == "higher":
201 tl.store(out_ptrs, inp_upper, mask_out)
203 elif interpolation == "nearest":
204 q_round = tl_extra_shim.rint(q_block)
205 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
206 tl.store(out_ptrs, out_block, mask_out)
208 elif interpolation == "midpoint":
209 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)
212@libentry()
213@triton.jit
214def quantile_bitonic_kernel(
215 inp,
216 q,
217 out,
218 N,
219 M,
220 Q,
221 BLOCK_Q: tl.constexpr,
222 BLOCK_M: tl.constexpr,
223 interpolation: tl.constexpr,
224):
225 pid = ext.program_id(0)
226 grid_0 = tl.num_programs(0)
227 ctype = inp.dtype.element_ty
229 while pid < N:
230 cols = tl.arange(0, BLOCK_M)
231 mask_M = cols < M
232 row_ptr = inp + pid * M
233 mask_val = _get_finfo_val(ctype, return_max=True)
234 vals = tl.load(row_ptr + cols, mask=mask_M, other=mask_val)
235 vals = tl.where(vals.dtype.is_fp64(), vals, vals.to(tl.float32))
236 ids = tl.arange(0, BLOCK_M)
237 sorted_vals, _ = argsort(vals, ids, 0, descending=False)
239 offsets_Q = tl.arange(0, BLOCK_Q)
240 mask_Q = offsets_Q < Q
241 q_vals = tl.load(q + offsets_Q, mask=mask_Q, other=0.0).to(tl.float32)
242 q_scaled = q_vals * (M - 1)
243 q_lower = tl.floor(q_scaled).to(tl.int32)
244 q_upper = tl.ceil(q_scaled).to(tl.int32)
246 idx = tl.arange(0, BLOCK_M)[:, None]
247 mask_lower = idx == q_lower[None, :]
248 mask_upper = idx == q_upper[None, :]
249 mask_lower_f = mask_lower.to(tl.float32)
250 mask_upper_f = mask_upper.to(tl.float32)
251 lower_vals = tl.sum(sorted_vals[:, None] * mask_lower_f, axis=0)
252 upper_vals = tl.sum(sorted_vals[:, None] * mask_upper_f, axis=0)
254 if interpolation == "linear":
255 q_frac = q_scaled - q_lower
256 out_vals = lower_vals + (upper_vals - lower_vals) * q_frac
257 elif interpolation == "lower":
258 out_vals = lower_vals
259 elif interpolation == "higher":
260 out_vals = upper_vals
261 elif interpolation == "nearest":
262 q_round = tl_extra_shim.rint(q_scaled).to(tl.int32)
263 out_vals = tl.where(q_round == q_upper, upper_vals, lower_vals)
264 elif interpolation == "midpoint":
265 out_vals = (lower_vals + upper_vals) * 0.5
267 out_ptr = out + pid * Q + offsets_Q
268 tl.store(out_ptr, out_vals.to(ctype), mask=mask_Q)
269 pid += grid_0
272def quantile(
273 inp, q, dim=None, keepdim=False, interpolation="linear", out=None
274) -> Tensor:
275 logger.debug("GEMS_CAMBRICON QUANTILE DIM")
276 assert torch.is_floating_point(inp)
277 assert dim is None or isinstance(dim, int)
278 assert isinstance(q, (float, torch.Tensor))
279 assert interpolation in INTERPOLATION_METHOD
281 # Handle dim
282 if dim is None:
283 inp = inp.ravel()
284 dim = 0
285 if dim < 0:
286 dim = dim + inp.ndim
288 # Handle q
289 q_all_ones = False
290 q_all_zeros = False
291 if isinstance(q, float):
292 q_all_ones = q == 1.0
293 q_all_zeros = q == 0.0
294 q = torch.tensor(q, device=inp.device, dtype=inp.dtype)
295 Q = 1
296 else:
297 q = q.to(device=inp.device, dtype=inp.dtype)
298 Q = 1 if q.numel() == 1 else len(q)
300 assert torch.all(q >= 0.0) and torch.all(q <= 1.0)
302 # Fast path: q == 0.0 -> min, q == 1.0 -> max (no sort needed)
303 if q_all_ones or q_all_zeros:
304 reduce_fn = torch.amax if q_all_ones else torch.amin
305 if out is not None and Q == 1:
306 reduce_fn(inp, dim=dim, keepdim=keepdim, out=out)
307 return out
308 output = reduce_fn(inp, dim=dim, keepdim=keepdim)
309 if Q > 1:
310 output = output.unsqueeze(0).expand(Q, *output.shape)
311 if out is not None:
312 out.copy_(output)
313 return out
314 return output
316 # handle input tensor
317 if dim != inp.ndim - 1:
318 inp = torch.movedim(inp, dim, -1).contiguous()
319 else:
320 inp = inp.contiguous()
322 M = inp.size(-1)
323 N = inp.numel() // M
325 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)
326 if M <= MAX_BITONIC_M:
327 BLOCK_M = triton.next_power_of_2(M)
328 BLOCK_Q = triton.next_power_of_2(min(Q, 16))
329 grid = min(N, MAX_GRID_SIZE_X // 4)
330 with torch_device_fn.device(inp.device):
331 quantile_bitonic_kernel[(grid,)](
332 inp,
333 q,
334 output,
335 N,
336 M,
337 Q,
338 BLOCK_Q=BLOCK_Q,
339 BLOCK_M=BLOCK_M,
340 interpolation=interpolation,
341 )
342 else:
343 sorted_vals, _ = inp.sort(dim=-1)
344 grid = lambda meta: (
345 triton.cdiv(Q, meta["BLOCK_Q"]),
346 triton.cdiv(N, meta["BLOCK_N"]),
347 )
348 with torch_device_fn.device(inp.device):
349 quantile_kernel[grid](
350 sorted_vals, q, output, N, M, Q, interpolation=interpolation
351 )
353 if Q == 1:
354 output = output.squeeze(-1)
355 else:
356 output = output.movedim(-1, 0)
357 if keepdim:
358 output = output.unsqueeze(dim + (1 if Q != 1 else 0))
360 if out is not None:
361 out.copy_(output)
362 return output