Coverage for src/flag_gems/runtime/backend/_sunrise/ops/quantile.py: 0%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, tl_extra_shim 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] 

15 

16 

17def heur_block_q(args): 

18 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) 

19 

20 

21def heur_block_n(args): 

22 if args["N"] >= 65536: 

23 return triton.next_power_of_2(triton.cdiv(args["N"], 512)) 

24 elif args["N"] >= 4096: 

25 return triton.next_power_of_2(triton.cdiv(args["N"], 128)) 

26 elif args["N"] >= 64: 

27 return 32 

28 elif args["N"] >= 32: 

29 return 4 

30 else: 

31 return 1 

32 

33 

34@libentry() 

35@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) 

36@triton.jit 

37def quantile_kernel( 

38 inp, 

39 q, 

40 out, 

41 N, 

42 M, 

43 Q, 

44 BLOCK_Q: tl.constexpr, 

45 BLOCK_N: tl.constexpr, 

46 interpolation: tl.constexpr, 

47): 

48 pid_Q = ext.program_id(0) 

49 pid_N = ext.program_id(1) 

50 ctype = inp.dtype.element_ty 

51 

52 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) 

53 mask_Q = offsets_Q < Q 

54 q_ptrs = q + offsets_Q 

55 

56 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) 

57 mask_N = offsets_N < N 

58 

59 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] 

60 mask_out = mask_N[:, None] & mask_Q[None, :] 

61 

62 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) 

63 q_lower = tl.floor(q_block).to(tl.int32) 

64 q_upper = tl.ceil(q_block).to(tl.int32) 

65 

66 inp_lower = tl.load( 

67 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 

68 ) 

69 inp_upper = tl.load( 

70 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 

71 ) 

72 

73 if interpolation == "linear": 

74 q_frac = q_block - q_lower 

75 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) 

76 

77 elif interpolation == "lower": 

78 tl.store(out_ptrs, inp_lower, mask_out) 

79 

80 elif interpolation == "higher": 

81 tl.store(out_ptrs, inp_upper, mask_out) 

82 

83 elif interpolation == "nearest": 

84 q_round = tl_extra_shim.rint(q_block) 

85 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) 

86 tl.store(out_ptrs, out_block, mask_out) 

87 

88 elif interpolation == "midpoint": 

89 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) 

90 

91 

92def quantile( 

93 inp, q, dim=None, keepdim=False, interpolation="linear", out=None 

94) -> Tensor: 

95 logger.debug("GEMS QUANTILE DIM") 

96 assert torch.is_floating_point(inp) 

97 assert dim is None or isinstance(dim, int) 

98 assert isinstance(q, (float, torch.Tensor)) 

99 assert interpolation in INTERPOLATION_METHOD 

100 

101 M = inp.numel() 

102 if isinstance(q, float): 

103 q = torch.tensor(q, device=inp.device) 

104 Q = 1 

105 else: 

106 Q = 1 if q.numel() == 1 else len(q) 

107 

108 assert M > 0 

109 assert Q > 0 

110 assert torch.all(q >= 0.0) and torch.all(q <= 1.0) 

111 

112 if dim is None: 

113 inp = inp.ravel() 

114 dim = 0 

115 

116 shape = list(inp.shape) 

117 

118 dim %= inp.ndim 

119 inp = dim_compress(inp, dim) 

120 M = shape[dim] 

121 N = inp.numel() // M 

122 

123 # inp, _ = inp.sort() # Sort the input with torch.sort() 

124 inp, _ = inp.cpu().sort() # [Tag][ZC] sort会报错 

125 inp = inp.to(q.device) 

126 

127 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) 

128 

129 grid = lambda meta: ( 

130 triton.cdiv(Q, meta["BLOCK_Q"]), 

131 triton.cdiv(N, meta["BLOCK_N"]), 

132 ) 

133 

134 with torch_device_fn.device(inp.device): 

135 quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation) 

136 

137 # output = output.permute( 

138 # (-1,) + tuple(range(0, inp.ndim - 1)) 

139 # ) # Same as torch.quantile() 

140 output = ( 

141 output.cpu().permute((-1,) + tuple(range(0, inp.ndim - 1))).to(q.device) 

142 ) # Same as torch.quantile() 

143 

144 if keepdim: 

145 output = output.unsqueeze(dim + 1) 

146 if Q == 1: 

147 output = output.squeeze(0) 

148 

149 if out is not None: 

150 out.copy_(output) 

151 return output