Coverage for src/flag_gems/runtime/backend/_iluvatar/ops/div.py: 0%

166 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9# TODO: Check if this logger instantiation is good 

10logger = logging.getLogger(__name__) 

11div_rn = tl_extra_shim.div_rn 

12div_rz = tl_extra_shim.div_rz 

13fmod = tl_extra_shim.fmod 

14trunc = tl_extra_shim.trunc 

15 

16 

17@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

18@triton.jit 

19def true_div_func(x, y): 

20 return x / y 

21 

22 

23@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

24@triton.jit 

25def true_div_func_tensor_scalar(x, y): 

26 return x / y 

27 

28 

29@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

30@triton.jit 

31def true_div_func_scalar_tensor(x, y): 

32 return x / y 

33 

34 

35def true_divide(A, B): 

36 logger.debug("GEMS TRUE_DIVIDE") 

37 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

38 return true_div_func(A, B) 

39 elif isinstance(A, torch.Tensor): 

40 return true_div_func_tensor_scalar(A, B) 

41 elif isinstance(B, torch.Tensor): 

42 return true_div_func_scalar_tensor(A, B) 

43 else: 

44 # Both scalar 

45 return torch.tensor(A / B) 

46 

47 

48def true_divide_(A, B): 

49 logger.debug("GEMS TRUE_DIVIDE_") 

50 if isinstance(B, torch.Tensor): 

51 return true_div_func(A, B, out0=A) 

52 else: 

53 return true_div_func_tensor_scalar(A, B, out0=A) 

54 

55 

56@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

57@triton.jit 

58def trunc_div_func(x, y): 

59 return trunc(x / y) 

60 

61 

62@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

63@triton.jit 

64def trunc_div_func_tensor_scalar(x, y): 

65 return trunc(div_rz(x, y)) 

66 

67 

68@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

69@triton.jit 

70def trunc_div_func_scalar_tensor(x, y): 

71 return trunc(div_rz(x, y)) 

72 

73 

74def trunc_divide(A, B): 

75 logger.debug("GEMS TRUNC_DIVIDE") 

76 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

77 return trunc_div_func(A, B) 

78 elif isinstance(A, torch.Tensor): 

79 return trunc_div_func_tensor_scalar(A, B) 

80 elif isinstance(B, torch.Tensor): 

81 return trunc_div_func_scalar_tensor(A, B) 

82 else: 

83 # Both scalar 

84 return torch.tensor(A / B) 

85 

86 

87def trunc_divide_(A, B): 

88 logger.debug("GEMS TRUNC_DIVIDE_") 

89 if isinstance(B, torch.Tensor): 

90 return trunc_div_func(A, B, out0=A) 

91 else: 

92 return trunc_div_func_tensor_scalar(A, B, out0=A) 

93 

94 

95@triton.jit 

96def _int_floordiv(x, y): 

97 # TODO: request Triton to add an integer remainder builtin 

98 # The semantic of Triton floordiv differs from Pytorch/Numpy 

99 # Triton floordiv equates to 

100 # (x - np.fmod(x, y)) / y 

101 # whereas Pytorch floordiv is 

102 # (x - np.remainder(x, y)) y 

103 # The results show a one off difference when 

104 # C1) x and y have opposite signs 

105 # and C2) x is not multiples of y. 

106 # Apart from the above, there's an erroneous case x // 0 returns -1 

107 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

108 # but this special case is coalesced into the c1 and c2 check so 

109 # there's extra handling. 

110 r = x % y 

111 c1 = r != 0 

112 c2 = (x < 0) ^ (y < 0) 

113 return tl.where(c1 & c2, x // y - 1, x // y) 

114 

115 

116# TO be consistent with python, numpy and torch, we have to implement it in the 

117# following way. 

118# CPython 

119# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

120# numpy 

121# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

122# torch 

123# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

124@triton.jit 

125def _float_floordiv(x, y): 

126 # NOTE: fmod's sign is the same as the dividend 

127 remainder = fmod(x, y) 

128 imperfect = remainder != 0.0 

129 different_sign = (x < 0) ^ (y < 0) 

130 

131 # NOTE: we have to use div_rn explicitly here 

132 q = div_rn(x - remainder, y) 

133 q = tl.where(imperfect & different_sign, q - 1, q) 

134 

135 floor_q = tl.math.floor(q) 

136 c = q - floor_q > 0.5 

137 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

138 

139 q_is_zeros = q == 0.0 

140 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

141 

142 is_div_by_zero = y == 0.0 

143 float_division = x / y 

144 out = tl.where(is_div_by_zero, float_division, floor_q) 

145 return out 

146 

147 

148@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

149@triton.jit 

150def floor_div_func(x, y): 

151 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

152 return _int_floordiv(x, y) 

153 else: 

154 return _float_floordiv(x, y) 

155 

156 

157@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

158@triton.jit 

159def floor_div_func_tensor_scalar(x, y): 

160 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

161 return _int_floordiv(x, y) 

162 else: 

163 return _float_floordiv(x, y) 

164 

165 

166@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

167@triton.jit 

168def floor_div_func_scalar_tensor(x, y): 

169 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

170 return _int_floordiv(x, y) 

171 else: 

172 return _float_floordiv(x, y) 

173 

174 

175def floor_divide(A, B): 

176 logger.debug("GEMS FLOOR_DIVIDE") 

177 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

178 return floor_div_func(A, B) 

179 elif isinstance(A, torch.Tensor): 

180 return floor_div_func_tensor_scalar(A, B) 

181 elif isinstance(B, torch.Tensor): 

182 return floor_div_func_scalar_tensor(A, B) 

183 else: 

184 # Both scalar 

185 return torch.tensor(A // B) 

186 

187 

188def floor_divide_(A, B): 

189 logger.debug("GEMS FLOOR_DIVIDE_") 

190 if isinstance(B, torch.Tensor): 

191 return floor_div_func(A, B, out0=A) 

192 else: 

193 return floor_div_func_tensor_scalar(A, B, out0=A) 

194 

195 

196def div_mode(A, B, rounding_mode=None): 

197 if rounding_mode is None: 

198 return true_divide(A, B) 

199 elif rounding_mode == "trunc": 

200 return trunc_divide(A, B) 

201 elif rounding_mode == "floor": 

202 return floor_divide(A, B) 

203 else: 

204 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

205 raise ValueError(msg) 

206 

207 

208def div_mode_(A, B, rounding_mode=None): 

209 if rounding_mode is None: 

210 return true_divide_(A, B) 

211 elif rounding_mode == "trunc": 

212 return trunc_divide_(A, B) 

213 elif rounding_mode == "floor": 

214 return floor_divide_(A, B) 

215 else: 

216 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

217 raise ValueError(msg) 

218 

219 

220@triton.jit 

221def _remainder(x, y): 

222 r = x % y 

223 c1 = r != 0 

224 c2 = (x < 0) ^ (y < 0) 

225 return tl.where(c1 & c2, r + y, r) 

226 

227 

228@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

229@triton.jit 

230def rem_tt(x, y): 

231 return _remainder(x, y) 

232 

233 

234@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

235@triton.jit 

236def rem_ts(x, y): 

237 return _remainder(x, y) 

238 

239 

240@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

241@triton.jit 

242def rem_st(x, y): 

243 return _remainder(x, y) 

244 

245 

246def remainder(A, B): 

247 logger.debug("GEMS FLOOR_DIVIDE") 

248 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

249 return rem_tt(A, B) 

250 elif isinstance(A, torch.Tensor): 

251 return rem_ts(A, B) 

252 elif isinstance(B, torch.Tensor): 

253 return rem_st(A, B) 

254 else: 

255 # Both scalar 

256 return torch.tensor(A % B) 

257 

258 

259def remainder_(A, B): 

260 logger.debug("GEMS REMAINDER_") 

261 if isinstance(B, torch.Tensor): 

262 return rem_tt(A, B, out0=A) 

263 else: 

264 return rem_ts(A, B, out0=A)