Coverage for src/flag_gems/runtime/backend/_cambricon/ops/div.py: 0%
206 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12div_rn = tl_extra_shim.div_rn
13div_rz = tl_extra_shim.div_rz
14fmod = tl_extra_shim.fmod
15trunc = tl_extra_shim.trunc
18@pointwise_dynamic(
19 is_tensor=[True, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]
20)
21@triton.jit
22def true_div_func(x, y, inplace):
23 return x / y
26@pointwise_dynamic(
27 is_tensor=[True, False, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]
28)
29@triton.jit
30def true_div_func_tensor_scalar(x, y, inplace):
31 y = y.to(x.dtype)
32 return x / y
35@pointwise_dynamic(
36 is_tensor=[False, True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")]
37)
38@triton.jit
39def true_div_func_scalar_tensor(x, y, inplace):
40 x = x.to(y.dtype)
41 return x / y
44def true_divide(A, B):
45 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE")
46 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
47 return true_div_func(A, B, False)
48 elif isinstance(A, torch.Tensor):
49 return true_div_func_tensor_scalar(A, B, False)
50 elif isinstance(B, torch.Tensor):
51 return true_div_func_scalar_tensor(A, B, False)
52 else:
53 # Both scalar
54 return torch.tensor(A / B)
57def true_divide_out(A, B, out):
58 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE OUT")
59 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
60 return true_div_func(A, B, False, out0=out)
61 elif isinstance(A, torch.Tensor):
62 return true_div_func_tensor_scalar(A, B, False, out0=out)
63 elif isinstance(B, torch.Tensor):
64 return true_div_func_scalar_tensor(A, B, False, out0=out)
65 else:
66 # Both scalar
67 return torch.tensor(A / B) if out is None else out.fill_(A / B)
70def true_divide_(A, B):
71 logger.debug("GEMS_CAMBRICON TRUE_DIVIDE_")
72 if isinstance(B, torch.Tensor):
73 return true_div_func(A, B, True, out0=A)
74 else:
75 return true_div_func_tensor_scalar(A, B, True, out0=A)
78@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
79@triton.jit
80def trunc_div_func(x, y, inplace):
81 return trunc(div_rn(x, y))
84@pointwise_dynamic(
85 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
86)
87@triton.jit
88def trunc_div_func_tensor_scalar(x, y, inplace):
89 return trunc(div_rn(x, tl.cast(y, x.dtype)))
92@pointwise_dynamic(
93 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
94)
95@triton.jit
96def trunc_div_func_scalar_tensor(x, y, inplace):
97 return trunc(div_rn(tl.cast(x, y.dtype), y))
100# Integer truncation division: Triton's // on integers is C-style (truncates toward zero)
101@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
102@triton.jit
103def trunc_div_int_func(x, y, inplace):
104 return x // y
107@pointwise_dynamic(
108 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
109)
110@triton.jit
111def trunc_div_int_func_tensor_scalar(x, y, inplace):
112 return x // y
115@pointwise_dynamic(
116 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
117)
118@triton.jit
119def trunc_div_int_func_scalar_tensor(x, y, inplace):
120 return x // y
123def trunc_divide(A, B):
124 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE")
125 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
126 if isinstance(A, torch.Tensor) and not A.is_floating_point():
127 if isinstance(B, torch.Tensor):
128 return trunc_div_int_func(A, B, False)
129 else:
130 return trunc_div_int_func_tensor_scalar(A, B, False)
131 if isinstance(B, torch.Tensor) and not B.is_floating_point():
132 return trunc_div_int_func_scalar_tensor(A, B, False)
133 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
134 return trunc_div_func(A, B, False)
135 elif isinstance(A, torch.Tensor):
136 return trunc_div_func_tensor_scalar(A, B, False)
137 elif isinstance(B, torch.Tensor):
138 return trunc_div_func_scalar_tensor(A, B, False)
139 else:
140 # Both scalar
141 return torch.tensor(A / B)
144def trunc_divide_(A, B):
145 logger.debug("GEMS_CAMBRICON TRUNC_DIVIDE_")
146 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
147 if not A.is_floating_point():
148 if isinstance(B, torch.Tensor):
149 return trunc_div_int_func(A, B, True, out0=A)
150 else:
151 return trunc_div_int_func_tensor_scalar(A, B, True, out0=A)
152 if isinstance(B, torch.Tensor):
153 return trunc_div_func(A, B, True, out0=A)
154 else:
155 return trunc_div_func_tensor_scalar(A, B, True, out0=A)
158@triton.jit
159def _int_floordiv(x, y):
160 # TODO: request Triton to add an integer remainder builtin
161 # The semantic of Triton floordiv differs from Pytorch/Numpy
162 # Triton floordiv equates to
163 # (x - np.fmod(x, y)) / y
164 # whereas Pytorch floordiv is
165 # (x - np.remainder(x, y)) y
166 # The results show a one off difference when
167 # C1) x and y have opposite signs
168 # and C2) x is not multiples of y.
169 # Apart from the above, there's an erroneous case x // 0 returns -1
170 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
171 # but this special case is coalesced into the c1 and c2 check so
172 # there's extra handling.
173 r = x % y
174 c1 = r != 0
175 c2 = (x < 0) ^ (y < 0)
176 c3 = (x < 0) & (y == 0)
177 c = c1 & c2
178 if x.dtype == tl.int16 and y.dtype == tl.int16:
179 return (x.to(tl.int32) // y.to(tl.int32)).cast(tl.int16) - c - c3
180 return x // y - c - c3
183# TO be consistent with python, numpy and torch, we have to implement it in the
184# following way.
185# CPython
186# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
187# numpy
188# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
189# torch
190# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
191@triton.jit
192def _float_floordiv(x, y):
193 # NOTE: fmod's sign is the same as the dividend
194 remainder = fmod(x, y)
195 imperfect = remainder != 0.0
196 different_sign = (x < 0) ^ (y < 0)
198 # NOTE: we have to use div_rn explicitly here
199 q = div_rn(x - remainder, y)
200 q = tl.where(imperfect & different_sign, q - 1, q)
202 floor_q = tl.math.floor(q)
203 c = q - floor_q > 0.5
204 floor_q = tl.where(c, floor_q + 1.0, floor_q)
206 q_is_zeros = q == 0.0
207 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
209 is_div_by_zero = y == 0.0
210 float_division = x / y
211 out = tl.where(is_div_by_zero, float_division, floor_q)
212 return out
215@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
216@triton.jit
217def floor_div_func(x, y, inplace):
218 if x.type.scalar.is_int() & y.type.scalar.is_int():
219 return _int_floordiv(x, y)
220 else:
221 return _float_floordiv(x, y)
224@pointwise_dynamic(
225 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
226)
227@triton.jit
228def floor_div_func_tensor_scalar(x, y, inplace):
229 if x.type.scalar.is_int() & y.type.scalar.is_int():
230 return _int_floordiv(x, y)
231 else:
232 return _float_floordiv(x, y)
235@pointwise_dynamic(
236 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
237)
238@triton.jit
239def floor_div_func_scalar_tensor(x, y, inplace):
240 if x.type.scalar.is_int() & y.type.scalar.is_int():
241 return _int_floordiv(x, y)
242 else:
243 return _float_floordiv(x, y)
246def floor_divide(A, B):
247 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE")
248 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
249 return floor_div_func(A, B, False)
250 elif isinstance(A, torch.Tensor):
251 return floor_div_func_tensor_scalar(A, B, False)
252 elif isinstance(B, torch.Tensor):
253 return floor_div_func_scalar_tensor(A, B, False)
254 else:
255 # Both scalar
256 return torch.tensor(A // B)
259def floor_divide_(A, B):
260 logger.debug("GEMS_CAMBRICON FLOOR_DIVIDE_")
261 if isinstance(B, torch.Tensor):
262 return floor_div_func(A, B, True, out0=A)
263 else:
264 return floor_div_func_tensor_scalar(A, B, True, out0=A)
267def div_mode(A, B, rounding_mode=None):
268 logger.debug("GEMS_CAMBRICON DIV_MODE")
269 if rounding_mode is None:
270 return true_divide(A, B)
271 elif rounding_mode == "trunc":
272 return trunc_divide(A, B)
273 elif rounding_mode == "floor":
274 return floor_divide(A, B)
275 else:
276 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
277 raise ValueError(msg)
280def div_mode_(A, B, rounding_mode=None):
281 logger.debug("GEMS_CAMBRICON DIV_MODE_")
282 if rounding_mode is None:
283 return true_divide_(A, B)
284 elif rounding_mode == "trunc":
285 return trunc_divide_(A, B)
286 elif rounding_mode == "floor":
287 return floor_divide_(A, B)
288 else:
289 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
290 raise ValueError(msg)
293@triton.jit
294def _remainder(x, y):
295 r = x % y
296 c1 = r != 0
297 c2 = (x < 0) ^ (y < 0)
298 return tl.where(c1 & c2, r + y, r)
301@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
302@triton.jit
303def rem_tt(x, y, inplace):
304 return _remainder(x, y)
307@pointwise_dynamic(
308 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
309)
310@triton.jit
311def rem_ts(x, y, inplace):
312 return _remainder(x, y)
315@pointwise_dynamic(
316 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
317)
318@triton.jit
319def rem_st(x, y, inplace):
320 return _remainder(x, y)
323def remainder(A, B):
324 logger.debug("GEMS_CAMBRICON REMAINDER")
325 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
326 return rem_tt(A, B, False)
327 elif isinstance(A, torch.Tensor):
328 return rem_ts(A, B, False)
329 elif isinstance(B, torch.Tensor):
330 return rem_st(A, B, False)
331 else:
332 # Both scalar
333 return torch.tensor(A % B)
336def remainder_(A, B):
337 logger.debug("GEMS_CAMBRICON REMAINDER_")
338 if isinstance(B, torch.Tensor):
339 return rem_tt(A, B, True, out0=A)
340 else:
341 return rem_ts(A, B, True, out0=A)