Coverage for src/flag_gems/ops/signbit.py: 53%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")])
12@triton.jit
13def signbit_func(x):
14 if tl.constexpr(x.dtype.is_fp32()):
15 xi32 = x.to(tl.int32, bitcast=True)
16 return xi32 < 0
17 elif tl.constexpr(x.dtype.is_fp16()):
18 xi16 = x.to(tl.int16, bitcast=True)
19 return xi16 < 0
20 elif tl.constexpr(x.dtype.is_bf16()):
21 xi16 = x.to(tl.int16, bitcast=True)
22 return xi16 < 0
23 elif tl.constexpr(x.dtype.is_fp64()):
24 xi64 = x.to(tl.int64, bitcast=True)
25 return xi64 < 0
26 else:
27 return x < 0
30def signbit(A):
31 logger.debug("GEMS SIGNBIT")
32 return signbit_func(A)
35def signbit_out(A, *, out=None):
36 logger.debug("GEMS SIGNBIT_OUT")
37 if out is None:
38 return signbit_func(A)
39 signbit_func(A, out0=out)
40 return out