Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/isfinite.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12_finitef = tl_extra_shim.finitef
15@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")])
16@triton.jit
17def isfinite_func_f32(x):
18 # Bitwise check: finite if exponent bits are not all 1s
19 # float32 exponent mask: 0x7F800000
20 bits = x.to(tl.uint32, bitcast=True)
21 exp_mask = tl.full(bits.shape, 0x7F800000, dtype=tl.uint32)
22 return (bits & exp_mask) != exp_mask
25@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")])
26@triton.jit
27def isfinite_func_f16(x):
28 # Bitwise check for float16: exponent mask 0x7C00
29 bits = x.to(tl.uint16, bitcast=True)
30 exp_mask = tl.full(bits.shape, 0x7C00, dtype=tl.uint16)
31 return (bits & exp_mask) != exp_mask
34@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "ALWAYS_BOOL")])
35@triton.jit
36def isfinite_func(x):
37 return _finitef(x.to(tl.float32))
40def isfinite(
41 A: torch.Tensor,
42) -> torch.Tensor:
43 logger.debug("GEMS_KUNLUNXIN ISFINITE")
44 if A.is_floating_point():
45 if A.dtype == torch.float32:
46 return isfinite_func_f32(A)
47 elif A.dtype == torch.float16:
48 return isfinite_func_f16(A)
49 else:
50 # bfloat16, float64, etc. - use original approach
51 return isfinite_func(A)
52 else:
53 return torch.full(A.shape, True, dtype=torch.bool, device=A.device)