Coverage for src/flag_gems/ops/div.py: 62%

217 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.pointwise_dynamic import ComplexMode 

9from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@pointwise_dynamic( 

15 is_tensor=[True, True, True, True], 

16 num_outputs=2, 

17 promotion_methods=[ 

18 (0, 1, 2, 3, "INT_TO_FLOAT"), 

19 (0, 1, 2, 3, "INT_TO_FLOAT"), 

20 ], 

21) 

22@triton.jit 

23def div_complex_kernel(ar, ai, br, bi): 

24 # Smith's method: avoid overflow by dividing by the larger component 

25 abs_br = tl.abs(br) 

26 abs_bi = tl.abs(bi) 

27 use_br = abs_br >= abs_bi 

28 

29 # When |br| >= |bi|: ratio = bi/br, denom = br + bi*ratio 

30 ratio1 = tl.where(br == 0, 0.0, bi / br) 

31 denom1 = br + bi * ratio1 

32 real1 = (ar + ai * ratio1) / denom1 

33 imag1 = (ai - ar * ratio1) / denom1 

34 

35 # When |bi| > |br|: ratio = br/bi, denom = bi + br*ratio 

36 ratio2 = tl.where(bi == 0, 0.0, br / bi) 

37 denom2 = bi + br * ratio2 

38 real2 = (ar * ratio2 + ai) / denom2 

39 imag2 = (ai * ratio2 - ar) / denom2 

40 

41 real = tl.where(use_br, real1, real2) 

42 imag = tl.where(use_br, imag1, imag2) 

43 return real, imag 

44 

45 

46@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

47@triton.jit 

48def true_div_func(x, y): 

49 return x / y 

50 

51 

52@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

53@triton.jit 

54def true_div_func_tensor_scalar(x, y): 

55 return x / y 

56 

57 

58@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

59@triton.jit 

60def true_div_func_scalar_tensor(x, y): 

61 return x / y 

62 

63 

64# Register complex support 

65true_div_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=div_complex_kernel) 

66true_div_func_tensor_scalar.register_complex( 

67 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=true_div_func 

68) 

69true_div_func_scalar_tensor.register_complex( 

70 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=true_div_func 

71) 

72 

73 

74def true_divide(A, B): 

75 logger.debug("GEMS TRUE_DIVIDE") 

76 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

77 return true_div_func(A, B) 

78 elif isinstance(A, torch.Tensor): 

79 return true_div_func_tensor_scalar(A, B) 

80 elif isinstance(B, torch.Tensor): 

81 return true_div_func_scalar_tensor(A, B) 

82 else: 

83 # Both scalar 

84 return torch.tensor(A / B) 

85 

86 

87def true_divide_out(A, B, out): 

88 logger.debug("GEMS TRUE_DIVIDE OUT") 

89 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

90 return true_div_func(A, B, out0=out) 

91 elif isinstance(A, torch.Tensor): 

92 return true_div_func_tensor_scalar(A, B, out0=out) 

93 elif isinstance(B, torch.Tensor): 

94 return true_div_func_scalar_tensor(A, B, out0=out) 

95 else: 

96 # Both scalar 

97 return torch.tensor(A / B) if out is None else out.fill_(A / B) 

98 

99 

100def true_divide_(A, B): 

101 logger.debug("GEMS TRUE_DIVIDE_") 

102 if isinstance(B, torch.Tensor): 

103 return true_div_func(A, B, out0=A) 

104 else: 

105 return true_div_func_tensor_scalar(A, B, out0=A) 

106 

107 

108@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

109@triton.jit 

110def trunc_div_func(x, y): 

111 return trunc(div_rz(x, y)) 

112 

113 

114@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

115@triton.jit 

116def trunc_div_func_tensor_scalar(x, y): 

117 return trunc(div_rz(x, tl.cast(y, x.dtype))) 

118 

119 

120@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

121@triton.jit 

122def trunc_div_func_scalar_tensor(x, y): 

123 return trunc(div_rz(tl.cast(x, y.dtype), y)) 

124 

125 

126# Integer truncation division: Triton's // on integers is C-style (truncates toward zero) 

127@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

128@triton.jit 

129def trunc_div_int_func(x, y): 

130 return x // y 

131 

132 

133@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

134@triton.jit 

135def trunc_div_int_func_tensor_scalar(x, y): 

136 return x // y 

137 

138 

139@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

140@triton.jit 

141def trunc_div_int_func_scalar_tensor(x, y): 

142 return x // y 

143 

144 

145def trunc_divide(A, B): 

146 logger.debug("GEMS TRUNC_DIVIDE") 

147 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

148 if isinstance(A, torch.Tensor) and not A.is_floating_point(): 

149 if isinstance(B, torch.Tensor): 

150 return trunc_div_int_func(A, B) 

151 else: 

152 return trunc_div_int_func_tensor_scalar(A, B) 

153 if isinstance(B, torch.Tensor) and not B.is_floating_point(): 

154 return trunc_div_int_func_scalar_tensor(A, B) 

155 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

156 return trunc_div_func(A, B) 

157 elif isinstance(A, torch.Tensor): 

158 return trunc_div_func_tensor_scalar(A, B) 

159 elif isinstance(B, torch.Tensor): 

160 return trunc_div_func_scalar_tensor(A, B) 

161 else: 

162 # Both scalar 

163 return torch.tensor(A / B) 

164 

165 

166def trunc_divide_(A, B): 

167 logger.debug("GEMS TRUNC_DIVIDE_") 

168 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

169 if not A.is_floating_point(): 

170 if isinstance(B, torch.Tensor): 

171 return trunc_div_int_func(A, B, out0=A) 

172 else: 

173 return trunc_div_int_func_tensor_scalar(A, B, out0=A) 

174 if isinstance(B, torch.Tensor): 

175 return trunc_div_func(A, B, out0=A) 

176 else: 

177 return trunc_div_func_tensor_scalar(A, B, out0=A) 

178 

179 

180@triton.jit 

181def _int_floordiv(x, y): 

182 # TODO: request Triton to add an integer remainder builtin 

183 # The semantic of Triton floordiv differs from Pytorch/Numpy 

184 # Triton floordiv equates to 

185 # (x - np.fmod(x, y)) / y 

186 # whereas Pytorch floordiv is 

187 # (x - np.remainder(x, y)) y 

188 # The results show a one off difference when 

189 # C1) x and y have opposite signs 

190 # and C2) x is not multiples of y. 

191 # Apart from the above, there's an erroneous case x // 0 returns -1 

192 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

193 # but this special case is coalesced into the c1 and c2 check so 

194 # there's extra handling. 

195 r = x % y 

196 c1 = r != 0 

197 c2 = (x < 0) ^ (y < 0) 

198 return tl.where(c1 & c2, x // y - 1, x // y) 

199 

200 

201# TO be consistent with python, numpy and torch, we have to implement it in the 

202# following way. 

203# CPython 

204# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

205# numpy 

206# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

207# torch 

208# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

209@triton.jit 

210def _float_floordiv(x, y): 

211 # NOTE: fmod's sign is the same as the dividend 

212 remainder = fmod(x, y) 

213 imperfect = remainder != 0.0 

214 different_sign = (x < 0) ^ (y < 0) 

215 

216 # NOTE: we have to use div_rn explicitly here 

217 q = div_rn(x - remainder, y) 

218 q = tl.where(imperfect & different_sign, q - 1, q) 

219 

220 floor_q = tl.math.floor(q) 

221 c = q - floor_q > 0.5 

222 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

223 

224 q_is_zeros = q == 0.0 

225 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

226 

227 is_div_by_zero = y == 0.0 

228 float_division = x / y 

229 out = tl.where(is_div_by_zero, float_division, floor_q) 

230 return out 

231 

232 

233@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

234@triton.jit 

235def floor_div_func(x, y): 

236 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

237 return _int_floordiv(x, y) 

238 else: 

239 return _float_floordiv(x, y) 

240 

241 

242@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

243@triton.jit 

244def floor_div_func_tensor_scalar(x, y): 

245 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

246 return _int_floordiv(x, y) 

247 else: 

248 return _float_floordiv(x, y) 

249 

250 

251@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

252@triton.jit 

253def floor_div_func_scalar_tensor(x, y): 

254 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

255 return _int_floordiv(x, y) 

256 else: 

257 return _float_floordiv(x, y) 

258 

259 

260def floor_divide(A, B): 

261 logger.debug("GEMS FLOOR_DIVIDE") 

262 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

263 return floor_div_func(A, B) 

264 elif isinstance(A, torch.Tensor): 

265 return floor_div_func_tensor_scalar(A, B) 

266 elif isinstance(B, torch.Tensor): 

267 return floor_div_func_scalar_tensor(A, B) 

268 else: 

269 # Both scalar 

270 return torch.tensor(A // B) 

271 

272 

273def floor_divide_(A, B): 

274 logger.debug("GEMS FLOOR_DIVIDE_") 

275 if isinstance(B, torch.Tensor): 

276 return floor_div_func(A, B, out0=A) 

277 else: 

278 return floor_div_func_tensor_scalar(A, B, out0=A) 

279 

280 

281def div_mode(A, B, rounding_mode=None): 

282 logger.debug("GEMS DIV_MODE") 

283 if rounding_mode is None: 

284 return true_divide(A, B) 

285 elif rounding_mode == "trunc": 

286 return trunc_divide(A, B) 

287 elif rounding_mode == "floor": 

288 return floor_divide(A, B) 

289 else: 

290 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

291 raise ValueError(msg) 

292 

293 

294def div_mode_(A, B, rounding_mode=None): 

295 logger.debug("GEMS DIV_MODE_") 

296 if rounding_mode is None: 

297 return true_divide_(A, B) 

298 elif rounding_mode == "trunc": 

299 return trunc_divide_(A, B) 

300 elif rounding_mode == "floor": 

301 return floor_divide_(A, B) 

302 else: 

303 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

304 raise ValueError(msg) 

305 

306 

307@triton.jit 

308def _remainder(x, y): 

309 r = x % y 

310 c1 = r != 0 

311 c2 = (x < 0) ^ (y < 0) 

312 return tl.where(c1 & c2, r + y, r) 

313 

314 

315@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

316@triton.jit 

317def rem_tt(x, y): 

318 return _remainder(x, y) 

319 

320 

321@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

322@triton.jit 

323def rem_ts(x, y): 

324 return _remainder(x, y) 

325 

326 

327@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

328@triton.jit 

329def rem_st(x, y): 

330 return _remainder(x, y) 

331 

332 

333def remainder(A, B): 

334 logger.debug("GEMS REMAINDER") 

335 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

336 return rem_tt(A, B) 

337 elif isinstance(A, torch.Tensor): 

338 return rem_ts(A, B) 

339 elif isinstance(B, torch.Tensor): 

340 return rem_st(A, B) 

341 else: 

342 # Both scalar 

343 return torch.tensor(A % B) 

344 

345 

346def remainder_(A, B): 

347 logger.debug("GEMS REMAINDER_") 

348 if isinstance(B, torch.Tensor): 

349 return rem_tt(A, B, out0=A) 

350 else: 

351 return rem_ts(A, B, out0=A)