Coverage for src/flag_gems/ops/remainder.py: 79%
38 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import pointwise_dynamic
10logger = logging.getLogger(__name__)
13@triton.jit
14def _remainder(x, y):
15 r = x % y
16 c1 = r != 0
17 c2 = (x < 0) ^ (y < 0)
18 return tl.where(c1 & c2, r + y, r)
21@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
22@triton.jit
23def rem_tt(x, y):
24 return _remainder(x, y)
27@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
28@triton.jit
29def rem_ts(x, y):
30 return _remainder(x, y)
33@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")])
34@triton.jit
35def rem_st(x, y):
36 return _remainder(x, y)
39def remainder(A, B):
40 logger.debug("GEMS REMAINDER")
41 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
42 return rem_tt(A, B)
43 elif isinstance(A, torch.Tensor):
44 return rem_ts(A, B)
45 elif isinstance(B, torch.Tensor):
46 return rem_st(A, B)
47 else:
48 # Both scalar
49 return torch.tensor(A % B)
52def remainder_(A, B):
53 logger.debug("GEMS REMAINDER_")
54 if isinstance(B, torch.Tensor):
55 return rem_tt(A, B, out0=A)
56 else:
57 return rem_ts(A, B, out0=A)