Coverage for src/flag_gems/ops/signbit.py: 53%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) 

12@triton.jit 

13def signbit_func(x): 

14 if tl.constexpr(x.dtype.is_fp32()): 

15 xi32 = x.to(tl.int32, bitcast=True) 

16 return xi32 < 0 

17 elif tl.constexpr(x.dtype.is_fp16()): 

18 xi16 = x.to(tl.int16, bitcast=True) 

19 return xi16 < 0 

20 elif tl.constexpr(x.dtype.is_bf16()): 

21 xi16 = x.to(tl.int16, bitcast=True) 

22 return xi16 < 0 

23 elif tl.constexpr(x.dtype.is_fp64()): 

24 xi64 = x.to(tl.int64, bitcast=True) 

25 return xi64 < 0 

26 else: 

27 return x < 0 

28 

29 

30def signbit(A): 

31 logger.debug("GEMS SIGNBIT") 

32 return signbit_func(A) 

33 

34 

35def signbit_out(A, *, out=None): 

36 logger.debug("GEMS SIGNBIT_OUT") 

37 if out is None: 

38 return signbit_func(A) 

39 signbit_func(A, out0=out) 

40 return out