Coverage for src/flag_gems/runtime/backend/_iluvatar/ops/div.py: 0%
166 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9# TODO: Check if this logger instantiation is good
10logger = logging.getLogger(__name__)
11div_rn = tl_extra_shim.div_rn
12div_rz = tl_extra_shim.div_rz
13fmod = tl_extra_shim.fmod
14trunc = tl_extra_shim.trunc
17@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")])
18@triton.jit
19def true_div_func(x, y):
20 return x / y
23@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
24@triton.jit
25def true_div_func_tensor_scalar(x, y):
26 return x / y
29@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")])
30@triton.jit
31def true_div_func_scalar_tensor(x, y):
32 return x / y
35def true_divide(A, B):
36 logger.debug("GEMS TRUE_DIVIDE")
37 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
38 return true_div_func(A, B)
39 elif isinstance(A, torch.Tensor):
40 return true_div_func_tensor_scalar(A, B)
41 elif isinstance(B, torch.Tensor):
42 return true_div_func_scalar_tensor(A, B)
43 else:
44 # Both scalar
45 return torch.tensor(A / B)
48def true_divide_(A, B):
49 logger.debug("GEMS TRUE_DIVIDE_")
50 if isinstance(B, torch.Tensor):
51 return true_div_func(A, B, out0=A)
52 else:
53 return true_div_func_tensor_scalar(A, B, out0=A)
56@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
57@triton.jit
58def trunc_div_func(x, y):
59 return trunc(x / y)
62@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
63@triton.jit
64def trunc_div_func_tensor_scalar(x, y):
65 return trunc(div_rz(x, y))
68@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
69@triton.jit
70def trunc_div_func_scalar_tensor(x, y):
71 return trunc(div_rz(x, y))
74def trunc_divide(A, B):
75 logger.debug("GEMS TRUNC_DIVIDE")
76 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
77 return trunc_div_func(A, B)
78 elif isinstance(A, torch.Tensor):
79 return trunc_div_func_tensor_scalar(A, B)
80 elif isinstance(B, torch.Tensor):
81 return trunc_div_func_scalar_tensor(A, B)
82 else:
83 # Both scalar
84 return torch.tensor(A / B)
87def trunc_divide_(A, B):
88 logger.debug("GEMS TRUNC_DIVIDE_")
89 if isinstance(B, torch.Tensor):
90 return trunc_div_func(A, B, out0=A)
91 else:
92 return trunc_div_func_tensor_scalar(A, B, out0=A)
95@triton.jit
96def _int_floordiv(x, y):
97 # TODO: request Triton to add an integer remainder builtin
98 # The semantic of Triton floordiv differs from Pytorch/Numpy
99 # Triton floordiv equates to
100 # (x - np.fmod(x, y)) / y
101 # whereas Pytorch floordiv is
102 # (x - np.remainder(x, y)) y
103 # The results show a one off difference when
104 # C1) x and y have opposite signs
105 # and C2) x is not multiples of y.
106 # Apart from the above, there's an erroneous case x // 0 returns -1
107 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
108 # but this special case is coalesced into the c1 and c2 check so
109 # there's extra handling.
110 r = x % y
111 c1 = r != 0
112 c2 = (x < 0) ^ (y < 0)
113 return tl.where(c1 & c2, x // y - 1, x // y)
116# TO be consistent with python, numpy and torch, we have to implement it in the
117# following way.
118# CPython
119# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
120# numpy
121# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
122# torch
123# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
124@triton.jit
125def _float_floordiv(x, y):
126 # NOTE: fmod's sign is the same as the dividend
127 remainder = fmod(x, y)
128 imperfect = remainder != 0.0
129 different_sign = (x < 0) ^ (y < 0)
131 # NOTE: we have to use div_rn explicitly here
132 q = div_rn(x - remainder, y)
133 q = tl.where(imperfect & different_sign, q - 1, q)
135 floor_q = tl.math.floor(q)
136 c = q - floor_q > 0.5
137 floor_q = tl.where(c, floor_q + 1.0, floor_q)
139 q_is_zeros = q == 0.0
140 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
142 is_div_by_zero = y == 0.0
143 float_division = x / y
144 out = tl.where(is_div_by_zero, float_division, floor_q)
145 return out
148@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
149@triton.jit
150def floor_div_func(x, y):
151 if x.type.scalar.is_int() & x.type.scalar.is_int():
152 return _int_floordiv(x, y)
153 else:
154 return _float_floordiv(x, y)
157@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
158@triton.jit
159def floor_div_func_tensor_scalar(x, y):
160 if x.type.scalar.is_int() & x.type.scalar.is_int():
161 return _int_floordiv(x, y)
162 else:
163 return _float_floordiv(x, y)
166@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
167@triton.jit
168def floor_div_func_scalar_tensor(x, y):
169 if x.type.scalar.is_int() & x.type.scalar.is_int():
170 return _int_floordiv(x, y)
171 else:
172 return _float_floordiv(x, y)
175def floor_divide(A, B):
176 logger.debug("GEMS FLOOR_DIVIDE")
177 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
178 return floor_div_func(A, B)
179 elif isinstance(A, torch.Tensor):
180 return floor_div_func_tensor_scalar(A, B)
181 elif isinstance(B, torch.Tensor):
182 return floor_div_func_scalar_tensor(A, B)
183 else:
184 # Both scalar
185 return torch.tensor(A // B)
188def floor_divide_(A, B):
189 logger.debug("GEMS FLOOR_DIVIDE_")
190 if isinstance(B, torch.Tensor):
191 return floor_div_func(A, B, out0=A)
192 else:
193 return floor_div_func_tensor_scalar(A, B, out0=A)
196def div_mode(A, B, rounding_mode=None):
197 if rounding_mode is None:
198 return true_divide(A, B)
199 elif rounding_mode == "trunc":
200 return trunc_divide(A, B)
201 elif rounding_mode == "floor":
202 return floor_divide(A, B)
203 else:
204 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
205 raise ValueError(msg)
208def div_mode_(A, B, rounding_mode=None):
209 if rounding_mode is None:
210 return true_divide_(A, B)
211 elif rounding_mode == "trunc":
212 return trunc_divide_(A, B)
213 elif rounding_mode == "floor":
214 return floor_divide_(A, B)
215 else:
216 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
217 raise ValueError(msg)
220@triton.jit
221def _remainder(x, y):
222 r = x % y
223 c1 = r != 0
224 c2 = (x < 0) ^ (y < 0)
225 return tl.where(c1 & c2, r + y, r)
228@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
229@triton.jit
230def rem_tt(x, y):
231 return _remainder(x, y)
234@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
235@triton.jit
236def rem_ts(x, y):
237 return _remainder(x, y)
240@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
241@triton.jit
242def rem_st(x, y):
243 return _remainder(x, y)
246def remainder(A, B):
247 logger.debug("GEMS FLOOR_DIVIDE")
248 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
249 return rem_tt(A, B)
250 elif isinstance(A, torch.Tensor):
251 return rem_ts(A, B)
252 elif isinstance(B, torch.Tensor):
253 return rem_st(A, B)
254 else:
255 # Both scalar
256 return torch.tensor(A % B)
259def remainder_(A, B):
260 logger.debug("GEMS REMAINDER_")
261 if isinstance(B, torch.Tensor):
262 return rem_tt(A, B, out0=A)
263 else:
264 return rem_ts(A, B, out0=A)