Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/div.py: 0%
201 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from flag_gems.utils import tl_extra_shim
10from ..utils.pointwise_dynamic import pointwise_dynamic
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13div_rn = tl_extra_shim.div_rn
14div_rz = tl_extra_shim.div_rz
15fmod = tl_extra_shim.fmod
16trunc = tl_extra_shim.trunc
17xpu_trunc_div = tl_extra_shim.xpu_trunc_div # use it if we need to cmp result with xpu
19config_ = CodeGenConfig(
20 512,
21 (65536, 65536, 65536),
22 32,
23 True,
24 prefer_1d_tile=True,
25 buffer_size_limit=4096,
26 isCloseVectorization=True,
27 unroll_num=8,
28)
31@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config_)
32@triton.jit
33def true_div_func(x, y):
34 return x / y
37@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
38@triton.jit
39def true_div_func_tensor_scalar(x, y):
40 return x / y
43@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
44@triton.jit
45def true_div_func_scalar_tensor(x, y):
46 return x / y
49def true_divide(A, B):
50 logger.debug("GEMS_KUNLUNXIN TRUE_DIVIDE")
51 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
52 return true_div_func(A, B)
53 elif isinstance(A, torch.Tensor):
54 return true_div_func_tensor_scalar(A, B)
55 elif isinstance(B, torch.Tensor):
56 return true_div_func_scalar_tensor(A, B)
57 else:
58 # Both scalar
59 return torch.tensor(A / B)
62def true_divide_out(A, B, out):
63 logger.debug("GEMS_KUNLUNXIN TRUE_DIVIDE_OUT")
64 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
65 return true_div_func(A, B, out0=out)
66 elif isinstance(A, torch.Tensor):
67 return true_div_func_tensor_scalar(A, B, out0=out)
68 elif isinstance(B, torch.Tensor):
69 return true_div_func_scalar_tensor(A, B, out0=out)
70 else:
71 # Both scalar
72 return torch.tensor(A / B) if out is None else out.fill_(A / B)
75def true_divide_(A, B):
76 logger.debug("GEMS_KUNLUNXIN TRUE_DIVIDE_")
77 if isinstance(B, torch.Tensor):
78 return true_div_func(A, B, out0=A)
79 else:
80 return true_div_func_tensor_scalar(A, B, out0=A)
83@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
84@triton.jit
85def trunc_div_func(x, y):
86 return xpu_trunc_div(x, y)
89@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
90@triton.jit
91def trunc_div_func_tensor_scalar(x, y):
92 return xpu_trunc_div(x, y)
95@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
96@triton.jit
97def trunc_div_func_scalar_tensor(x, y):
98 return xpu_trunc_div(x, y)
101# Integer truncation division: Triton's // on integers is C-style (truncates toward zero)
102@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
103@triton.jit
104def trunc_div_int_func(x, y):
105 return x // y
108@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
109@triton.jit
110def trunc_div_int_func_tensor_scalar(x, y):
111 return x // y
114@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
115@triton.jit
116def trunc_div_int_func_scalar_tensor(x, y):
117 return x // y
120def trunc_divide(A, B):
121 logger.debug("GEMS_KUNLUNXIN TRUNC_DIVIDE")
122 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
123 if isinstance(A, torch.Tensor) and not A.is_floating_point():
124 if isinstance(B, torch.Tensor):
125 return trunc_div_int_func(A, B)
126 else:
127 return trunc_div_int_func_tensor_scalar(A, B)
128 if isinstance(B, torch.Tensor) and not B.is_floating_point():
129 return trunc_div_int_func_scalar_tensor(A, B)
130 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
131 return trunc_div_func(A, B)
132 elif isinstance(A, torch.Tensor):
133 return trunc_div_func_tensor_scalar(A, B)
134 elif isinstance(B, torch.Tensor):
135 return trunc_div_func_scalar_tensor(A, B)
136 else:
137 # Both scalar
138 return torch.tensor(A / B)
141def trunc_divide_(A, B):
142 logger.debug("GEMS_KUNLUNXIN TRUNC_DIVIDE_")
143 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
144 if not A.is_floating_point():
145 if isinstance(B, torch.Tensor):
146 return trunc_div_int_func(A, B, out0=A)
147 else:
148 return trunc_div_int_func_tensor_scalar(A, B, out0=A)
149 if isinstance(B, torch.Tensor):
150 return trunc_div_func(A, B, out0=A)
151 else:
152 return trunc_div_func_tensor_scalar(A, B, out0=A)
155@triton.jit
156def _int_floordiv(x, y):
157 # TODO: request Triton to add an integer remainder builtin
158 # The semantic of Triton floordiv differs from Pytorch/Numpy
159 # Triton floordiv equates to
160 # (x - np.fmod(x, y)) / y
161 # whereas Pytorch floordiv is
162 # (x - np.remainder(x, y)) y
163 # The results show a one off difference when
164 # C1) x and y have opposite signs
165 # and C2) x is not multiples of y.
166 # Apart from the above, there's an erroneous case x // 0 returns -1
167 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
168 # but this special case is coalesced into the c1 and c2 check so
169 # there's extra handling.
170 r = x % y
171 c1 = r != 0
172 c2 = (x < 0) ^ (y < 0)
173 return tl.where(c1 & c2, x // y - 1, x // y)
176# TO be consistent with python, numpy and torch, we have to implement it in the
177# following way.
178# CPython
179# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
180# numpy
181# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
182# torch
183# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
184@triton.jit
185def _float_floordiv(x, y):
186 # NOTE: fmod's sign is the same as the dividend
187 remainder = fmod(x, y)
188 imperfect = remainder != 0.0
189 different_sign = (x < 0) ^ (y < 0)
191 # NOTE: we have to use div_rn explicitly here
192 q = div_rn(x - remainder, y)
193 q = tl.where(imperfect & different_sign, q - 1, q)
195 floor_q = tl.math.floor(q)
196 c = q - floor_q > 0.5
197 floor_q = tl.where(c, floor_q + 1.0, floor_q)
199 q_is_zeros = q == 0.0
200 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
202 is_div_by_zero = y == 0.0
203 float_division = x / y
204 out = tl.where(is_div_by_zero, float_division, floor_q)
205 return out
208@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
209@triton.jit
210def floor_div_func(x, y):
211 if x.type.scalar.is_int() & x.type.scalar.is_int():
212 return _int_floordiv(x, y)
213 else:
214 return _float_floordiv(x, y)
217@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
218@triton.jit
219def floor_div_func_tensor_scalar(x, y):
220 if x.type.scalar.is_int() & x.type.scalar.is_int():
221 return _int_floordiv(x, y)
222 else:
223 return _float_floordiv(x, y)
226@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
227@triton.jit
228def floor_div_func_scalar_tensor(x, y):
229 if x.type.scalar.is_int() & x.type.scalar.is_int():
230 return _int_floordiv(x, y)
231 else:
232 return _float_floordiv(x, y)
235def floor_divide(A, B):
236 logger.debug("GEMS_KUNLUNXIN FLOOR_DIVIDE")
237 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
238 return floor_div_func(A, B)
239 elif isinstance(A, torch.Tensor):
240 return floor_div_func_tensor_scalar(A, B)
241 elif isinstance(B, torch.Tensor):
242 return floor_div_func_scalar_tensor(A, B)
243 else:
244 # Both scalar
245 return torch.tensor(A // B)
248def floor_divide_(A, B):
249 logger.debug("GEMS_KUNLUNXIN FLOOR_DIVIDE_")
250 if isinstance(B, torch.Tensor):
251 return floor_div_func(A, B, out0=A)
252 else:
253 return floor_div_func_tensor_scalar(A, B, out0=A)
256def div_mode(A, B, rounding_mode=None):
257 if rounding_mode is None:
258 return true_divide(A, B)
259 elif rounding_mode == "trunc":
260 return trunc_divide(A, B)
261 elif rounding_mode == "floor":
262 return floor_divide(A, B)
263 else:
264 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
265 raise ValueError(msg)
268def div_mode_(A, B, rounding_mode=None):
269 if rounding_mode is None:
270 return true_divide_(A, B)
271 elif rounding_mode == "trunc":
272 return trunc_divide_(A, B)
273 elif rounding_mode == "floor":
274 return floor_divide_(A, B)
275 else:
276 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
277 raise ValueError(msg)
280@triton.jit
281def _remainder(x, y):
282 r = x % y
283 c1 = r != 0
284 c2 = (x < 0) ^ (y < 0)
285 return tl.where(c1 & c2, r + y, r)
288@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
289@triton.jit
290def rem_tt(x, y):
291 return _remainder(x, y)
294@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
295@triton.jit
296def rem_ts(x, y):
297 return _remainder(x, y)
300@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
301@triton.jit
302def rem_st(x, y):
303 return _remainder(x, y)
306def remainder(A, B):
307 logger.debug("GEMS_KUNLUNXIN FLOOR_DIVIDE")
308 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
309 return rem_tt(A, B)
310 elif isinstance(A, torch.Tensor):
311 return rem_ts(A, B)
312 elif isinstance(B, torch.Tensor):
313 return rem_st(A, B)
314 else:
315 # Both scalar
316 return torch.tensor(A % B)
319def remainder_(A, B):
320 logger.debug("GEMS_KUNLUNXIN REMAINDER_")
321 if isinstance(B, torch.Tensor):
322 return rem_tt(A, B, out0=A)
323 else:
324 return rem_ts(A, B, out0=A)