Coverage for src/flag_gems/runtime/backend/_cambricon/ops/div.py: 0%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12div_rn = tl_extra_shim.div_rn 

13div_rz = tl_extra_shim.div_rz 

14fmod = tl_extra_shim.fmod 

15trunc = tl_extra_shim.trunc 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")] 

20) 

21@triton.jit 

22def true_div_func(x, y, inplace): 

23 return x / y 

24 

25 

26@pointwise_dynamic( 

27 is_tensor=[True, False, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")] 

28) 

29@triton.jit 

30def true_div_func_tensor_scalar(x, y, inplace): 

31 y = y.to(x.dtype) 

32 return x / y 

33 

34 

35@pointwise_dynamic( 

36 is_tensor=[False, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")] 

37) 

38@triton.jit 

39def true_div_func_scalar_tensor(x, y, inplace): 

40 x = x.to(y.dtype) 

41 return x / y 

42 

43 

44def true_divide(A, B): 

45 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE") 

46 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

47 return true_div_func(A, B, False) 

48 elif isinstance(A, torch.Tensor): 

49 return true_div_func_tensor_scalar(A, B, False) 

50 elif isinstance(B, torch.Tensor): 

51 return true_div_func_scalar_tensor(A, B, False) 

52 else: 

53 # Both scalar 

54 return torch.tensor(A / B) 

55 

56 

57def true_divide_out(A, B, out): 

58 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE OUT") 

59 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

60 return true_div_func(A, B, False, out0=out) 

61 elif isinstance(A, torch.Tensor): 

62 return true_div_func_tensor_scalar(A, B, False, out0=out) 

63 elif isinstance(B, torch.Tensor): 

64 return true_div_func_scalar_tensor(A, B, False, out0=out) 

65 else: 

66 # Both scalar 

67 return torch.tensor(A / B) if out is None else out.fill_(A / B) 

68 

69 

70def true_divide_(A, B): 

71 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE_") 

72 if isinstance(B, torch.Tensor): 

73 return true_div_func(A, B, True, out0=A) 

74 else: 

75 return true_div_func_tensor_scalar(A, B, True, out0=A) 

76 

77 

78@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

79@triton.jit 

80def trunc_div_func(x, y, inplace): 

81 return trunc(div_rn(x, y)) 

82 

83 

84@pointwise_dynamic( 

85 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

86) 

87@triton.jit 

88def trunc_div_func_tensor_scalar(x, y, inplace): 

89 return trunc(div_rn(x, tl.cast(y, x.dtype))) 

90 

91 

92@pointwise_dynamic( 

93 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

94) 

95@triton.jit 

96def trunc_div_func_scalar_tensor(x, y, inplace): 

97 return trunc(div_rn(tl.cast(x, y.dtype), y)) 

98 

99 

100# Integer truncation division: Triton's // on integers is C-style (truncates toward zero) 

101@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

102@triton.jit 

103def trunc_div_int_func(x, y, inplace): 

104 return x // y 

105 

106 

107@pointwise_dynamic( 

108 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

109) 

110@triton.jit 

111def trunc_div_int_func_tensor_scalar(x, y, inplace): 

112 return x // y 

113 

114 

115@pointwise_dynamic( 

116 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

117) 

118@triton.jit 

119def trunc_div_int_func_scalar_tensor(x, y, inplace): 

120 return x // y 

121 

122 

123def trunc_divide(A, B): 

124 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE") 

125 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

126 if isinstance(A, torch.Tensor) and not A.is_floating_point(): 

127 if isinstance(B, torch.Tensor): 

128 return trunc_div_int_func(A, B, False) 

129 else: 

130 return trunc_div_int_func_tensor_scalar(A, B, False) 

131 if isinstance(B, torch.Tensor) and not B.is_floating_point(): 

132 return trunc_div_int_func_scalar_tensor(A, B, False) 

133 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

134 return trunc_div_func(A, B, False) 

135 elif isinstance(A, torch.Tensor): 

136 return trunc_div_func_tensor_scalar(A, B, False) 

137 elif isinstance(B, torch.Tensor): 

138 return trunc_div_func_scalar_tensor(A, B, False) 

139 else: 

140 # Both scalar 

141 return torch.tensor(A / B) 

142 

143 

144def trunc_divide_(A, B): 

145 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE_") 

146 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

147 if not A.is_floating_point(): 

148 if isinstance(B, torch.Tensor): 

149 return trunc_div_int_func(A, B, True, out0=A) 

150 else: 

151 return trunc_div_int_func_tensor_scalar(A, B, True, out0=A) 

152 if isinstance(B, torch.Tensor): 

153 return trunc_div_func(A, B, True, out0=A) 

154 else: 

155 return trunc_div_func_tensor_scalar(A, B, True, out0=A) 

156 

157 

158@triton.jit 

159def _int_floordiv(x, y): 

160 # TODO: request Triton to add an integer remainder builtin 

161 # The semantic of Triton floordiv differs from Pytorch/Numpy 

162 # Triton floordiv equates to 

163 # (x - np.fmod(x, y)) / y 

164 # whereas Pytorch floordiv is 

165 # (x - np.remainder(x, y)) y 

166 # The results show a one off difference when 

167 # C1) x and y have opposite signs 

168 # and C2) x is not multiples of y. 

169 # Apart from the above, there's an erroneous case x // 0 returns -1 

170 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

171 # but this special case is coalesced into the c1 and c2 check so 

172 # there's extra handling. 

173 r = x % y 

174 c1 = r != 0 

175 c2 = (x < 0) ^ (y < 0) 

176 c3 = (x < 0) & (y == 0) 

177 c = c1 & c2 

178 if x.dtype == tl.int16 and y.dtype == tl.int16: 

179 return (x.to(tl.int32) // y.to(tl.int32)).cast(tl.int16) - c - c3 

180 return x // y - c - c3 

181 

182 

183# TO be consistent with python, numpy and torch, we have to implement it in the 

184# following way. 

185# CPython 

186# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

187# numpy 

188# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

189# torch 

190# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

191@triton.jit 

192def _float_floordiv(x, y): 

193 # NOTE: fmod's sign is the same as the dividend 

194 remainder = fmod(x, y) 

195 imperfect = remainder != 0.0 

196 different_sign = (x < 0) ^ (y < 0) 

197 

198 # NOTE: we have to use div_rn explicitly here 

199 q = div_rn(x - remainder, y) 

200 q = tl.where(imperfect & different_sign, q - 1, q) 

201 

202 floor_q = tl.math.floor(q) 

203 c = q - floor_q > 0.5 

204 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

205 

206 q_is_zeros = q == 0.0 

207 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

208 

209 is_div_by_zero = y == 0.0 

210 float_division = x / y 

211 out = tl.where(is_div_by_zero, float_division, floor_q) 

212 return out 

213 

214 

215@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

216@triton.jit 

217def floor_div_func(x, y, inplace): 

218 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

219 return _int_floordiv(x, y) 

220 else: 

221 return _float_floordiv(x, y) 

222 

223 

224@pointwise_dynamic( 

225 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

226) 

227@triton.jit 

228def floor_div_func_tensor_scalar(x, y, inplace): 

229 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

230 return _int_floordiv(x, y) 

231 else: 

232 return _float_floordiv(x, y) 

233 

234 

235@pointwise_dynamic( 

236 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

237) 

238@triton.jit 

239def floor_div_func_scalar_tensor(x, y, inplace): 

240 if x.type.scalar.is_int() & y.type.scalar.is_int(): 

241 return _int_floordiv(x, y) 

242 else: 

243 return _float_floordiv(x, y) 

244 

245 

246def floor_divide(A, B): 

247 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE") 

248 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

249 return floor_div_func(A, B, False) 

250 elif isinstance(A, torch.Tensor): 

251 return floor_div_func_tensor_scalar(A, B, False) 

252 elif isinstance(B, torch.Tensor): 

253 return floor_div_func_scalar_tensor(A, B, False) 

254 else: 

255 # Both scalar 

256 return torch.tensor(A // B) 

257 

258 

259def floor_divide_(A, B): 

260 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE_") 

261 if isinstance(B, torch.Tensor): 

262 return floor_div_func(A, B, True, out0=A) 

263 else: 

264 return floor_div_func_tensor_scalar(A, B, True, out0=A) 

265 

266 

267def div_mode(A, B, rounding_mode=None): 

268 logger.debug("GEMS_CAMBRICON DIV_MODE") 

269 if rounding_mode is None: 

270 return true_divide(A, B) 

271 elif rounding_mode == "trunc": 

272 return trunc_divide(A, B) 

273 elif rounding_mode == "floor": 

274 return floor_divide(A, B) 

275 else: 

276 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

277 raise ValueError(msg) 

278 

279 

280def div_mode_(A, B, rounding_mode=None): 

281 logger.debug("GEMS_CAMBRICON DIV_MODE_") 

282 if rounding_mode is None: 

283 return true_divide_(A, B) 

284 elif rounding_mode == "trunc": 

285 return trunc_divide_(A, B) 

286 elif rounding_mode == "floor": 

287 return floor_divide_(A, B) 

288 else: 

289 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

290 raise ValueError(msg) 

291 

292 

293@triton.jit 

294def _remainder(x, y): 

295 r = x % y 

296 c1 = r != 0 

297 c2 = (x < 0) ^ (y < 0) 

298 return tl.where(c1 & c2, r + y, r) 

299 

300 

301@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

302@triton.jit 

303def rem_tt(x, y, inplace): 

304 return _remainder(x, y) 

305 

306 

307@pointwise_dynamic( 

308 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

309) 

310@triton.jit 

311def rem_ts(x, y, inplace): 

312 return _remainder(x, y) 

313 

314 

315@pointwise_dynamic( 

316 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

317) 

318@triton.jit 

319def rem_st(x, y, inplace): 

320 return _remainder(x, y) 

321 

322 

323def remainder(A, B): 

324 logger.debug("GEMS_CAMBRICON REMAINDER") 

325 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

326 return rem_tt(A, B, False) 

327 elif isinstance(A, torch.Tensor): 

328 return rem_ts(A, B, False) 

329 elif isinstance(B, torch.Tensor): 

330 return rem_st(A, B, False) 

331 else: 

332 # Both scalar 

333 return torch.tensor(A % B) 

334 

335 

336def remainder_(A, B): 

337 logger.debug("GEMS_CAMBRICON REMAINDER_") 

338 if isinstance(B, torch.Tensor): 

339 return rem_tt(A, B, True, out0=A) 

340 else: 

341 return rem_ts(A, B, True, out0=A)