Coverage for src/flag_gems/runtime/backend/_arm/ops/quantile.py: 0%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems.utils import dim_compress, tl_extra_shim 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13INTERPOLATION_METHOD = ["linear", "lower", "higher", "nearest", "midpoint"] 

14 

15 

16def heur_block_q(args): 

17 return triton.next_power_of_2(min(triton.cdiv(args["Q"], 8), 16)) 

18 

19 

20def heur_block_n(args): 

21 if args["N"] >= 65536: 

22 return triton.next_power_of_2(triton.cdiv(args["N"], 512)) 

23 elif args["N"] >= 4096: 

24 return triton.next_power_of_2(triton.cdiv(args["N"], 128)) 

25 elif args["N"] >= 64: 

26 return 32 

27 elif args["N"] >= 32: 

28 return 4 

29 else: 

30 return 1 

31 

32 

33# @libentry() 

34@triton.heuristics(values={"BLOCK_Q": heur_block_q, "BLOCK_N": heur_block_n}) 

35@triton.jit 

36def quantile_kernel( 

37 inp, 

38 q, 

39 out, 

40 N, 

41 M, 

42 Q, 

43 BLOCK_Q: tl.constexpr, 

44 BLOCK_N: tl.constexpr, 

45 interpolation: tl.constexpr, 

46): 

47 pid_Q = tle.program_id(0) 

48 pid_N = tle.program_id(1) 

49 ctype = inp.dtype.element_ty 

50 

51 offsets_Q = pid_Q * BLOCK_Q + tl.arange(0, BLOCK_Q) 

52 mask_Q = offsets_Q < Q 

53 q_ptrs = q + offsets_Q 

54 

55 offsets_N = pid_N * BLOCK_N + tl.arange(0, BLOCK_N) 

56 mask_N = offsets_N < N 

57 

58 out_ptrs = out + offsets_N[:, None] * Q + offsets_Q[None, :] 

59 mask_out = mask_N[:, None] & mask_Q[None, :] 

60 

61 q_block = tl.load(q_ptrs, mask_Q, 0.0).to(ctype) * (M - 1) 

62 q_lower = tl.floor(q_block).to(tl.int32) 

63 q_upper = tl.ceil(q_block).to(tl.int32) 

64 

65 inp_lower = tl.load( 

66 inp + offsets_N[:, None] * M + q_lower[None, :], mask_N[:, None], 0.0 

67 ) 

68 inp_upper = tl.load( 

69 inp + offsets_N[:, None] * M + q_upper[None, :], mask_N[:, None], 0.0 

70 ) 

71 

72 if interpolation == "linear": 

73 q_frac = q_block - q_lower 

74 tl.store(out_ptrs, inp_lower + (inp_upper - inp_lower) * q_frac, mask_out) 

75 

76 elif interpolation == "lower": 

77 tl.store(out_ptrs, inp_lower, mask_out) 

78 

79 elif interpolation == "higher": 

80 tl.store(out_ptrs, inp_upper, mask_out) 

81 

82 elif interpolation == "nearest": 

83 q_round = tl_extra_shim.rint(q_block) 

84 out_block = tl.where(q_round == q_upper, inp_upper, inp_lower) 

85 tl.store(out_ptrs, out_block, mask_out) 

86 

87 elif interpolation == "midpoint": 

88 tl.store(out_ptrs, (inp_lower + inp_upper) / 2, mask_out) 

89 

90 

91def quantile( 

92 inp, q, dim=None, keepdim=False, interpolation="linear", out=None 

93) -> Tensor: 

94 logger.debug("GEMS QUANTILE DIM") 

95 assert torch.is_floating_point(inp) 

96 assert dim is None or isinstance(dim, int) 

97 assert isinstance(q, (float, torch.Tensor)) 

98 assert interpolation in INTERPOLATION_METHOD 

99 

100 M = inp.numel() 

101 if isinstance(q, float): 

102 q = torch.tensor(q, device=inp.device) 

103 Q = 1 

104 else: 

105 Q = 1 if q.numel() == 1 else len(q) 

106 

107 assert M > 0 

108 assert Q > 0 

109 assert torch.all(q >= 0.0) and torch.all(q <= 1.0) 

110 

111 if dim is None: 

112 inp = inp.ravel() 

113 dim = 0 

114 

115 shape = list(inp.shape) 

116 

117 dim %= inp.ndim 

118 inp = dim_compress(inp, dim) 

119 M = shape[dim] 

120 N = inp.numel() // M 

121 

122 inp, _ = inp.sort() # Sort the input with torch.sort() 

123 output = torch.empty(inp.shape[:-1] + (Q,), dtype=inp.dtype, device=inp.device) 

124 

125 grid = lambda meta: ( 

126 triton.cdiv(Q, meta["BLOCK_Q"]), 

127 triton.cdiv(N, meta["BLOCK_N"]), 

128 ) 

129 

130 # with torch_device_fn.device(inp.device): 

131 quantile_kernel[grid](inp, q, output, N, M, Q, interpolation=interpolation) 

132 

133 output = output.permute( 

134 (-1,) + tuple(range(0, inp.ndim - 1)) 

135 ) # Same as torch.quantile() 

136 if keepdim: 

137 output = output.unsqueeze(dim + 1) 

138 if Q == 1: 

139 output = output.squeeze(0) 

140 

141 if out is not None: 

142 out.copy_(output) 

143 return output