Coverage for src/flag_gems/ops/i0_.py: 47%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def i0_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

21 xf = tl.cast(x, tl.float32) 

22 ax = tl.abs(xf) 

23 

24 t_small = ax / 3.75 

25 y_small = t_small * t_small 

26 poly_small = 1.0 + y_small * ( 

27 3.5156229 

28 + y_small 

29 * ( 

30 3.0899424 

31 + y_small 

32 * ( 

33 1.2067492 

34 + y_small * (0.2659732 + y_small * (0.0360768 + y_small * 0.0045813)) 

35 ) 

36 ) 

37 ) 

38 

39 y_large = 3.75 / ax 

40 poly_large = 0.39894228 + y_large * ( 

41 0.01328592 

42 + y_large 

43 * ( 

44 0.00225319 

45 + y_large 

46 * ( 

47 -0.00157565 

48 + y_large 

49 * ( 

50 0.00916281 

51 + y_large 

52 * ( 

53 -0.02057706 

54 + y_large 

55 * (0.02635537 + y_large * (-0.01647633 + y_large * 0.00392377)) 

56 ) 

57 ) 

58 ) 

59 ) 

60 ) 

61 val_large = tl.exp(ax) * poly_large / tl.sqrt(ax) 

62 

63 result = tl.where(ax <= 3.75, poly_small, val_large) 

64 

65 result_cast = tl.cast(result, x.dtype) 

66 tl.store(x_ptr + offsets, result_cast, mask=mask) 

67 

68 

69def i0_(*args, **kwargs): 

70 logger.debug("GEMS I0_") 

71 x = None 

72 if len(args) > 0: 

73 x = args[0] 

74 else: 

75 # Try common keyword names 

76 for k in ("input", "self", "x"): 

77 if k in kwargs: 

78 x = kwargs[k] 

79 break 

80 if x is None: 

81 raise ValueError( 

82 "i0_ expects a tensor as the first positional argument or in keyword 'input'/'self'/'x'." 

83 ) 

84 

85 if x.device.type != flag_gems.device: 

86 raise AssertionError(f"Input tensor must be on a {flag_gems.device} device.") 

87 if not x.is_contiguous(): 

88 raise AssertionError("Input tensor must be contiguous.") 

89 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

90 raise AssertionError( 

91 "Unsupported dtype for i0_. Supported: float16, bfloat16, float32, float64." 

92 ) 

93 

94 n_elements = x.numel() 

95 if n_elements == 0: 

96 return x 

97 

98 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

99 i0_kernel_[grid](x, n_elements, BLOCK_SIZE=1024) 

100 return x