Coverage for src/flag_gems/ops/round.py: 57%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def round_half_to_even_impl(x): 

14 """Round to nearest with ties to even (round half to even). 

15 x must be fp32.""" 

16 r = tl.floor(x) 

17 d = x - r # fractional part, in [0, 1) for positive, in (-1, 0] for negative 

18 

19 # is_odd = (r % 2 == 1), i.e., r is odd 

20 # In Triton: r - 2 * floor(r/2) for odd r in [-2.5, 2.5] range is close to 1 

21 is_odd = tl.abs(r - 2.0 * tl.floor(r / 2.0)) > 0.5 

22 

23 # For d > 0.5: always round up 

24 # For d == 0.5 and r is odd: round up (to make result even) 

25 # For d == 0.5 and r is even: stay at r (already even) 

26 # For d < 0.5: stay at r 

27 return tl.where((d > 0.5) | ((tl.abs(d - 0.5) < 1e-10) & is_odd), r + 1.0, r) 

28 

29 

30@triton.jit 

31def round_kernel( 

32 x_ptr, 

33 out_ptr, 

34 n_elements, 

35 decimals: tl.constexpr, 

36 BLOCK_SIZE: tl.constexpr, 

37 IS_FP32: tl.constexpr, 

38 IS_FP16: tl.constexpr, 

39 IS_BF16: tl.constexpr, 

40): 

41 pid = tl.program_id(axis=0) 

42 block_start = pid * BLOCK_SIZE 

43 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

44 mask = offsets < n_elements 

45 

46 x = tl.load(x_ptr + offsets, mask=mask) 

47 

48 # Apply round with "round half to even" rule 

49 if decimals == 0: 

50 out = x 

51 if IS_FP32: 

52 out = round_half_to_even_impl(x) 

53 elif IS_FP16: 

54 x_fp32 = tl.cast(x, tl.float32) 

55 out = tl.cast(round_half_to_even_impl(x_fp32), tl.float16) 

56 elif IS_BF16: 

57 x_fp32 = tl.cast(x, tl.float32) 

58 out = tl.cast(round_half_to_even_impl(x_fp32), tl.bfloat16) 

59 else: 

60 # For non-zero decimals, use scaling approach 

61 scale = 10.0**decimals 

62 if IS_FP32: 

63 x_scaled = x * scale 

64 out = round_half_to_even_impl(x_scaled) / scale 

65 elif IS_FP16: 

66 x_fp32 = tl.cast(x, tl.float32) 

67 x_scaled = x_fp32 * scale 

68 out = tl.cast(round_half_to_even_impl(x_scaled) / scale, tl.float16) 

69 elif IS_BF16: 

70 x_fp32 = tl.cast(x, tl.float32) 

71 x_scaled = x_fp32 * scale 

72 out = tl.cast(round_half_to_even_impl(x_scaled) / scale, tl.bfloat16) 

73 else: 

74 out = x 

75 

76 tl.store(out_ptr + offsets, out, mask=mask) 

77 

78 

79def round_func(input, decimals=0): 

80 if not isinstance(input, torch.Tensor): 

81 raise TypeError("round expects a torch.Tensor.") 

82 

83 if input.is_complex(): 

84 raise TypeError("round is not supported for complex tensors.") 

85 

86 # For integer types, return a copy (array-api convention) 

87 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]: 

88 return input.clone() 

89 

90 if not input.is_contiguous(): 

91 raise ValueError( 

92 "round Triton kernel currently supports only contiguous tensors." 

93 ) 

94 

95 n_elements = input.numel() 

96 if n_elements == 0: 

97 return input 

98 

99 dtype = input.dtype 

100 IS_FP32 = dtype == torch.float32 

101 IS_FP16 = dtype == torch.float16 

102 IS_BF16 = dtype == torch.bfloat16 

103 

104 output = torch.empty_like(input) 

105 

106 BLOCK_SIZE = 1024 

107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

108 

109 with torch_device_fn.device(input.device): 

110 round_kernel[grid]( 

111 input, 

112 output, 

113 n_elements, 

114 decimals, 

115 BLOCK_SIZE=BLOCK_SIZE, 

116 IS_FP32=IS_FP32, 

117 IS_FP16=IS_FP16, 

118 IS_BF16=IS_BF16, 

119 ) 

120 return output 

121 

122 

123def round(input, decimals=0): 

124 logger.debug("GEMS ROUND") 

125 return round_func(input, decimals=decimals) 

126 

127 

128def round_out(input, *, decimals=0, out=None): 

129 logger.debug("GEMS ROUND_OUT") 

130 if out is None: 

131 return round_func(input, decimals=decimals) 

132 

133 if not isinstance(input, torch.Tensor): 

134 raise TypeError("round expects a torch.Tensor.") 

135 

136 if input.is_complex(): 

137 raise TypeError("round is not supported for complex tensors.") 

138 

139 # For integer types, return a copy 

140 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]: 

141 out.copy_(input) 

142 return out 

143 

144 if not input.is_contiguous(): 

145 raise ValueError( 

146 "round Triton kernel currently supports only contiguous tensors." 

147 ) 

148 

149 n_elements = input.numel() 

150 if n_elements == 0: 

151 return out 

152 

153 dtype = input.dtype 

154 IS_FP32 = dtype == torch.float32 

155 IS_FP16 = dtype == torch.float16 

156 IS_BF16 = dtype == torch.bfloat16 

157 

158 BLOCK_SIZE = 1024 

159 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

160 

161 with torch_device_fn.device(input.device): 

162 round_kernel[grid]( 

163 input, 

164 out, 

165 n_elements, 

166 decimals, 

167 BLOCK_SIZE=BLOCK_SIZE, 

168 IS_FP32=IS_FP32, 

169 IS_FP16=IS_FP16, 

170 IS_BF16=IS_BF16, 

171 ) 

172 return out 

173 

174 

175def round_(input, *, decimals=0): 

176 logger.debug("GEMS ROUND_") 

177 if not isinstance(input, torch.Tensor): 

178 raise TypeError("round expects a torch.Tensor.") 

179 

180 if input.is_complex(): 

181 raise TypeError("round is not supported for complex tensors.") 

182 

183 # For integer types, return input unchanged (array-api convention for integer round) 

184 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]: 

185 return input 

186 

187 if not input.is_contiguous(): 

188 raise ValueError( 

189 "round Triton kernel currently supports only contiguous tensors." 

190 ) 

191 

192 n_elements = input.numel() 

193 if n_elements == 0: 

194 return input 

195 

196 dtype = input.dtype 

197 IS_FP32 = dtype == torch.float32 

198 IS_FP16 = dtype == torch.float16 

199 IS_BF16 = dtype == torch.bfloat16 

200 

201 BLOCK_SIZE = 1024 

202 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

203 

204 with torch_device_fn.device(input.device): 

205 round_kernel[grid]( 

206 input, 

207 input, 

208 n_elements, 

209 decimals, 

210 BLOCK_SIZE=BLOCK_SIZE, 

211 IS_FP32=IS_FP32, 

212 IS_FP16=IS_FP16, 

213 IS_BF16=IS_BF16, 

214 ) 

215 return input