Coverage for src/flag_gems/ops/i0.py: 52%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def i0_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

21 x_f32 = x.to(tl.float32) 

22 ax = tl.abs(x_f32) 

23 

24 # Small region: |x| <= 3.75 

25 t = x_f32 / 3.75 

26 y = t * t 

27 p_small = 1.0 + y * ( 

28 3.5156229 

29 + y 

30 * ( 

31 3.0899424 

32 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813))) 

33 ) 

34 ) 

35 

36 # Large region: |x| > 3.75 

37 yb = 3.75 / ax 

38 p_big = 0.39894228 + yb * ( 

39 0.01328592 

40 + yb 

41 * ( 

42 0.00225319 

43 + yb 

44 * ( 

45 -0.00157565 

46 + yb 

47 * ( 

48 0.00916281 

49 + yb 

50 * ( 

51 -0.02057706 

52 + yb * (0.02635537 + yb * (-0.01647633 + yb * 0.00392377)) 

53 ) 

54 ) 

55 ) 

56 ) 

57 ) 

58 # Avoid division by zero via masking; big branch only used when ax > 3.75 

59 res_big = tl.exp(ax) * p_big / tl.sqrt(ax) 

60 

61 use_small = ax <= 3.75 

62 res = tl.where(use_small, p_small, res_big) 

63 

64 # Store result; Triton will cast to the dtype of out_ptr as needed 

65 tl.store(out_ptr + offsets, res, mask=mask) 

66 

67 

68def _launch_i0(out: torch.Tensor, x: torch.Tensor): 

69 if x.device.type != flag_gems.device or out.device.type != flag_gems.device: 

70 raise ValueError(f"Input and output must be {flag_gems.device} tensors") 

71 assert ( 

72 out.numel() == x.numel() 

73 ), "Input and output must have the same number of elements" 

74 assert out.device == x.device, "Input and output must be on the same device" 

75 

76 x_in = x 

77 out_in = out 

78 

79 # Ensure floating point compute 

80 if not x_in.is_floating_point(): 

81 x_in = x_in.to(torch.get_default_dtype()) 

82 

83 # Cast input to match the desired output dtype if needed 

84 # (Compute will be done in fp32 inside kernel; store will cast to out dtype) 

85 if x_in.dtype != out_in.dtype: 

86 x_in = x_in.to(out_in.dtype) 

87 

88 x_contig = x_in.contiguous() 

89 out_was_noncontig = not out_in.is_contiguous() 

90 out_contig = out_in.contiguous() if out_was_noncontig else out_in 

91 

92 n_elements = out_contig.numel() 

93 BLOCK_SIZE = 1024 

94 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

95 

96 i0_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

97 

98 if out_was_noncontig: 

99 out_in.copy_(out_contig) 

100 return out_in 

101 

102 

103def i0(x: torch.Tensor): 

104 logger.debug("GEMS I0") 

105 if x.device.type != flag_gems.device: 

106 raise ValueError(f"i0: input tensor must be on {flag_gems.device} device") 

107 # Result dtype follows PyTorch's floating type behavior 

108 out_dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype() 

109 out = torch.empty_like(x.to(dtype=out_dtype), dtype=out_dtype, device=x.device) 

110 _launch_i0(out, x) 

111 return out 

112 

113 

114def i0_out(x: torch.Tensor, out: torch.Tensor): 

115 logger.debug("GEMS I0_OUT") 

116 if x.device.type != flag_gems.device or out.device.type != flag_gems.device: 

117 raise ValueError( 

118 f"i0_out: input and output tensors must be on {flag_gems.device} device" 

119 ) 

120 if not out.is_floating_point(): 

121 raise TypeError("i0_out: output tensor must be a floating point type") 

122 if x.numel() != out.numel(): 

123 raise ValueError( 

124 "i0_out: input and output must have the same number of elements" 

125 ) 

126 _launch_i0(out, x) 

127 return out