Coverage for src/flag_gems/runtime/backend/_cambricon/ops/quantile.py: 0%

224 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.core as core 

7from torch import Tensor 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2, zeros_like 

13except ImportError: 

14 pass 

15from flag_gems.runtime import torch_device_fn 

16from flag_gems.utils import libentry, tl_extra_shim 

17from flag_gems.utils import triton_lang_extension as ext 

18 

19from ..utils import MAX_GRID_SIZE_X 

20from .topk import _get_finfo_val 

21 

22logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

23 

24INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] 

25MAX_BITONIC_M = 1024 

26 

27""" 

28Note(Zhengzekang): 

29Refer from triton2.2 official `sort` implementation: 

30https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

31Just add indices to sort with values. 

32""" 

33 

34 

35@triton.jit 

36def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

37 n_outer: core.constexpr = x.numel >> n_dims 

38 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

39 

40 # tl.device_print("shape is: ", shape) 

41 y = core.reshape(x, shape) 

42 y_idx = core.reshape(ids, shape) 

43 

44 # slice left/right with 'stride' 2**(n_dims - i - 1) 

45 mask = core.arange(0, 2)[None, :, None] 

46 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

47 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

48 left = core.reshape(left, x.shape) 

49 right = core.reshape(right, x.shape) 

50 

51 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

52 ids.dtype 

53 ) 

54 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

55 ids.dtype 

56 ) 

57 left_idx = core.reshape(left_idx, ids.shape) 

58 right_idx = core.reshape(right_idx, ids.shape) 

59 

60 # actual compare-and-swap 

61 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

62 idtype = core.int8 

63 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

64 idtype = core.int16 

65 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

66 idtype = core.int32 

67 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

68 idtype = core.int64 

69 else: 

70 raise ValueError("Unsupported dtype") 

71 

72 ileft = left.to(idtype, bitcast=True) 

73 iright = right.to(idtype, bitcast=True) 

74 ix = x.to(idtype, bitcast=True) 

75 

76 cond = (left > right) ^ flip 

77 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

78 

79 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

80 idx_dtype = core.int8 

81 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

82 idx_dtype = core.int16 

83 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

84 idx_dtype = core.int32 

85 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

86 idx_dtype = core.int64 

87 else: 

88 raise ValueError("Unsupported dtype") 

89 

90 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

91 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

92 ix_idx = ids.to(idx_dtype, bitcast=True) 

93 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

94 

95 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

96 

97 

98@triton.jit 

99def _bitonic_merge( 

100 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

101): 

102 """ 

103 order_type 0 == ascending 

104 order_type 1 == descending 

105 order_type 2 == alternating 

106 """ 

107 n_outer: core.constexpr = x.numel >> n_dims 

108 core.static_assert(stage <= n_dims) 

109 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

110 # descending order. 

111 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

112 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

113 # a stride of 2) at this stage 

114 if order == 2: 

115 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

116 flip = core.reshape( 

117 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

118 ) 

119 else: 

120 flip = order 

121 # perform `stage` rounds of `compare-and-swap` 

122 for i in core.static_range(stage): 

123 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

124 return x, ids 

125 

126 

127@triton.jit 

128def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

129 # handle default dimension or check that it is the most minor dim 

130 _dim: core.constexpr = dim 

131 n_dims: core.constexpr = _log2(x.shape[_dim]) 

132 for i in core.static_range(1, n_dims + 1): 

133 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

134 return x, ids 

135 

136 

137def heur_block_q(args): 

138 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) 

139 

140 

141def heur_block_n(args): 

142 if args["N"] >= 65536: 

143 return triton.next_power_of_2(triton.cdiv(args["N"], 512)) 

144 elif args["N"] >= 4096: 

145 return triton.next_power_of_2(triton.cdiv(args["N"], 128)) 

146 elif args["N"] >= 64: 

147 return 32 

148 elif args["N"] >= 32: 

149 return 4 

150 else: 

151 return 1 

152 

153 

154@libentry() 

155@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) 

156@triton.jit 

157def quantile_kernel( 

158 inp, 

159 q, 

160 out, 

161 N, 

162 M, 

163 Q, 

164 BLOCK_Q: tl.constexpr, 

165 BLOCK_N: tl.constexpr, 

166 interpolation: tl.constexpr, 

167): 

168 pid_Q = ext.program_id(0) 

169 pid_N = ext.program_id(1) 

170 ctype = inp.dtype.element_ty 

171 

172 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) 

173 mask_Q = offsets_Q < Q 

174 q_ptrs = q + offsets_Q 

175 

176 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) 

177 mask_N = offsets_N < N 

178 

179 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] 

180 mask_out = mask_N[:, None] & mask_Q[None, :] 

181 

182 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) 

183 q_lower = tl.floor(q_block).to(tl.int32) 

184 q_upper = tl.ceil(q_block).to(tl.int32) 

185 

186 inp_lower = tl.load( 

187 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 

188 ) 

189 inp_upper = tl.load( 

190 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 

191 ) 

192 

193 if interpolation == "linear": 

194 q_frac = q_block - q_lower 

195 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) 

196 

197 elif interpolation == "lower": 

198 tl.store(out_ptrs, inp_lower, mask_out) 

199 

200 elif interpolation == "higher": 

201 tl.store(out_ptrs, inp_upper, mask_out) 

202 

203 elif interpolation == "nearest": 

204 q_round = tl_extra_shim.rint(q_block) 

205 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) 

206 tl.store(out_ptrs, out_block, mask_out) 

207 

208 elif interpolation == "midpoint": 

209 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) 

210 

211 

212@libentry() 

213@triton.jit 

214def quantile_bitonic_kernel( 

215 inp, 

216 q, 

217 out, 

218 N, 

219 M, 

220 Q, 

221 BLOCK_Q: tl.constexpr, 

222 BLOCK_M: tl.constexpr, 

223 interpolation: tl.constexpr, 

224): 

225 pid = ext.program_id(0) 

226 grid_0 = tl.num_programs(0) 

227 ctype = inp.dtype.element_ty 

228 

229 while pid < N: 

230 cols = tl.arange(0, BLOCK_M) 

231 mask_M = cols < M 

232 row_ptr = inp + pid * M 

233 mask_val = _get_finfo_val(ctype, return_max=True) 

234 vals = tl.load(row_ptr + cols, mask=mask_M, other=mask_val) 

235 vals = tl.where(vals.dtype.is_fp64(), vals, vals.to(tl.float32)) 

236 ids = tl.arange(0, BLOCK_M) 

237 sorted_vals, _ = argsort(vals, ids, 0, descending=False) 

238 

239 offsets_Q = tl.arange(0, BLOCK_Q) 

240 mask_Q = offsets_Q < Q 

241 q_vals = tl.load(q + offsets_Q, mask=mask_Q, other=0.0).to(tl.float32) 

242 q_scaled = q_vals * (M - 1) 

243 q_lower = tl.floor(q_scaled).to(tl.int32) 

244 q_upper = tl.ceil(q_scaled).to(tl.int32) 

245 

246 idx = tl.arange(0, BLOCK_M)[:, None] 

247 mask_lower = idx == q_lower[None, :] 

248 mask_upper = idx == q_upper[None, :] 

249 mask_lower_f = mask_lower.to(tl.float32) 

250 mask_upper_f = mask_upper.to(tl.float32) 

251 lower_vals = tl.sum(sorted_vals[:, None] * mask_lower_f, axis=0) 

252 upper_vals = tl.sum(sorted_vals[:, None] * mask_upper_f, axis=0) 

253 

254 if interpolation == "linear": 

255 q_frac = q_scaled - q_lower 

256 out_vals = lower_vals + (upper_vals - lower_vals) * q_frac 

257 elif interpolation == "lower": 

258 out_vals = lower_vals 

259 elif interpolation == "higher": 

260 out_vals = upper_vals 

261 elif interpolation == "nearest": 

262 q_round = tl_extra_shim.rint(q_scaled).to(tl.int32) 

263 out_vals = tl.where(q_round == q_upper, upper_vals, lower_vals) 

264 elif interpolation == "midpoint": 

265 out_vals = (lower_vals + upper_vals) * 0.5 

266 

267 out_ptr = out + pid * Q + offsets_Q 

268 tl.store(out_ptr, out_vals.to(ctype), mask=mask_Q) 

269 pid += grid_0 

270 

271 

272def quantile( 

273 inp, q, dim=None, keepdim=False, interpolation="linear", out=None 

274) -> Tensor: 

275 logger.debug("GEMS_CAMBRICON QUANTILE DIM") 

276 assert torch.is_floating_point(inp) 

277 assert dim is None or isinstance(dim, int) 

278 assert isinstance(q, (float, torch.Tensor)) 

279 assert interpolation in INTERPOLATION_METHOD 

280 

281 # Handle dim 

282 if dim is None: 

283 inp = inp.ravel() 

284 dim = 0 

285 if dim < 0: 

286 dim = dim + inp.ndim 

287 

288 # Handle q 

289 q_all_ones = False 

290 q_all_zeros = False 

291 if isinstance(q, float): 

292 q_all_ones = q == 1.0 

293 q_all_zeros = q == 0.0 

294 q = torch.tensor(q, device=inp.device, dtype=inp.dtype) 

295 Q = 1 

296 else: 

297 q = q.to(device=inp.device, dtype=inp.dtype) 

298 Q = 1 if q.numel() == 1 else len(q) 

299 

300 assert torch.all(q >= 0.0) and torch.all(q <= 1.0) 

301 

302 # Fast path: q == 0.0 -> min, q == 1.0 -> max (no sort needed) 

303 if q_all_ones or q_all_zeros: 

304 reduce_fn = torch.amax if q_all_ones else torch.amin 

305 if out is not None and Q == 1: 

306 reduce_fn(inp, dim=dim, keepdim=keepdim, out=out) 

307 return out 

308 output = reduce_fn(inp, dim=dim, keepdim=keepdim) 

309 if Q > 1: 

310 output = output.unsqueeze(0).expand(Q, *output.shape) 

311 if out is not None: 

312 out.copy_(output) 

313 return out 

314 return output 

315 

316 # handle input tensor 

317 if dim != inp.ndim - 1: 

318 inp = torch.movedim(inp, dim, -1).contiguous() 

319 else: 

320 inp = inp.contiguous() 

321 

322 M = inp.size(-1) 

323 N = inp.numel() // M 

324 

325 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) 

326 if M <= MAX_BITONIC_M: 

327 BLOCK_M = triton.next_power_of_2(M) 

328 BLOCK_Q = triton.next_power_of_2(min(Q, 16)) 

329 grid = min(N, MAX_GRID_SIZE_X // 4) 

330 with torch_device_fn.device(inp.device): 

331 quantile_bitonic_kernel[(grid,)]( 

332 inp, 

333 q, 

334 output, 

335 N, 

336 M, 

337 Q, 

338 BLOCK_Q=BLOCK_Q, 

339 BLOCK_M=BLOCK_M, 

340 interpolation=interpolation, 

341 ) 

342 else: 

343 sorted_vals, _ = inp.sort(dim=-1) 

344 grid = lambda meta: ( 

345 triton.cdiv(Q, meta["BLOCK_Q"]), 

346 triton.cdiv(N, meta["BLOCK_N"]), 

347 ) 

348 with torch_device_fn.device(inp.device): 

349 quantile_kernel[grid]( 

350 sorted_vals, q, output, N, M, Q, interpolation=interpolation 

351 ) 

352 

353 if Q == 1: 

354 output = output.squeeze(-1) 

355 else: 

356 output = output.movedim(-1, 0) 

357 if keepdim: 

358 output = output.unsqueeze(dim + (1 if Q != 1 else 0)) 

359 

360 if out is not None: 

361 out.copy_(output) 

362 return output