Coverage for src/flag_gems/runtime/backend/_cambricon/ops/nan_to_num.py: 0%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.mlu.libdevice import isnan as _isnan
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@libtuner(
18 configs=[
19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1),
23 ],
24 key=["n_elements"],
25)
26@triton.jit
27def nan_to_num_kernel(
28 X_ptr,
29 OUT_ptr,
30 nan_val,
31 posinf_val,
32 neginf_val,
33 n_elements,
34 BLOCK_SIZE: tl.constexpr,
35):
36 pid = tl.program_id(0)
37 num_jobs = tl.num_programs(0)
38 block_start = pid * BLOCK_SIZE
39 step = num_jobs * BLOCK_SIZE
40 block_start = block_start.to(tl.int64)
41 for off in range(block_start, n_elements, step):
42 offsets = off + tl.arange(0, BLOCK_SIZE)
43 mask = offsets < n_elements
44 x = tl.load(X_ptr + offsets, mask=mask)
45 x_nan = _isnan(x)
46 x_posinf = x == float("inf")
47 x_neginf = x == float("-inf")
48 result = tl.where(x_nan, nan_val, x)
49 result = tl.where(x_posinf, posinf_val, result)
50 result = tl.where(x_neginf, neginf_val, result)
51 tl.store(OUT_ptr + offsets, result, mask=mask)
54def nan_to_num(A, nan=None, posinf=None, neginf=None):
55 logger.debug("GEMS_CAMBRICON NAN_TO_NUM")
56 if posinf is None:
57 posinf = torch.finfo(A.dtype).max
58 if neginf is None:
59 neginf = torch.finfo(A.dtype).min
60 if nan is None:
61 nan = 0.0
63 A = A.contiguous()
64 out = torch.empty_like(A)
65 N = A.numel()
66 if N == 0:
67 return out
68 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
69 with torch_device_fn.device(A.device):
70 nan_to_num_kernel[grid_fn](A, out, nan, posinf, neginf, N)
71 return out