Coverage for src/flag_gems/ops/greater.py: 74%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")])
12@triton.jit
13def greater_func(x, y):
14 return x.to(tl.float32) > y
17def greater(A, B):
18 logger.debug("GEMS GREATER")
19 return greater_func(A, B)
22def greater_out(A, B, *, out=None):
23 logger.debug("GEMS GREATER_OUT")
24 if out is None:
25 return greater_func(A, B)
26 greater_func(A, B, out0=out)
27 return out
30@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")])
31@triton.jit
32def greater_func_scalar(x, y):
33 return x.to(tl.float32) > y
36def greater_scalar(A, B):
37 logger.debug("GEMS GREATER_SCALAR")
38 return greater_func_scalar(A, B)
41def greater_scalar_out(A, B, *, out=None):
42 logger.debug("GEMS GREATER_SCALAR_OUT")
43 if out is None:
44 return greater_func_scalar(A, B)
45 greater_func_scalar(A, B, out0=out)
46 return out