Coverage for src/flag_gems/ops/special_i1.py: 37%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
9from flag_gems.runtime import torch_device_fn
11logger = logging.getLogger(__name__)
14@triton.jit
15def special_i1_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
16 pid = tl.program_id(axis=0)
17 block_start = pid * BLOCK_SIZE
18 offsets = block_start + tl.arange(0, BLOCK_SIZE)
19 mask = offsets < n_elements
21 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
22 x_f32 = x.to(tl.float32)
23 ax = tl.abs(x_f32)
25 # Small region: |x| <= 3.75
26 y = x_f32 / 3.75
27 y2 = y * y
28 # Horner polynomial for small |x|
29 p = 0.00032411
30 p = 0.00301532 + y2 * p
31 p = 0.02658733 + y2 * p
32 p = 0.15084934 + y2 * p
33 p = 0.51498869 + y2 * p
34 p = 0.87890594 + y2 * p
35 p = 0.5 + y2 * p
36 ans_small = x_f32 * p
38 # Large region: |x| > 3.75
39 # Use asymptotic expansion: I1(x) ~ exp(|x|)/sqrt(|x|) * poly(3.75/|x|)
40 # Coefficients from Cephes
41 t = 3.75 / tl.maximum(ax, 1e-20)
42 q = -0.00420059
43 q = 0.01787654 + t * q
44 q = -0.02895312 + t * q
45 q = 0.02282967 + t * q
46 q = -0.01031555 + t * q
47 q = 0.00163801 + t * q
48 q = -0.00362018 + t * q
49 q = -0.03988024 + t * q
50 q = 0.39894228 + t * q
51 pref = tl.exp(ax) / tl.sqrt(tl.maximum(ax, 1e-20))
52 ans_large = pref * q
53 # I1 is odd
54 ans_large = tl.where(x_f32 < 0, -ans_large, ans_large)
56 is_small = ax <= 3.75
57 ans = tl.where(is_small, ans_small, ans_large)
59 # Cast back to input dtype and store
60 tl.store(out_ptr + offsets, ans.to(x.dtype), mask=mask)
63def _launch_special_i1(x: torch.Tensor, out: torch.Tensor):
64 if x.device.type != flag_gems.device or out.device.type != flag_gems.device:
65 raise ValueError(f"Tensors must be {flag_gems.device} tensors")
66 assert (
67 x.numel() == out.numel()
68 ), "Input and output must have the same number of elements"
69 assert x.dtype == out.dtype, "Input and output must have the same dtype"
71 n_elements = x.numel()
72 if n_elements == 0:
73 return
75 BLOCK_SIZE = 1024
76 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
77 with torch_device_fn.device(x.device):
78 special_i1_kernel[grid](x, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)
81def special_i1(self: torch.Tensor):
82 logger.debug("GEMS SPECIAL_I1")
83 x = self
84 x_c = x.contiguous()
85 out = torch.empty_like(x_c)
86 _launch_special_i1(x_c, out)
87 if x.layout == torch.strided and x.is_contiguous():
88 return out
89 else:
90 return out.view_as(x)
93def special_i1_out(self: torch.Tensor, out: torch.Tensor):
94 logger.debug("GEMS SPECIAL_I1_OUT")
95 x = self
96 if out.dtype != x.dtype:
97 raise TypeError("out dtype must match input dtype")
98 if out.device != x.device:
99 raise TypeError("out device must match input device")
101 x_c = x.contiguous()
102 out_c = out.contiguous()
103 _launch_special_i1(x_c, out_c)
104 if out_c.data_ptr() != out.data_ptr():
105 out.copy_(out_c)
106 return out