Coverage for src/flag_gems/runtime/backend/_sunrise/ops/div.py: 0%

175 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

9from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14MAX_GRID_SIZES = (65535, 65535, 65535) 

15config = CodeGenConfig( 

16 max_tile_size=1024, 

17 max_grid_size=MAX_GRID_SIZES, 

18 max_num_warps_per_cta=32, 

19 prefer_block_pointer=True, 

20 prefer_1d_tile=True, 

21) 

22 

23 

24@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config) 

25@triton.jit 

26def true_div_func(x, y): 

27 return x / y 

28 

29 

30@pointwise_dynamic( 

31 is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config 

32) 

33@triton.jit 

34def true_div_func_tensor_scalar(x, y): 

35 return x / y 

36 

37 

38@pointwise_dynamic( 

39 is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config 

40) 

41@triton.jit 

42def true_div_func_scalar_tensor(x, y): 

43 return x / y 

44 

45 

46def true_divide(A, B): 

47 logger.debug("GEMS TRUE_DIVIDE") 

48 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

49 return true_div_func(A, B) 

50 elif isinstance(A, torch.Tensor): 

51 return true_div_func_tensor_scalar(A, B) 

52 elif isinstance(B, torch.Tensor): 

53 return true_div_func_scalar_tensor(A, B) 

54 else: 

55 # Both scalar 

56 return torch.tensor(A / B) 

57 

58 

59def true_divide_out(A, B, out): 

60 logger.debug("GEMS TRUE_DIVIDE OUT") 

61 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

62 return true_div_func(A, B, out0=out) 

63 elif isinstance(A, torch.Tensor): 

64 return true_div_func_tensor_scalar(A, B, out0=out) 

65 elif isinstance(B, torch.Tensor): 

66 return true_div_func_scalar_tensor(A, B, out0=out) 

67 else: 

68 # Both scalar 

69 return torch.tensor(A / B) if out is None else out.fill_(A / B) 

70 

71 

72def true_divide_(A, B): 

73 logger.debug("GEMS TRUE_DIVIDE_") 

74 if isinstance(B, torch.Tensor): 

75 return true_div_func(A, B, out0=A) 

76 else: 

77 return true_div_func_tensor_scalar(A, B, out0=A) 

78 

79 

80@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config) 

81@triton.jit 

82def trunc_div_func(x, y): 

83 return trunc(div_rz(x, y)) 

84 

85 

86@pointwise_dynamic( 

87 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config 

88) 

89@triton.jit 

90def trunc_div_func_tensor_scalar(x, y): 

91 return trunc(div_rz(x, y)) 

92 

93 

94@pointwise_dynamic( 

95 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config 

96) 

97@triton.jit 

98def trunc_div_func_scalar_tensor(x, y): 

99 return trunc(div_rz(x, y)) 

100 

101 

102def trunc_divide(A, B): 

103 logger.debug("GEMS TRUNC_DIVIDE") 

104 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

105 return trunc_div_func(A, B) 

106 elif isinstance(A, torch.Tensor): 

107 return trunc_div_func_tensor_scalar(A, B) 

108 elif isinstance(B, torch.Tensor): 

109 return trunc_div_func_scalar_tensor(A, B) 

110 else: 

111 # Both scalar 

112 return torch.tensor(type(A)(int(A / B))) 

113 

114 

115def trunc_divide_(A, B): 

116 logger.debug("GEMS TRUNC_DIVIDE_") 

117 if isinstance(B, torch.Tensor): 

118 return trunc_div_func(A, B, out0=A) 

119 else: 

120 return trunc_div_func_tensor_scalar(A, B, out0=A) 

121 

122 

123@triton.jit 

124def _int_floordiv(x, y): 

125 # TODO: request Triton to add an integer remainder builtin 

126 # The semantic of Triton floordiv differs from Pytorch/Numpy 

127 # Triton floordiv equates to 

128 # (x - np.fmod(x, y)) / y 

129 # whereas Pytorch floordiv is 

130 # (x - np.remainder(x, y)) y 

131 # The results show a one off difference when 

132 # C1) x and y have opposite signs 

133 # and C2) x is not multiples of y. 

134 # Apart from the above, there's an erroneous case x // 0 returns -1 

135 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

136 # but this special case is coalesced into the c1 and c2 check so 

137 # there's extra handling. 

138 r = x % y 

139 c1 = r != 0 

140 c2 = (x < 0) ^ (y < 0) 

141 return tl.where(c1 & c2, x // y - 1, x // y) 

142 

143 

144# TO be consistent with python, numpy and torch, we have to implement it in the 

145# following way. 

146# CPython 

147# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

148# numpy 

149# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

150# torch 

151# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

152@triton.jit 

153def _float_floordiv(x, y): 

154 # NOTE: fmod's sign is the same as the dividend 

155 remainder = fmod(x, y) 

156 imperfect = remainder != 0.0 

157 different_sign = (x < 0) ^ (y < 0) 

158 

159 # NOTE: we have to use div_rn explicitly here 

160 q = div_rn(x - remainder, y) 

161 q = tl.where(imperfect & different_sign, q - 1, q) 

162 

163 floor_q = tl.math.floor(q) 

164 c = q - floor_q > 0.5 

165 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

166 

167 q_is_zeros = q == 0.0 

168 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

169 

170 is_div_by_zero = y == 0.0 

171 float_division = x / y 

172 out = tl.where(is_div_by_zero, float_division, floor_q) 

173 return out 

174 

175 

176@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config) 

177@triton.jit 

178def floor_div_func(x, y): 

179 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

180 return _int_floordiv(x, y) 

181 else: 

182 return _float_floordiv(x, y) 

183 

184 

185@pointwise_dynamic( 

186 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config 

187) 

188@triton.jit 

189def floor_div_func_tensor_scalar(x, y): 

190 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

191 return _int_floordiv(x, y) 

192 else: 

193 return _float_floordiv(x, y) 

194 

195 

196@pointwise_dynamic( 

197 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config 

198) 

199@triton.jit 

200def floor_div_func_scalar_tensor(x, y): 

201 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

202 return _int_floordiv(x, y) 

203 else: 

204 return _float_floordiv(x, y) 

205 

206 

207def floor_divide(A, B): 

208 logger.debug("GEMS FLOOR_DIVIDE") 

209 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

210 return floor_div_func(A, B) 

211 elif isinstance(A, torch.Tensor): 

212 return floor_div_func_tensor_scalar(A, B) 

213 elif isinstance(B, torch.Tensor): 

214 return floor_div_func_scalar_tensor(A, B) 

215 else: 

216 # Both scalar 

217 return torch.tensor(A // B) 

218 

219 

220def floor_divide_(A, B): 

221 logger.debug("GEMS FLOOR_DIVIDE_") 

222 if isinstance(B, torch.Tensor): 

223 return floor_div_func(A, B, out0=A) 

224 else: 

225 return floor_div_func_tensor_scalar(A, B, out0=A) 

226 

227 

228def div_mode(A, B, rounding_mode=None): 

229 if rounding_mode is None: 

230 return true_divide(A, B) 

231 elif rounding_mode == "trunc": 

232 return trunc_divide(A, B) 

233 elif rounding_mode == "floor": 

234 return floor_divide(A, B) 

235 else: 

236 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

237 raise ValueError(msg) 

238 

239 

240def div_mode_(A, B, rounding_mode=None): 

241 if rounding_mode is None: 

242 return true_divide_(A, B) 

243 elif rounding_mode == "trunc": 

244 return trunc_divide_(A, B) 

245 elif rounding_mode == "floor": 

246 return floor_divide_(A, B) 

247 else: 

248 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

249 raise ValueError(msg) 

250 

251 

252@triton.jit 

253def _remainder(x, y): 

254 r = x % y 

255 c1 = r != 0 

256 c2 = (x < 0) ^ (y < 0) 

257 return tl.where(c1 & c2, r + y, r) 

258 

259 

260@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config) 

261@triton.jit 

262def rem_tt(x, y): 

263 return _remainder(x, y) 

264 

265 

266@pointwise_dynamic( 

267 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config 

268) 

269@triton.jit 

270def rem_ts(x, y): 

271 return _remainder(x, y) 

272 

273 

274@pointwise_dynamic( 

275 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config 

276) 

277@triton.jit 

278def rem_st(x, y): 

279 return _remainder(x, y) 

280 

281 

282def remainder(A, B): 

283 logger.debug("GEMS REMAINDER") 

284 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

285 return rem_tt(A, B) 

286 elif isinstance(A, torch.Tensor): 

287 return rem_ts(A, B) 

288 elif isinstance(B, torch.Tensor): 

289 return rem_st(A, B) 

290 else: 

291 # Both scalar 

292 return torch.tensor(A % B) 

293 

294 

295def remainder_(A, B): 

296 logger.debug("GEMS REMAINDER_") 

297 if isinstance(B, torch.Tensor): 

298 return rem_tt(A, B, out0=A) 

299 else: 

300 return rem_ts(A, B, out0=A)