Coverage for src/flag_gems/runtime/backend/_iluvatar/ops/div.py: 0%

188 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9# TODO: Check if this logger instantiation is good 

10logger = logging.getLogger(__name__) 

11div_rn = tl_extra_shim.div_rn 

12div_rz = tl_extra_shim.div_rz 

13fmod = tl_extra_shim.fmod 

14trunc = tl_extra_shim.trunc 

15 

16 

17@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

18@triton.jit 

19def true_div_func(x, y): 

20 return x / y 

21 

22 

23@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

24@triton.jit 

25def true_div_func_tensor_scalar(x, y): 

26 return x / y 

27 

28 

29@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

30@triton.jit 

31def true_div_func_scalar_tensor(x, y): 

32 return x / y 

33 

34 

35def true_divide(A, B): 

36 logger.debug("GEMS_ILUVATAR TRUE_DIVIDE") 

37 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

38 return true_div_func(A, B) 

39 elif isinstance(A, torch.Tensor): 

40 return true_div_func_tensor_scalar(A, B) 

41 elif isinstance(B, torch.Tensor): 

42 return true_div_func_scalar_tensor(A, B) 

43 else: 

44 # Both scalar 

45 return torch.tensor(A / B) 

46 

47 

48def true_divide_(A, B): 

49 logger.debug("GEMS_ILUVATAR TRUE_DIVIDE_") 

50 if isinstance(B, torch.Tensor): 

51 return true_div_func(A, B, out0=A) 

52 else: 

53 return true_div_func_tensor_scalar(A, B, out0=A) 

54 

55 

56@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

57@triton.jit 

58def trunc_div_func(x, y): 

59 return trunc(x / y) 

60 

61 

62@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

63@triton.jit 

64def trunc_div_func_tensor_scalar(x, y): 

65 return trunc(div_rz(x, tl.cast(y, x.dtype))) 

66 

67 

68@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

69@triton.jit 

70def trunc_div_func_scalar_tensor(x, y): 

71 return trunc(div_rz(tl.cast(x, y.dtype), y)) 

72 

73 

74# Integer truncation division: Triton's // on integers is C-style (truncates toward zero) 

75@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

76@triton.jit 

77def trunc_div_int_func(x, y): 

78 return x // y 

79 

80 

81@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

82@triton.jit 

83def trunc_div_int_func_tensor_scalar(x, y): 

84 return x // y 

85 

86 

87@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

88@triton.jit 

89def trunc_div_int_func_scalar_tensor(x, y): 

90 return x // y 

91 

92 

93def trunc_divide(A, B): 

94 logger.debug("GEMS_ILUVATAR TRUNC_DIVIDE") 

95 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

96 if isinstance(A, torch.Tensor) and not A.is_floating_point(): 

97 if isinstance(B, torch.Tensor): 

98 return trunc_div_int_func(A, B) 

99 else: 

100 return trunc_div_int_func_tensor_scalar(A, B) 

101 if isinstance(B, torch.Tensor) and not B.is_floating_point(): 

102 return trunc_div_int_func_scalar_tensor(A, B) 

103 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

104 return trunc_div_func(A, B) 

105 elif isinstance(A, torch.Tensor): 

106 return trunc_div_func_tensor_scalar(A, B) 

107 elif isinstance(B, torch.Tensor): 

108 return trunc_div_func_scalar_tensor(A, B) 

109 else: 

110 # Both scalar 

111 return torch.tensor(A / B) 

112 

113 

114def trunc_divide_(A, B): 

115 logger.debug("GEMS_ILUVATAR TRUNC_DIVIDE_") 

116 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

117 if not A.is_floating_point(): 

118 if isinstance(B, torch.Tensor): 

119 return trunc_div_int_func(A, B, out0=A) 

120 else: 

121 return trunc_div_int_func_tensor_scalar(A, B, out0=A) 

122 if isinstance(B, torch.Tensor): 

123 return trunc_div_func(A, B, out0=A) 

124 else: 

125 return trunc_div_func_tensor_scalar(A, B, out0=A) 

126 

127 

128@triton.jit 

129def _int_floordiv(x, y): 

130 # TODO: request Triton to add an integer remainder builtin 

131 # The semantic of Triton floordiv differs from Pytorch/Numpy 

132 # Triton floordiv equates to 

133 # (x - np.fmod(x, y)) / y 

134 # whereas Pytorch floordiv is 

135 # (x - np.remainder(x, y)) y 

136 # The results show a one off difference when 

137 # C1) x and y have opposite signs 

138 # and C2) x is not multiples of y. 

139 # Apart from the above, there's an erroneous case x // 0 returns -1 

140 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

141 # but this special case is coalesced into the c1 and c2 check so 

142 # there's extra handling. 

143 r = x % y 

144 c1 = r != 0 

145 c2 = (x < 0) ^ (y < 0) 

146 return tl.where(c1 & c2, x // y - 1, x // y) 

147 

148 

149# TO be consistent with python, numpy and torch, we have to implement it in the 

150# following way. 

151# CPython 

152# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

153# numpy 

154# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

155# torch 

156# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

157@triton.jit 

158def _float_floordiv(x, y): 

159 # NOTE: fmod's sign is the same as the dividend 

160 remainder = fmod(x, y) 

161 imperfect = remainder != 0.0 

162 different_sign = (x < 0) ^ (y < 0) 

163 

164 # NOTE: we have to use div_rn explicitly here 

165 q = div_rn(x - remainder, y) 

166 q = tl.where(imperfect & different_sign, q - 1, q) 

167 

168 floor_q = tl.math.floor(q) 

169 c = q - floor_q > 0.5 

170 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

171 

172 q_is_zeros = q == 0.0 

173 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

174 

175 is_div_by_zero = y == 0.0 

176 float_division = x / y 

177 out = tl.where(is_div_by_zero, float_division, floor_q) 

178 return out 

179 

180 

181@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

182@triton.jit 

183def floor_div_func(x, y): 

184 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

185 return _int_floordiv(x, y) 

186 else: 

187 return _float_floordiv(x, y) 

188 

189 

190@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

191@triton.jit 

192def floor_div_func_tensor_scalar(x, y): 

193 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

194 return _int_floordiv(x, y) 

195 else: 

196 return _float_floordiv(x, y) 

197 

198 

199@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

200@triton.jit 

201def floor_div_func_scalar_tensor(x, y): 

202 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

203 return _int_floordiv(x, y) 

204 else: 

205 return _float_floordiv(x, y) 

206 

207 

208def floor_divide(A, B): 

209 logger.debug("GEMS_ILUVATAR FLOOR_DIVIDE") 

210 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

211 return floor_div_func(A, B) 

212 elif isinstance(A, torch.Tensor): 

213 return floor_div_func_tensor_scalar(A, B) 

214 elif isinstance(B, torch.Tensor): 

215 return floor_div_func_scalar_tensor(A, B) 

216 else: 

217 # Both scalar 

218 return torch.tensor(A // B) 

219 

220 

221def floor_divide_(A, B): 

222 logger.debug("GEMS_ILUVATAR FLOOR_DIVIDE_") 

223 if isinstance(B, torch.Tensor): 

224 return floor_div_func(A, B, out0=A) 

225 else: 

226 return floor_div_func_tensor_scalar(A, B, out0=A) 

227 

228 

229def div_mode(A, B, rounding_mode=None): 

230 if rounding_mode is None: 

231 return true_divide(A, B) 

232 elif rounding_mode == "trunc": 

233 return trunc_divide(A, B) 

234 elif rounding_mode == "floor": 

235 return floor_divide(A, B) 

236 else: 

237 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

238 raise ValueError(msg) 

239 

240 

241def div_mode_(A, B, rounding_mode=None): 

242 if rounding_mode is None: 

243 return true_divide_(A, B) 

244 elif rounding_mode == "trunc": 

245 return trunc_divide_(A, B) 

246 elif rounding_mode == "floor": 

247 return floor_divide_(A, B) 

248 else: 

249 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

250 raise ValueError(msg) 

251 

252 

253@triton.jit 

254def _remainder(x, y): 

255 r = x % y 

256 c1 = r != 0 

257 c2 = (x < 0) ^ (y < 0) 

258 return tl.where(c1 & c2, r + y, r) 

259 

260 

261@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

262@triton.jit 

263def rem_tt(x, y): 

264 return _remainder(x, y) 

265 

266 

267@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

268@triton.jit 

269def rem_ts(x, y): 

270 return _remainder(x, y) 

271 

272 

273@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

274@triton.jit 

275def rem_st(x, y): 

276 return _remainder(x, y) 

277 

278 

279def remainder(A, B): 

280 logger.debug("GEMS_ILUVATAR REMAINDER") 

281 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

282 return rem_tt(A, B) 

283 elif isinstance(A, torch.Tensor): 

284 return rem_ts(A, B) 

285 elif isinstance(B, torch.Tensor): 

286 return rem_st(A, B) 

287 else: 

288 # Both scalar 

289 return torch.tensor(A % B) 

290 

291 

292def remainder_(A, B): 

293 logger.debug("GEMS_ILUVATAR REMAINDER_") 

294 if isinstance(B, torch.Tensor): 

295 return rem_tt(A, B, out0=A) 

296 else: 

297 return rem_ts(A, B, out0=A)