Coverage for src/flag_gems/ops/remainder.py: 79%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
12@triton.jit
13def _remainder(x, y):
14 r = x % y
15 c1 = r != 0
16 c2 = (x < 0) ^ (y < 0)
17 return tl.where(c1 & c2, r + y, r)
20@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
21@triton.jit
22def rem_tt(x, y):
23 return _remainder(x, y)
26@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
27@triton.jit
28def rem_ts(x, y):
29 return _remainder(x, y)
32@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
33@triton.jit
34def rem_st(x, y):
35 return _remainder(x, y)
38def remainder(A, B):
39 logger.debug("GEMS REMAINDER")
40 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
41 return rem_tt(A, B)
42 elif isinstance(A, torch.Tensor):
43 return rem_ts(A, B)
44 elif isinstance(B, torch.Tensor):
45 return rem_st(A, B)
46 else:
47 # Both scalar
48 return torch.tensor(A % B)
51def remainder_(A, B):
52 logger.debug("GEMS REMAINDER_")
53 if isinstance(B, torch.Tensor):
54 return rem_tt(A, B, out0=A)
55 else:
56 return rem_ts(A, B, out0=A)