Coverage for src/flag_gems/ops/div.py: 61%
185 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.pointwise_dynamic import ComplexMode
9from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc
11logger = logging.getLogger(__name__)
14@pointwise_dynamic(
15 is_tensor=[True, True, True, True],
16 num_outputs=2,
17 promotion_methods=[
18 (0, 1, 2, 3, "INT_TO_FLOAT"),
19 (0, 1, 2, 3, "INT_TO_FLOAT"),
20 ],
21)
22@triton.jit
23def div_complex_kernel(ar, ai, br, bi):
24 # Smith's method: avoid overflow by dividing by the larger component
25 abs_br = tl.abs(br)
26 abs_bi = tl.abs(bi)
27 use_br = abs_br >= abs_bi
29 # When |br| >= |bi|: ratio = bi/br, denom = br + bi*ratio
30 ratio1 = tl.where(br == 0, 0.0, bi / br)
31 denom1 = br + bi * ratio1
32 real1 = (ar + ai * ratio1) / denom1
33 imag1 = (ai - ar * ratio1) / denom1
35 # When |bi| > |br|: ratio = br/bi, denom = bi + br*ratio
36 ratio2 = tl.where(bi == 0, 0.0, br / bi)
37 denom2 = bi + br * ratio2
38 real2 = (ar * ratio2 + ai) / denom2
39 imag2 = (ai * ratio2 - ar) / denom2
41 real = tl.where(use_br, real1, real2)
42 imag = tl.where(use_br, imag1, imag2)
43 return real, imag
46@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
47@triton.jit
48def true_div_func(x, y):
49 return x / y
52@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
53@triton.jit
54def true_div_func_tensor_scalar(x, y):
55 return x / y
58@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
59@triton.jit
60def true_div_func_scalar_tensor(x, y):
61 return x / y
64# Register complex support
65true_div_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=div_complex_kernel)
66true_div_func_tensor_scalar.register_complex(
67 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=true_div_func
68)
69true_div_func_scalar_tensor.register_complex(
70 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=true_div_func
71)
74def true_divide(A, B):
75 logger.debug("GEMS TRUE_DIVIDE")
76 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
77 return true_div_func(A, B)
78 elif isinstance(A, torch.Tensor):
79 return true_div_func_tensor_scalar(A, B)
80 elif isinstance(B, torch.Tensor):
81 return true_div_func_scalar_tensor(A, B)
82 else:
83 # Both scalar
84 return torch.tensor(A / B)
87def true_divide_out(A, B, out):
88 logger.debug("GEMS TRUE_DIVIDE OUT")
89 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
90 return true_div_func(A, B, out0=out)
91 elif isinstance(A, torch.Tensor):
92 return true_div_func_tensor_scalar(A, B, out0=out)
93 elif isinstance(B, torch.Tensor):
94 return true_div_func_scalar_tensor(A, B, out0=out)
95 else:
96 # Both scalar
97 return torch.tensor(A / B) if out is None else out.fill_(A / B)
100def true_divide_(A, B):
101 logger.debug("GEMS TRUE_DIVIDE_")
102 if isinstance(B, torch.Tensor):
103 return true_div_func(A, B, out0=A)
104 else:
105 return true_div_func_tensor_scalar(A, B, out0=A)
108@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
109@triton.jit
110def trunc_div_func(x, y):
111 return trunc(div_rz(x, y))
114@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
115@triton.jit
116def trunc_div_func_tensor_scalar(x, y):
117 return trunc(div_rz(x, tl.cast(y, x.dtype)))
120@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
121@triton.jit
122def trunc_div_func_scalar_tensor(x, y):
123 return trunc(div_rz(tl.cast(x, y.dtype), y))
126# Integer truncation division: Triton's // on integers is C-style (truncates toward zero)
127@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
128@triton.jit
129def trunc_div_int_func(x, y):
130 return x // y
133@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
134@triton.jit
135def trunc_div_int_func_tensor_scalar(x, y):
136 return x // y
139@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
140@triton.jit
141def trunc_div_int_func_scalar_tensor(x, y):
142 return x // y
145def trunc_divide(A, B):
146 logger.debug("GEMS TRUNC_DIVIDE")
147 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
148 if isinstance(A, torch.Tensor) and not A.is_floating_point():
149 if isinstance(B, torch.Tensor):
150 return trunc_div_int_func(A, B)
151 else:
152 return trunc_div_int_func_tensor_scalar(A, B)
153 if isinstance(B, torch.Tensor) and not B.is_floating_point():
154 return trunc_div_int_func_scalar_tensor(A, B)
155 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
156 return trunc_div_func(A, B)
157 elif isinstance(A, torch.Tensor):
158 return trunc_div_func_tensor_scalar(A, B)
159 elif isinstance(B, torch.Tensor):
160 return trunc_div_func_scalar_tensor(A, B)
161 else:
162 # Both scalar
163 return torch.tensor(A / B)
166def trunc_divide_(A, B):
167 logger.debug("GEMS TRUNC_DIVIDE_")
168 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
169 if not A.is_floating_point():
170 if isinstance(B, torch.Tensor):
171 return trunc_div_int_func(A, B, out0=A)
172 else:
173 return trunc_div_int_func_tensor_scalar(A, B, out0=A)
174 if isinstance(B, torch.Tensor):
175 return trunc_div_func(A, B, out0=A)
176 else:
177 return trunc_div_func_tensor_scalar(A, B, out0=A)
180@triton.jit
181def _int_floordiv(x, y):
182 # TODO: request Triton to add an integer remainder builtin
183 # The semantic of Triton floordiv differs from Pytorch/Numpy
184 # Triton floordiv equates to
185 # (x - np.fmod(x, y)) / y
186 # whereas Pytorch floordiv is
187 # (x - np.remainder(x, y)) y
188 # The results show a one off difference when
189 # C1) x and y have opposite signs
190 # and C2) x is not multiples of y.
191 # Apart from the above, there's an erroneous case x // 0 returns -1
192 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
193 # but this special case is coalesced into the c1 and c2 check so
194 # there's extra handling.
195 r = x % y
196 c1 = r != 0
197 c2 = (x < 0) ^ (y < 0)
198 return tl.where(c1 & c2, x // y - 1, x // y)
201# TO be consistent with python, numpy and torch, we have to implement it in the
202# following way.
203# CPython
204# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
205# numpy
206# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
207# torch
208# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
209@triton.jit
210def _float_floordiv(x, y):
211 # NOTE: fmod's sign is the same as the dividend
212 remainder = fmod(x, y)
213 imperfect = remainder != 0.0
214 different_sign = (x < 0) ^ (y < 0)
216 # NOTE: we have to use div_rn explicitly here
217 q = div_rn(x - remainder, y)
218 q = tl.where(imperfect & different_sign, q - 1, q)
220 floor_q = tl.math.floor(q)
221 c = q - floor_q > 0.5
222 floor_q = tl.where(c, floor_q + 1.0, floor_q)
224 q_is_zeros = q == 0.0
225 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
227 is_div_by_zero = y == 0.0
228 float_division = x / y
229 out = tl.where(is_div_by_zero, float_division, floor_q)
230 return out
233@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
234@triton.jit
235def floor_div_func(x, y):
236 if x.type.scalar.is_int() & y.type.scalar.is_int():
237 return _int_floordiv(x, y)
238 else:
239 return _float_floordiv(x, y)
242@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
243@triton.jit
244def floor_div_func_tensor_scalar(x, y):
245 if x.type.scalar.is_int() & y.type.scalar.is_int():
246 return _int_floordiv(x, y)
247 else:
248 return _float_floordiv(x, y)
251@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
252@triton.jit
253def floor_div_func_scalar_tensor(x, y):
254 if x.type.scalar.is_int() & y.type.scalar.is_int():
255 return _int_floordiv(x, y)
256 else:
257 return _float_floordiv(x, y)
260def floor_divide(A, B):
261 logger.debug("GEMS FLOOR_DIVIDE")
262 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
263 return floor_div_func(A, B)
264 elif isinstance(A, torch.Tensor):
265 return floor_div_func_tensor_scalar(A, B)
266 elif isinstance(B, torch.Tensor):
267 return floor_div_func_scalar_tensor(A, B)
268 else:
269 # Both scalar
270 return torch.tensor(A // B)
273def floor_divide_(A, B):
274 logger.debug("GEMS FLOOR_DIVIDE_")
275 if isinstance(B, torch.Tensor):
276 return floor_div_func(A, B, out0=A)
277 else:
278 return floor_div_func_tensor_scalar(A, B, out0=A)
281def div_mode(A, B, rounding_mode=None):
282 logger.debug("GEMS DIV_MODE")
283 if rounding_mode is None:
284 return true_divide(A, B)
285 elif rounding_mode == "trunc":
286 return trunc_divide(A, B)
287 elif rounding_mode == "floor":
288 return floor_divide(A, B)
289 else:
290 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
291 raise ValueError(msg)
294def div_mode_(A, B, rounding_mode=None):
295 logger.debug("GEMS DIV_MODE_")
296 if rounding_mode is None:
297 return true_divide_(A, B)
298 elif rounding_mode == "trunc":
299 return trunc_divide_(A, B)
300 elif rounding_mode == "floor":
301 return floor_divide_(A, B)
302 else:
303 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
304 raise ValueError(msg)