Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isfinite.py: 0%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12_finitef = tl_extra_shim.finitef 

13 

14 

15@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")]) 

16@triton.jit 

17def isfinite_func_f32(x): 

18 # Bitwise check: finite if exponent bits are not all 1s 

19 # float32 exponent mask: 0x7F800000 

20 bits = x.to(tl.uint32, bitcast=True) 

21 exp_mask = tl.full(bits.shape, 0x7F800000, dtype=tl.uint32) 

22 return (bits & exp_mask) != exp_mask 

23 

24 

25@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")]) 

26@triton.jit 

27def isfinite_func_f16(x): 

28 # Bitwise check for float16: exponent mask 0x7C00 

29 bits = x.to(tl.uint16, bitcast=True) 

30 exp_mask = tl.full(bits.shape, 0x7C00, dtype=tl.uint16) 

31 return (bits & exp_mask) != exp_mask 

32 

33 

34@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")]) 

35@triton.jit 

36def isfinite_func(x): 

37 return _finitef(x.to(tl.float32)) 

38 

39 

40def isfinite( 

41 A: torch.Tensor, 

42) -> torch.Tensor: 

43 logger.debug("GEMS_KUNLUNXIN ISFINITE") 

44 if A.is_floating_point(): 

45 if A.dtype == torch.float32: 

46 return isfinite_func_f32(A) 

47 elif A.dtype == torch.float16: 

48 return isfinite_func_f16(A) 

49 else: 

50 # bfloat16, float64, etc. - use original approach 

51 return isfinite_func(A) 

52 else: 

53 return torch.full(A.shape, True, dtype=torch.bool, device=A.device)