Coverage for src/flag_gems/runtime/backend/_sunrise/ops/quantile.py: 0%
90 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry, tl_extra_shim
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"]
17def heur_block_q(args):
18 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16))
21def heur_block_n(args):
22 if args["N"] >= 65536:
23 return triton.next_power_of_2(triton.cdiv(args["N"], 512))
24 elif args["N"] >= 4096:
25 return triton.next_power_of_2(triton.cdiv(args["N"], 128))
26 elif args["N"] >= 64:
27 return 32
28 elif args["N"] >= 32:
29 return 4
30 else:
31 return 1
34@libentry()
35@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n})
36@triton.jit
37def quantile_kernel(
38 inp,
39 q,
40 out,
41 N,
42 M,
43 Q,
44 BLOCK_Q: tl.constexpr,
45 BLOCK_N: tl.constexpr,
46 interpolation: tl.constexpr,
47):
48 pid_Q = ext.program_id(0)
49 pid_N = ext.program_id(1)
50 ctype = inp.dtype.element_ty
52 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q)
53 mask_Q = offsets_Q < Q
54 q_ptrs = q + offsets_Q
56 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N)
57 mask_N = offsets_N < N
59 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :]
60 mask_out = mask_N[:, None] & mask_Q[None, :]
62 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1)
63 q_lower = tl.floor(q_block).to(tl.int32)
64 q_upper = tl.ceil(q_block).to(tl.int32)
66 inp_lower = tl.load(
67 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0
68 )
69 inp_upper = tl.load(
70 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0
71 )
73 if interpolation == "linear":
74 q_frac = q_block - q_lower
75 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out)
77 elif interpolation == "lower":
78 tl.store(out_ptrs, inp_lower, mask_out)
80 elif interpolation == "higher":
81 tl.store(out_ptrs, inp_upper, mask_out)
83 elif interpolation == "nearest":
84 q_round = tl_extra_shim.rint(q_block)
85 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower)
86 tl.store(out_ptrs, out_block, mask_out)
88 elif interpolation == "midpoint":
89 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out)
92def quantile(
93 inp, q, dim=None, keepdim=False, interpolation="linear", out=None
94) -> Tensor:
95 logger.debug("GEMS QUANTILE DIM")
96 assert torch.is_floating_point(inp)
97 assert dim is None or isinstance(dim, int)
98 assert isinstance(q, (float, torch.Tensor))
99 assert interpolation in INTERPOLATION_METHOD
101 M = inp.numel()
102 if isinstance(q, float):
103 q = torch.tensor(q, device=inp.device)
104 Q = 1
105 else:
106 Q = 1 if q.numel() == 1 else len(q)
108 assert M > 0
109 assert Q > 0
110 assert torch.all(q >= 0.0) and torch.all(q <= 1.0)
112 if dim is None:
113 inp = inp.ravel()
114 dim = 0
116 shape = list(inp.shape)
118 dim %= inp.ndim
119 inp = dim_compress(inp, dim)
120 M = shape[dim]
121 N = inp.numel() // M
123 # inp, _ = inp.sort() # Sort the input with torch.sort()
124 inp, _ = inp.cpu().sort() # [Tag][ZC] sort会报错
125 inp = inp.to(q.device)
127 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device)
129 grid = lambda meta: (
130 triton.cdiv(Q, meta["BLOCK_Q"]),
131 triton.cdiv(N, meta["BLOCK_N"]),
132 )
134 with torch_device_fn.device(inp.device):
135 quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation)
137 # output = output.permute(
138 # (-1,) + tuple(range(0, inp.ndim - 1))
139 # ) # Same as torch.quantile()
140 output = (
141 output.cpu().permute((-1,) + tuple(range(0, inp.ndim - 1))).to(q.device)
142 ) # Same as torch.quantile()
144 if keepdim:
145 output = output.unsqueeze(dim + 1)
146 if Q == 1:
147 output = output.squeeze(0)
149 if out is not None:
150 out.copy_(output)
151 return output