Coverage for src/flag_gems/ops/rsub.py: 89%
18 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import triton
5from flag_gems.utils import pointwise_dynamic
7logger = logging.getLogger(__name__)
10@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
11@triton.jit
12def rsub_func(x, y, alpha):
13 return y - x * alpha
16@pointwise_dynamic(
17 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
18)
19@triton.jit
20def rsub_func_tensor_scalar(x, y, alpha):
21 return y - x * alpha
24def rsub_tensor(A, B, *, alpha=1):
25 logger.debug("GEMS RSUB_TENSOR")
26 return rsub_func(A, B, alpha)
29def rsub_scalar(A, B, alpha=1):
30 logger.debug("GEMS RSUB_SCALAR")
31 return rsub_func_tensor_scalar(A, B, alpha)