Coverage for src/flag_gems/runtime/backend/_cambricon/ops/nan_to_num.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.mlu.libdevice import isnan as _isnan 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10 

11from ..utils import TOTAL_CORE_NUM 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=[ 

19 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1), 

20 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1), 

21 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1), 

22 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=1, num_warps=1), 

23 ], 

24 key=["n_elements"], 

25) 

26@triton.jit 

27def nan_to_num_kernel( 

28 X_ptr, 

29 OUT_ptr, 

30 nan_val, 

31 posinf_val, 

32 neginf_val, 

33 n_elements, 

34 BLOCK_SIZE: tl.constexpr, 

35): 

36 pid = tl.program_id(0) 

37 num_jobs = tl.num_programs(0) 

38 block_start = pid * BLOCK_SIZE 

39 step = num_jobs * BLOCK_SIZE 

40 block_start = block_start.to(tl.int64) 

41 for off in range(block_start, n_elements, step): 

42 offsets = off + tl.arange(0, BLOCK_SIZE) 

43 mask = offsets < n_elements 

44 x = tl.load(X_ptr + offsets, mask=mask) 

45 x_nan = _isnan(x) 

46 x_posinf = x == float("inf") 

47 x_neginf = x == float("-inf") 

48 result = tl.where(x_nan, nan_val, x) 

49 result = tl.where(x_posinf, posinf_val, result) 

50 result = tl.where(x_neginf, neginf_val, result) 

51 tl.store(OUT_ptr + offsets, result, mask=mask) 

52 

53 

54def nan_to_num(A, nan=None, posinf=None, neginf=None): 

55 logger.debug("GEMS_CAMBRICON NAN_TO_NUM") 

56 if posinf is None: 

57 posinf = torch.finfo(A.dtype).max 

58 if neginf is None: 

59 neginf = torch.finfo(A.dtype).min 

60 if nan is None: 

61 nan = 0.0 

62 

63 A = A.contiguous() 

64 out = torch.empty_like(A) 

65 N = A.numel() 

66 if N == 0: 

67 return out 

68 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),) 

69 with torch_device_fn.device(A.device): 

70 nan_to_num_kernel[grid_fn](A, out, nan, posinf, neginf, N) 

71 return out