Coverage for src/flag_gems/runtime/backend/_sunrise/ops/div.py: 0%
175 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.pointwise_dynamic import CodeGenConfig
9from flag_gems.utils.triton_lang_extension import div_rn, div_rz, fmod, trunc
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14MAX_GRID_SIZES = (65535, 65535, 65535)
15config = CodeGenConfig(
16 max_tile_size=1024,
17 max_grid_size=MAX_GRID_SIZES,
18 max_num_warps_per_cta=32,
19 prefer_block_pointer=True,
20 prefer_1d_tile=True,
21)
24@pointwise_dynamic(promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config)
25@triton.jit
26def true_div_func(x, y):
27 return x / y
30@pointwise_dynamic(
31 is_tensor=[True, False], promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config
32)
33@triton.jit
34def true_div_func_tensor_scalar(x, y):
35 return x / y
38@pointwise_dynamic(
39 is_tensor=[False, True], promotion_methods=[(0, 1, "INT_TO_FLOAT")], config=config
40)
41@triton.jit
42def true_div_func_scalar_tensor(x, y):
43 return x / y
46def true_divide(A, B):
47 logger.debug("GEMS TRUE_DIVIDE")
48 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
49 return true_div_func(A, B)
50 elif isinstance(A, torch.Tensor):
51 return true_div_func_tensor_scalar(A, B)
52 elif isinstance(B, torch.Tensor):
53 return true_div_func_scalar_tensor(A, B)
54 else:
55 # Both scalar
56 return torch.tensor(A / B)
59def true_divide_out(A, B, out):
60 logger.debug("GEMS TRUE_DIVIDE OUT")
61 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
62 return true_div_func(A, B, out0=out)
63 elif isinstance(A, torch.Tensor):
64 return true_div_func_tensor_scalar(A, B, out0=out)
65 elif isinstance(B, torch.Tensor):
66 return true_div_func_scalar_tensor(A, B, out0=out)
67 else:
68 # Both scalar
69 return torch.tensor(A / B) if out is None else out.fill_(A / B)
72def true_divide_(A, B):
73 logger.debug("GEMS TRUE_DIVIDE_")
74 if isinstance(B, torch.Tensor):
75 return true_div_func(A, B, out0=A)
76 else:
77 return true_div_func_tensor_scalar(A, B, out0=A)
80@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
81@triton.jit
82def trunc_div_func(x, y):
83 return trunc(div_rz(x, y))
86@pointwise_dynamic(
87 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
88)
89@triton.jit
90def trunc_div_func_tensor_scalar(x, y):
91 return trunc(div_rz(x, y))
94@pointwise_dynamic(
95 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config
96)
97@triton.jit
98def trunc_div_func_scalar_tensor(x, y):
99 return trunc(div_rz(x, y))
102def trunc_divide(A, B):
103 logger.debug("GEMS TRUNC_DIVIDE")
104 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
105 return trunc_div_func(A, B)
106 elif isinstance(A, torch.Tensor):
107 return trunc_div_func_tensor_scalar(A, B)
108 elif isinstance(B, torch.Tensor):
109 return trunc_div_func_scalar_tensor(A, B)
110 else:
111 # Both scalar
112 return torch.tensor(type(A)(int(A / B)))
115def trunc_divide_(A, B):
116 logger.debug("GEMS TRUNC_DIVIDE_")
117 if isinstance(B, torch.Tensor):
118 return trunc_div_func(A, B, out0=A)
119 else:
120 return trunc_div_func_tensor_scalar(A, B, out0=A)
123@triton.jit
124def _int_floordiv(x, y):
125 # TODO: request Triton to add an integer remainder builtin
126 # The semantic of Triton floordiv differs from Pytorch/Numpy
127 # Triton floordiv equates to
128 # (x - np.fmod(x, y)) / y
129 # whereas Pytorch floordiv is
130 # (x - np.remainder(x, y)) y
131 # The results show a one off difference when
132 # C1) x and y have opposite signs
133 # and C2) x is not multiples of y.
134 # Apart from the above, there's an erroneous case x // 0 returns -1
135 # whereas in Pytorch x // 0 returns -1 if x >=0 and -2 if x < 0
136 # but this special case is coalesced into the c1 and c2 check so
137 # there's extra handling.
138 r = x % y
139 c1 = r != 0
140 c2 = (x < 0) ^ (y < 0)
141 return tl.where(c1 & c2, x // y - 1, x // y)
144# TO be consistent with python, numpy and torch, we have to implement it in the
145# following way.
146# CPython
147# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
148# numpy
149# https://github.com/numpy/numpy/blob/a4ad142aa1282a77bbb05acd706cb57c9cc29846/numpy/_core/src/npymath/npy_math_internal.h.src#L532
150# torch
151# https://github.com/pytorch/pytorch/blob/d6d9183456cd07ca0b361a194b98c2fb196e7c36/c10/util/generic_math.h#L23
152@triton.jit
153def _float_floordiv(x, y):
154 # NOTE: fmod's sign is the same as the dividend
155 remainder = fmod(x, y)
156 imperfect = remainder != 0.0
157 different_sign = (x < 0) ^ (y < 0)
159 # NOTE: we have to use div_rn explicitly here
160 q = div_rn(x - remainder, y)
161 q = tl.where(imperfect & different_sign, q - 1, q)
163 floor_q = tl.math.floor(q)
164 c = q - floor_q > 0.5
165 floor_q = tl.where(c, floor_q + 1.0, floor_q)
167 q_is_zeros = q == 0.0
168 floor_q = tl.where(q_is_zeros, tl.where(different_sign, -0.0, 0.0), floor_q)
170 is_div_by_zero = y == 0.0
171 float_division = x / y
172 out = tl.where(is_div_by_zero, float_division, floor_q)
173 return out
176@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
177@triton.jit
178def floor_div_func(x, y):
179 if x.type.scalar.is_int() & x.type.scalar.is_int():
180 return _int_floordiv(x, y)
181 else:
182 return _float_floordiv(x, y)
185@pointwise_dynamic(
186 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
187)
188@triton.jit
189def floor_div_func_tensor_scalar(x, y):
190 if x.type.scalar.is_int() & x.type.scalar.is_int():
191 return _int_floordiv(x, y)
192 else:
193 return _float_floordiv(x, y)
196@pointwise_dynamic(
197 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config
198)
199@triton.jit
200def floor_div_func_scalar_tensor(x, y):
201 if x.type.scalar.is_int() & x.type.scalar.is_int():
202 return _int_floordiv(x, y)
203 else:
204 return _float_floordiv(x, y)
207def floor_divide(A, B):
208 logger.debug("GEMS FLOOR_DIVIDE")
209 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
210 return floor_div_func(A, B)
211 elif isinstance(A, torch.Tensor):
212 return floor_div_func_tensor_scalar(A, B)
213 elif isinstance(B, torch.Tensor):
214 return floor_div_func_scalar_tensor(A, B)
215 else:
216 # Both scalar
217 return torch.tensor(A // B)
220def floor_divide_(A, B):
221 logger.debug("GEMS FLOOR_DIVIDE_")
222 if isinstance(B, torch.Tensor):
223 return floor_div_func(A, B, out0=A)
224 else:
225 return floor_div_func_tensor_scalar(A, B, out0=A)
228def div_mode(A, B, rounding_mode=None):
229 if rounding_mode is None:
230 return true_divide(A, B)
231 elif rounding_mode == "trunc":
232 return trunc_divide(A, B)
233 elif rounding_mode == "floor":
234 return floor_divide(A, B)
235 else:
236 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
237 raise ValueError(msg)
240def div_mode_(A, B, rounding_mode=None):
241 if rounding_mode is None:
242 return true_divide_(A, B)
243 elif rounding_mode == "trunc":
244 return trunc_divide_(A, B)
245 elif rounding_mode == "floor":
246 return floor_divide_(A, B)
247 else:
248 msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
249 raise ValueError(msg)
252@triton.jit
253def _remainder(x, y):
254 r = x % y
255 c1 = r != 0
256 c2 = (x < 0) ^ (y < 0)
257 return tl.where(c1 & c2, r + y, r)
260@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
261@triton.jit
262def rem_tt(x, y):
263 return _remainder(x, y)
266@pointwise_dynamic(
267 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
268)
269@triton.jit
270def rem_ts(x, y):
271 return _remainder(x, y)
274@pointwise_dynamic(
275 is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")], config=config
276)
277@triton.jit
278def rem_st(x, y):
279 return _remainder(x, y)
282def remainder(A, B):
283 logger.debug("GEMS REMAINDER")
284 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
285 return rem_tt(A, B)
286 elif isinstance(A, torch.Tensor):
287 return rem_ts(A, B)
288 elif isinstance(B, torch.Tensor):
289 return rem_st(A, B)
290 else:
291 # Both scalar
292 return torch.tensor(A % B)
295def remainder_(A, B):
296 logger.debug("GEMS REMAINDER_")
297 if isinstance(B, torch.Tensor):
298 return rem_tt(A, B, out0=A)
299 else:
300 return rem_ts(A, B, out0=A)