Coverage for src/flag_gems/ops/special_i0e.py: 53%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6import flag_gems 

7 

8 

9@triton.jit 

10def _special_i0e_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

11 pid = tl.program_id(axis=0) 

12 block_start = pid * BLOCK_SIZE 

13 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

14 mask = offsets < n_elements 

15 

16 x = tl.load(x_ptr + offsets, mask=mask) 

17 

18 # Compute in fp32 for accuracy/stability 

19 xf = x.to(tl.float32) 

20 ax = tl.abs(xf) 

21 

22 # Small region: x <= 3.75 

23 t_small = ax / 3.75 

24 t2 = t_small * t_small 

25 # Polynomial approximation for I0 in small region (Numerical Recipes) 

26 p = 1.0 + t2 * ( 

27 3.5156229 

28 + t2 

29 * ( 

30 3.0899424 

31 + t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813))) 

32 ) 

33 ) 

34 small = p * tl.exp(-ax) 

35 

36 # Large region: x > 3.75, use asymptotic expansion to avoid exp overflow 

37 # i0e(x) = I0(x)*exp(-|x|) ≈ (1/sqrt(|x|)) * poly(3.75/|x|) 

38 t = 3.75 / ax 

39 q = 0.39894228 + t * ( 

40 0.01328592 

41 + t 

42 * ( 

43 0.00225319 

44 + t 

45 * ( 

46 -0.00157565 

47 + t 

48 * ( 

49 0.00916281 

50 + t 

51 * ( 

52 -0.02057706 

53 + t * (0.02635537 + t * (-0.01647633 + t * 0.00392377)) 

54 ) 

55 ) 

56 ) 

57 ) 

58 ) 

59 large = q / tl.sqrt(ax) 

60 

61 is_large = ax > 3.75 

62 y = tl.where(is_large, large, small) 

63 

64 # Cast back to input dtype for storage 

65 y = y.to(x.dtype) 

66 tl.store(out_ptr + offsets, y, mask=mask) 

67 

68 

69def _run_special_i0e_kernel(x: torch.Tensor, out: torch.Tensor): 

70 if x.device.type != flag_gems.device or out.device.type != flag_gems.device: 

71 raise ValueError(f"Tensors must be {flag_gems.device} tensors") 

72 assert x.dtype in ( 

73 torch.float16, 

74 torch.bfloat16, 

75 torch.float32, 

76 torch.float64, 

77 ), "Unsupported dtype" 

78 assert out.dtype == x.dtype, "Output dtype must match input dtype" 

79 

80 x_c = x.contiguous() 

81 out_c = out.contiguous() 

82 

83 n_elements = out_c.numel() 

84 if n_elements == 0: 

85 return out 

86 

87 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

88 _special_i0e_kernel[grid](x_c, out_c, n_elements, BLOCK_SIZE=1024) 

89 

90 if out_c.data_ptr() != out.data_ptr(): 

91 out.copy_(out_c) 

92 return out 

93 

94 

95def special_i0e(x: torch.Tensor): 

96 """ 

97 ATen wrapper: special_i0e(Tensor self) -> Tensor 

98 """ 

99 out = torch.empty_like(x) 

100 return _run_special_i0e_kernel(x, out) 

101 

102 

103def special_i0e_out(x: torch.Tensor, out: torch.Tensor): 

104 """ 

105 ATen wrapper: special_i0e.out(Tensor self, Tensor out) -> Tensor 

106 """ 

107 # Broadcast input to out's shape if needed 

108 if x.shape != out.shape: 

109 x = x.expand(out.shape) 

110 _run_special_i0e_kernel(x, out) 

111 return out