Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/div.py: 0%

201 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

7 

8from flag_gems.utils import tl_extra_shim 

9 

10from ..utils.pointwise_dynamic import pointwise_dynamic 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13div_rn = tl_extra_shim.div_rn 

14div_rz = tl_extra_shim.div_rz 

15fmod = tl_extra_shim.fmod 

16trunc = tl_extra_shim.trunc 

17xpu_trunc_div = tl_extra_shim.xpu_trunc_div # use it if we need to cmp result with xpu 

18 

19config_ = CodeGenConfig( 

20 512, 

21 (65536, 65536, 65536), 

22 32, 

23 True, 

24 prefer_1d_tile=True, 

25 buffer_size_limit=4096, 

26 isCloseVectorization=True, 

27 unroll_num=8, 

28) 

29 

30 

31@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config_) 

32@triton.jit 

33def true_div_func(x, y): 

34 return x / y 

35 

36 

37@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

38@triton.jit 

39def true_div_func_tensor_scalar(x, y): 

40 return x / y 

41 

42 

43@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")]) 

44@triton.jit 

45def true_div_func_scalar_tensor(x, y): 

46 return x / y 

47 

48 

49def true_divide(A, B): 

50 logger.debug("GEMS_KUNLUNXIN TRUE_DIVIDE") 

51 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

52 return true_div_func(A, B) 

53 elif isinstance(A, torch.Tensor): 

54 return true_div_func_tensor_scalar(A, B) 

55 elif isinstance(B, torch.Tensor): 

56 return true_div_func_scalar_tensor(A, B) 

57 else: 

58 # Both scalar 

59 return torch.tensor(A / B) 

60 

61 

62def true_divide_out(A, B, out): 

63 logger.debug("GEMS_KUNLUNXIN TRUE_DIVIDE_OUT") 

64 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

65 return true_div_func(A, B, out0=out) 

66 elif isinstance(A, torch.Tensor): 

67 return true_div_func_tensor_scalar(A, B, out0=out) 

68 elif isinstance(B, torch.Tensor): 

69 return true_div_func_scalar_tensor(A, B, out0=out) 

70 else: 

71 # Both scalar 

72 return torch.tensor(A / B) if out is None else out.fill_(A / B) 

73 

74 

75def true_divide_(A, B): 

76 logger.debug("GEMS_KUNLUNXIN TRUE_DIVIDE_") 

77 if isinstance(B, torch.Tensor): 

78 return true_div_func(A, B, out0=A) 

79 else: 

80 return true_div_func_tensor_scalar(A, B, out0=A) 

81 

82 

83@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

84@triton.jit 

85def trunc_div_func(x, y): 

86 return xpu_trunc_div(x, y) 

87 

88 

89@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

90@triton.jit 

91def trunc_div_func_tensor_scalar(x, y): 

92 return xpu_trunc_div(x, y) 

93 

94 

95@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

96@triton.jit 

97def trunc_div_func_scalar_tensor(x, y): 

98 return xpu_trunc_div(x, y) 

99 

100 

101# Integer truncation division: Triton's // on integers is C-style (truncates toward zero) 

102@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

103@triton.jit 

104def trunc_div_int_func(x, y): 

105 return x // y 

106 

107 

108@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

109@triton.jit 

110def trunc_div_int_func_tensor_scalar(x, y): 

111 return x // y 

112 

113 

114@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

115@triton.jit 

116def trunc_div_int_func_scalar_tensor(x, y): 

117 return x // y 

118 

119 

120def trunc_divide(A, B): 

121 logger.debug("GEMS_KUNLUNXIN TRUNC_DIVIDE") 

122 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

123 if isinstance(A, torch.Tensor) and not A.is_floating_point(): 

124 if isinstance(B, torch.Tensor): 

125 return trunc_div_int_func(A, B) 

126 else: 

127 return trunc_div_int_func_tensor_scalar(A, B) 

128 if isinstance(B, torch.Tensor) and not B.is_floating_point(): 

129 return trunc_div_int_func_scalar_tensor(A, B) 

130 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

131 return trunc_div_func(A, B) 

132 elif isinstance(A, torch.Tensor): 

133 return trunc_div_func_tensor_scalar(A, B) 

134 elif isinstance(B, torch.Tensor): 

135 return trunc_div_func_scalar_tensor(A, B) 

136 else: 

137 # Both scalar 

138 return torch.tensor(A / B) 

139 

140 

141def trunc_divide_(A, B): 

142 logger.debug("GEMS_KUNLUNXIN TRUNC_DIVIDE_") 

143 # Integer types: use dedicated int kernels (Triton // is C-style truncation) 

144 if not A.is_floating_point(): 

145 if isinstance(B, torch.Tensor): 

146 return trunc_div_int_func(A, B, out0=A) 

147 else: 

148 return trunc_div_int_func_tensor_scalar(A, B, out0=A) 

149 if isinstance(B, torch.Tensor): 

150 return trunc_div_func(A, B, out0=A) 

151 else: 

152 return trunc_div_func_tensor_scalar(A, B, out0=A) 

153 

154 

155@triton.jit 

156def _int_floordiv(x, y): 

157 # TODO: request Triton to add an integer remainder builtin 

158 # The semantic of Triton floordiv differs from Pytorch/Numpy 

159 # Triton floordiv equates to 

160 # (x - np.fmod(x, y)) / y 

161 # whereas Pytorch floordiv is 

162 # (x - np.remainder(x, y)) y 

163 # The results show a one off difference when 

164 # C1) x and y have opposite signs 

165 # and C2) x is not multiples of y. 

166 # Apart from the above, there's an erroneous case x // 0 returns -1 

167 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0 

168 # but this special case is coalesced into the c1 and c2 check so 

169 # there's extra handling. 

170 r = x % y 

171 c1 = r != 0 

172 c2 = (x < 0) ^ (y < 0) 

173 return tl.where(c1 & c2, x // y - 1, x // y) 

174 

175 

176# TO be consistent with python, numpy and torch, we have to implement it in the 

177# following way. 

178# CPython 

179# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 

180# numpy 

181# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532 

182# torch 

183# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23 

184@triton.jit 

185def _float_floordiv(x, y): 

186 # NOTE: fmod's sign is the same as the dividend 

187 remainder = fmod(x, y) 

188 imperfect = remainder != 0.0 

189 different_sign = (x < 0) ^ (y < 0) 

190 

191 # NOTE: we have to use div_rn explicitly here 

192 q = div_rn(x - remainder, y) 

193 q = tl.where(imperfect & different_sign, q - 1, q) 

194 

195 floor_q = tl.math.floor(q) 

196 c = q - floor_q > 0.5 

197 floor_q = tl.where(c, floor_q + 1.0, floor_q) 

198 

199 q_is_zeros = q == 0.0 

200 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q) 

201 

202 is_div_by_zero = y == 0.0 

203 float_division = x / y 

204 out = tl.where(is_div_by_zero, float_division, floor_q) 

205 return out 

206 

207 

208@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

209@triton.jit 

210def floor_div_func(x, y): 

211 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

212 return _int_floordiv(x, y) 

213 else: 

214 return _float_floordiv(x, y) 

215 

216 

217@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

218@triton.jit 

219def floor_div_func_tensor_scalar(x, y): 

220 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

221 return _int_floordiv(x, y) 

222 else: 

223 return _float_floordiv(x, y) 

224 

225 

226@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

227@triton.jit 

228def floor_div_func_scalar_tensor(x, y): 

229 if x.type.scalar.is_int() & x.type.scalar.is_int(): 

230 return _int_floordiv(x, y) 

231 else: 

232 return _float_floordiv(x, y) 

233 

234 

235def floor_divide(A, B): 

236 logger.debug("GEMS_KUNLUNXIN FLOOR_DIVIDE") 

237 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

238 return floor_div_func(A, B) 

239 elif isinstance(A, torch.Tensor): 

240 return floor_div_func_tensor_scalar(A, B) 

241 elif isinstance(B, torch.Tensor): 

242 return floor_div_func_scalar_tensor(A, B) 

243 else: 

244 # Both scalar 

245 return torch.tensor(A // B) 

246 

247 

248def floor_divide_(A, B): 

249 logger.debug("GEMS_KUNLUNXIN FLOOR_DIVIDE_") 

250 if isinstance(B, torch.Tensor): 

251 return floor_div_func(A, B, out0=A) 

252 else: 

253 return floor_div_func_tensor_scalar(A, B, out0=A) 

254 

255 

256def div_mode(A, B, rounding_mode=None): 

257 if rounding_mode is None: 

258 return true_divide(A, B) 

259 elif rounding_mode == "trunc": 

260 return trunc_divide(A, B) 

261 elif rounding_mode == "floor": 

262 return floor_divide(A, B) 

263 else: 

264 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

265 raise ValueError(msg) 

266 

267 

268def div_mode_(A, B, rounding_mode=None): 

269 if rounding_mode is None: 

270 return true_divide_(A, B) 

271 elif rounding_mode == "trunc": 

272 return trunc_divide_(A, B) 

273 elif rounding_mode == "floor": 

274 return floor_divide_(A, B) 

275 else: 

276 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." 

277 raise ValueError(msg) 

278 

279 

280@triton.jit 

281def _remainder(x, y): 

282 r = x % y 

283 c1 = r != 0 

284 c2 = (x < 0) ^ (y < 0) 

285 return tl.where(c1 & c2, r + y, r) 

286 

287 

288@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

289@triton.jit 

290def rem_tt(x, y): 

291 return _remainder(x, y) 

292 

293 

294@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

295@triton.jit 

296def rem_ts(x, y): 

297 return _remainder(x, y) 

298 

299 

300@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

301@triton.jit 

302def rem_st(x, y): 

303 return _remainder(x, y) 

304 

305 

306def remainder(A, B): 

307 logger.debug("GEMS_KUNLUNXIN FLOOR_DIVIDE") 

308 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

309 return rem_tt(A, B) 

310 elif isinstance(A, torch.Tensor): 

311 return rem_ts(A, B) 

312 elif isinstance(B, torch.Tensor): 

313 return rem_st(A, B) 

314 else: 

315 # Both scalar 

316 return torch.tensor(A % B) 

317 

318 

319def remainder_(A, B): 

320 logger.debug("GEMS_KUNLUNXIN REMAINDER_") 

321 if isinstance(B, torch.Tensor): 

322 return rem_tt(A, B, out0=A) 

323 else: 

324 return rem_ts(A, B, out0=A)