Coverage for src/flag_gems/ops/rsub.py: 89%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import triton 

4 

5from flag_gems.utils import pointwise_dynamic 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

11@triton.jit 

12def rsub_func(x, y, alpha): 

13 return y - x * alpha 

14 

15 

16@pointwise_dynamic( 

17 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

18) 

19@triton.jit 

20def rsub_func_tensor_scalar(x, y, alpha): 

21 return y - x * alpha 

22 

23 

24def rsub_tensor(A, B, *, alpha=1): 

25 logger.debug("GEMS RSUB_TENSOR") 

26 return rsub_func(A, B, alpha) 

27 

28 

29def rsub_scalar(A, B, alpha=1): 

30 logger.debug("GEMS RSUB_SCALAR") 

31 return rsub_func_tensor_scalar(A, B, alpha)