Coverage for src/flag_gems/ops/remainder.py: 79%

38 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _remainder(x, y): 

15 r = x % y 

16 c1 = r != 0 

17 c2 = (x < 0) ^ (y < 0) 

18 return tl.where(c1 & c2, r + y, r) 

19 

20 

21@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

22@triton.jit 

23def rem_tt(x, y): 

24 return _remainder(x, y) 

25 

26 

27@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

28@triton.jit 

29def rem_ts(x, y): 

30 return _remainder(x, y) 

31 

32 

33@pointwise_dynamic(is_tensor=[False, True], promotion_methods=[(0, 1, "DEFAULT")]) 

34@triton.jit 

35def rem_st(x, y): 

36 return _remainder(x, y) 

37 

38 

39def remainder(A, B): 

40 logger.debug("GEMS REMAINDER") 

41 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

42 return rem_tt(A, B) 

43 elif isinstance(A, torch.Tensor): 

44 return rem_ts(A, B) 

45 elif isinstance(B, torch.Tensor): 

46 return rem_st(A, B) 

47 else: 

48 # Both scalar 

49 return torch.tensor(A % B) 

50 

51 

52def remainder_(A, B): 

53 logger.debug("GEMS REMAINDER_") 

54 if isinstance(B, torch.Tensor): 

55 return rem_tt(A, B, out0=A) 

56 else: 

57 return rem_ts(A, B, out0=A)