Coverage for src/flag_gems/ops/isneginf.py: 84%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic(promotion_methods=[(0, "ALWAYS_BOOL")]) 

12@triton.jit 

13def isneginf_func(x): 

14 x_fp32 = x.to(tl.float32) 

15 return tl_extra_shim.isinf(x_fp32) & (x_fp32 < 0) 

16 

17 

18def isneginf(A): 

19 logger.debug("GEMS ISNEGINF") 

20 return isneginf_func(A) 

21 

22 

23def isneginf_out(A, *, out=None): 

24 logger.debug("GEMS ISNEGINF_OUT") 

25 if out is None: 

26 return isneginf_func(A) 

27 isneginf_func(A, out0=out) 

28 return out