Coverage for src/flag_gems/runtime/backend/_iluvatar/ops/div.py: 0%
188 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9# TODO: Check if this logger instantiation is good
10logger = logging.getLogger(__name__)
11div_rn = tl_extra_shim.div_rn
12div_rz = tl_extra_shim.div_rz
13fmod = tl_extra_shim.fmod
14trunc = tl_extra_shim.trunc
17@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
18@triton.jit
19def true_div_func(x, y):
20 return x / y
23@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
24@triton.jit
25def true_div_func_tensor_scalar(x, y):
26 return x / y
29@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
30@triton.jit
31def true_div_func_scalar_tensor(x, y):
32 return x / y
35def true_divide(A, B):
36 logger.debug("GEMS_ILUVATAR TRUE_DIVIDE")
37 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
38 return true_div_func(A, B)
39 elif isinstance(A, torch.Tensor):
40 return true_div_func_tensor_scalar(A, B)
41 elif isinstance(B, torch.Tensor):
42 return true_div_func_scalar_tensor(A, B)
43 else:
44 # Both scalar
45 return torch.tensor(A / B)
48def true_divide_(A, B):
49 logger.debug("GEMS_ILUVATAR TRUE_DIVIDE_")
50 if isinstance(B, torch.Tensor):
51 return true_div_func(A, B, out0=A)
52 else:
53 return true_div_func_tensor_scalar(A, B, out0=A)
56@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
57@triton.jit
58def trunc_div_func(x, y):
59 return trunc(x / y)
62@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
63@triton.jit
64def trunc_div_func_tensor_scalar(x, y):
65 return trunc(div_rz(x, tl.cast(y, x.dtype)))
68@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
69@triton.jit
70def trunc_div_func_scalar_tensor(x, y):
71 return trunc(div_rz(tl.cast(x, y.dtype), y))
74# Integer truncation division: Triton's // on integers is C-style (truncates toward zero)
75@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
76@triton.jit
77def trunc_div_int_func(x, y):
78 return x // y
81@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
82@triton.jit
83def trunc_div_int_func_tensor_scalar(x, y):
84 return x // y
87@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
88@triton.jit
89def trunc_div_int_func_scalar_tensor(x, y):
90 return x // y
93def trunc_divide(A, B):
94 logger.debug("GEMS_ILUVATAR TRUNC_DIVIDE")
95 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
96 if isinstance(A, torch.Tensor) and not A.is_floating_point():
97 if isinstance(B, torch.Tensor):
98 return trunc_div_int_func(A, B)
99 else:
100 return trunc_div_int_func_tensor_scalar(A, B)
101 if isinstance(B, torch.Tensor) and not B.is_floating_point():
102 return trunc_div_int_func_scalar_tensor(A, B)
103 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
104 return trunc_div_func(A, B)
105 elif isinstance(A, torch.Tensor):
106 return trunc_div_func_tensor_scalar(A, B)
107 elif isinstance(B, torch.Tensor):
108 return trunc_div_func_scalar_tensor(A, B)
109 else:
110 # Both scalar
111 return torch.tensor(A / B)
114def trunc_divide_(A, B):
115 logger.debug("GEMS_ILUVATAR TRUNC_DIVIDE_")
116 # Integer types: use dedicated int kernels (Triton // is C-style truncation)
117 if not A.is_floating_point():
118 if isinstance(B, torch.Tensor):
119 return trunc_div_int_func(A, B, out0=A)
120 else:
121 return trunc_div_int_func_tensor_scalar(A, B, out0=A)
122 if isinstance(B, torch.Tensor):
123 return trunc_div_func(A, B, out0=A)
124 else:
125 return trunc_div_func_tensor_scalar(A, B, out0=A)
128@triton.jit
129def _int_floordiv(x, y):
130 # TODO: request Triton to add an integer remainder builtin
131 # The semantic of Triton floordiv differs from Pytorch/Numpy
132 # Triton floordiv equates to
133 # (x - np.fmod(x, y)) / y
134 # whereas Pytorch floordiv is
135 # (x - np.remainder(x, y)) y
136 # The results show a one off difference when
137 # C1) x and y have opposite signs
138 # and C2) x is not multiples of y.
139 # Apart from the above, there's an erroneous case x // 0 returns -1
140 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
141 # but this special case is coalesced into the c1 and c2 check so
142 # there's extra handling.
143 r = x % y
144 c1 = r != 0
145 c2 = (x < 0) ^ (y < 0)
146 return tl.where(c1 & c2, x // y - 1, x // y)
149# TO be consistent with python, numpy and torch, we have to implement it in the
150# following way.
151# CPython
152# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
153# numpy
154# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
155# torch
156# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
157@triton.jit
158def _float_floordiv(x, y):
159 # NOTE: fmod's sign is the same as the dividend
160 remainder = fmod(x, y)
161 imperfect = remainder != 0.0
162 different_sign = (x < 0) ^ (y < 0)
164 # NOTE: we have to use div_rn explicitly here
165 q = div_rn(x - remainder, y)
166 q = tl.where(imperfect & different_sign, q - 1, q)
168 floor_q = tl.math.floor(q)
169 c = q - floor_q > 0.5
170 floor_q = tl.where(c, floor_q + 1.0, floor_q)
172 q_is_zeros = q == 0.0
173 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
175 is_div_by_zero = y == 0.0
176 float_division = x / y
177 out = tl.where(is_div_by_zero, float_division, floor_q)
178 return out
181@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
182@triton.jit
183def floor_div_func(x, y):
184 if x.type.scalar.is_int() & y.type.scalar.is_int():
185 return _int_floordiv(x, y)
186 else:
187 return _float_floordiv(x, y)
190@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
191@triton.jit
192def floor_div_func_tensor_scalar(x, y):
193 if x.type.scalar.is_int() & y.type.scalar.is_int():
194 return _int_floordiv(x, y)
195 else:
196 return _float_floordiv(x, y)
199@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
200@triton.jit
201def floor_div_func_scalar_tensor(x, y):
202 if x.type.scalar.is_int() & y.type.scalar.is_int():
203 return _int_floordiv(x, y)
204 else:
205 return _float_floordiv(x, y)
208def floor_divide(A, B):
209 logger.debug("GEMS_ILUVATAR FLOOR_DIVIDE")
210 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
211 return floor_div_func(A, B)
212 elif isinstance(A, torch.Tensor):
213 return floor_div_func_tensor_scalar(A, B)
214 elif isinstance(B, torch.Tensor):
215 return floor_div_func_scalar_tensor(A, B)
216 else:
217 # Both scalar
218 return torch.tensor(A // B)
221def floor_divide_(A, B):
222 logger.debug("GEMS_ILUVATAR FLOOR_DIVIDE_")
223 if isinstance(B, torch.Tensor):
224 return floor_div_func(A, B, out0=A)
225 else:
226 return floor_div_func_tensor_scalar(A, B, out0=A)
229def div_mode(A, B, rounding_mode=None):
230 if rounding_mode is None:
231 return true_divide(A, B)
232 elif rounding_mode == "trunc":
233 return trunc_divide(A, B)
234 elif rounding_mode == "floor":
235 return floor_divide(A, B)
236 else:
237 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
238 raise ValueError(msg)
241def div_mode_(A, B, rounding_mode=None):
242 if rounding_mode is None:
243 return true_divide_(A, B)
244 elif rounding_mode == "trunc":
245 return trunc_divide_(A, B)
246 elif rounding_mode == "floor":
247 return floor_divide_(A, B)
248 else:
249 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
250 raise ValueError(msg)
253@triton.jit
254def _remainder(x, y):
255 r = x % y
256 c1 = r != 0
257 c2 = (x < 0) ^ (y < 0)
258 return tl.where(c1 & c2, r + y, r)
261@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
262@triton.jit
263def rem_tt(x, y):
264 return _remainder(x, y)
267@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
268@triton.jit
269def rem_ts(x, y):
270 return _remainder(x, y)
273@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
274@triton.jit
275def rem_st(x, y):
276 return _remainder(x, y)
279def remainder(A, B):
280 logger.debug("GEMS_ILUVATAR REMAINDER")
281 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
282 return rem_tt(A, B)
283 elif isinstance(A, torch.Tensor):
284 return rem_ts(A, B)
285 elif isinstance(B, torch.Tensor):
286 return rem_st(A, B)
287 else:
288 # Both scalar
289 return torch.tensor(A % B)
292def remainder_(A, B):
293 logger.debug("GEMS_ILUVATAR REMAINDER_")
294 if isinstance(B, torch.Tensor):
295 return rem_tt(A, B, out0=A)
296 else:
297 return rem_ts(A, B, out0=A)