Coverage for src/flag_gems/ops/special_i1.py: 37%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9from flag_gems.runtime import torch_device_fn 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.jit 

15def special_i1_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

16 pid = tl.program_id(axis=0) 

17 block_start = pid * BLOCK_SIZE 

18 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

19 mask = offsets < n_elements 

20 

21 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

22 x_f32 = x.to(tl.float32) 

23 ax = tl.abs(x_f32) 

24 

25 # Small region: |x| <= 3.75 

26 y = x_f32 / 3.75 

27 y2 = y * y 

28 # Horner polynomial for small |x| 

29 p = 0.00032411 

30 p = 0.00301532 + y2 * p 

31 p = 0.02658733 + y2 * p 

32 p = 0.15084934 + y2 * p 

33 p = 0.51498869 + y2 * p 

34 p = 0.87890594 + y2 * p 

35 p = 0.5 + y2 * p 

36 ans_small = x_f32 * p 

37 

38 # Large region: |x| > 3.75 

39 # Use asymptotic expansion: I1(x) ~ exp(|x|)/sqrt(|x|) * poly(3.75/|x|) 

40 # Coefficients from Cephes 

41 t = 3.75 / tl.maximum(ax, 1e-20) 

42 q = -0.00420059 

43 q = 0.01787654 + t * q 

44 q = -0.02895312 + t * q 

45 q = 0.02282967 + t * q 

46 q = -0.01031555 + t * q 

47 q = 0.00163801 + t * q 

48 q = -0.00362018 + t * q 

49 q = -0.03988024 + t * q 

50 q = 0.39894228 + t * q 

51 pref = tl.exp(ax) / tl.sqrt(tl.maximum(ax, 1e-20)) 

52 ans_large = pref * q 

53 # I1 is odd 

54 ans_large = tl.where(x_f32 < 0, -ans_large, ans_large) 

55 

56 is_small = ax <= 3.75 

57 ans = tl.where(is_small, ans_small, ans_large) 

58 

59 # Cast back to input dtype and store 

60 tl.store(out_ptr + offsets, ans.to(x.dtype), mask=mask) 

61 

62 

63def _launch_special_i1(x: torch.Tensor, out: torch.Tensor): 

64 if x.device.type != flag_gems.device or out.device.type != flag_gems.device: 

65 raise ValueError(f"Tensors must be {flag_gems.device} tensors") 

66 assert ( 

67 x.numel() == out.numel() 

68 ), "Input and output must have the same number of elements" 

69 assert x.dtype == out.dtype, "Input and output must have the same dtype" 

70 

71 n_elements = x.numel() 

72 if n_elements == 0: 

73 return 

74 

75 BLOCK_SIZE = 1024 

76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

77 with torch_device_fn.device(x.device): 

78 special_i1_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

79 

80 

81def special_i1(self: torch.Tensor): 

82 logger.debug("GEMS SPECIAL_I1") 

83 x = self 

84 x_c = x.contiguous() 

85 out = torch.empty_like(x_c) 

86 _launch_special_i1(x_c, out) 

87 if x.layout == torch.strided and x.is_contiguous(): 

88 return out 

89 else: 

90 return out.view_as(x) 

91 

92 

93def special_i1_out(self: torch.Tensor, out: torch.Tensor): 

94 logger.debug("GEMS SPECIAL_I1_OUT") 

95 x = self 

96 if out.dtype != x.dtype: 

97 raise TypeError("out dtype must match input dtype") 

98 if out.device != x.device: 

99 raise TypeError("out device must match input device") 

100 

101 x_c = x.contiguous() 

102 out_c = out.contiguous() 

103 _launch_special_i1(x_c, out_c) 

104 if out_c.data_ptr() != out.data_ptr(): 

105 out.copy_(out_c) 

106 return out