Coverage for src/flag_gems/runtime/backend/_arm/ops/quantile.py: 0%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems.utils import dim_compress, tl_extra_shim
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(__name__)
13INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]
16def heur_block_q(args):
17 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))
20def heur_block_n(args):
21 if args["N"] >= 65536:
22 return triton.next_power_of_2(triton.cdiv(args["N"], 512))
23 elif args["N"] >= 4096:
24 return triton.next_power_of_2(triton.cdiv(args["N"], 128))
25 elif args["N"] >= 64:
26 return 32
27 elif args["N"] >= 32:
28 return 4
29 else:
30 return 1
33# @libentry()
34@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
35@triton.jit
36def quantile_kernel(
37 inp,
38 q,
39 out,
40 N,
41 M,
42 Q,
43 BLOCK_Q: tl.constexpr,
44 BLOCK_N: tl.constexpr,
45 interpolation: tl.constexpr,
46):
47 pid_Q = tle.program_id(0)
48 pid_N = tle.program_id(1)
49 ctype = inp.dtype.element_ty
51 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
52 mask_Q = offsets_Q < Q
53 q_ptrs = q + offsets_Q
55 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
56 mask_N = offsets_N < N
58 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
59 mask_out = mask_N[:, None] & mask_Q[None, :]
61 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
62 q_lower = tl.floor(q_block).to(tl.int32)
63 q_upper = tl.ceil(q_block).to(tl.int32)
65 inp_lower = tl.load(
66 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
67 )
68 inp_upper = tl.load(
69 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
70 )
72 if interpolation == "linear":
73 q_frac = q_block - q_lower
74 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)
76 elif interpolation == "lower":
77 tl.store(out_ptrs, inp_lower, mask_out)
79 elif interpolation == "higher":
80 tl.store(out_ptrs, inp_upper, mask_out)
82 elif interpolation == "nearest":
83 q_round = tl_extra_shim.rint(q_block)
84 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
85 tl.store(out_ptrs, out_block, mask_out)
87 elif interpolation == "midpoint":
88 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)
91def quantile(
92 inp, q, dim=None, keepdim=False, interpolation="linear", out=None
93) -> Tensor:
94 logger.debug("GEMS QUANTILE DIM")
95 assert torch.is_floating_point(inp)
96 assert dim is None or isinstance(dim, int)
97 assert isinstance(q, (float, torch.Tensor))
98 assert interpolation in INTERPOLATION_METHOD
100 M = inp.numel()
101 if isinstance(q, float):
102 q = torch.tensor(q, device=inp.device)
103 Q = 1
104 else:
105 Q = 1 if q.numel() == 1 else len(q)
107 assert M > 0
108 assert Q > 0
109 assert torch.all(q >= 0.0) and torch.all(q <= 1.0)
111 if dim is None:
112 inp = inp.ravel()
113 dim = 0
115 shape = list(inp.shape)
117 dim %= inp.ndim
118 inp = dim_compress(inp, dim)
119 M = shape[dim]
120 N = inp.numel() // M
122 inp, _ = inp.sort() # Sort the input with torch.sort()
123 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)
125 grid = lambda meta: (
126 triton.cdiv(Q, meta["BLOCK_Q"]),
127 triton.cdiv(N, meta["BLOCK_N"]),
128 )
130 # with torch_device_fn.device(inp.device):
131 quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation)
133 output = output.permute(
134 (-1,) + tuple(range(0, inp.ndim - 1))
135 ) # Same as torch.quantile()
136 if keepdim:
137 output = output.unsqueeze(dim + 1)
138 if Q == 1:
139 output = output.squeeze(0)
141 if out is not None:
142 out.copy_(output)
143 return output