Coverage for src/flag_gems/ops/special_i0e.py: 53%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
6import flag_gems
9@triton.jit
10def _special_i0e_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
11 pid = tl.program_id(axis=0)
12 block_start = pid * BLOCK_SIZE
13 offsets = block_start + tl.arange(0, BLOCK_SIZE)
14 mask = offsets < n_elements
16 x = tl.load(x_ptr + offsets, mask=mask)
18 # Compute in fp32 for accuracy/stability
19 xf = x.to(tl.float32)
20 ax = tl.abs(xf)
22 # Small region: x <= 3.75
23 t_small = ax / 3.75
24 t2 = t_small * t_small
25 # Polynomial approximation for I0 in small region (Numerical Recipes)
26 p = 1.0 + t2 * (
27 3.5156229
28 + t2
29 * (
30 3.0899424
31 + t2 * (1.2067492 + t2 * (0.2659732 + t2 * (0.0360768 + t2 * 0.0045813)))
32 )
33 )
34 small = p * tl.exp(-ax)
36 # Large region: x > 3.75, use asymptotic expansion to avoid exp overflow
37 # i0e(x) = I0(x)*exp(-|x|) ≈ (1/sqrt(|x|)) * poly(3.75/|x|)
38 t = 3.75 / ax
39 q = 0.39894228 + t * (
40 0.01328592
41 + t
42 * (
43 0.00225319
44 + t
45 * (
46 -0.00157565
47 + t
48 * (
49 0.00916281
50 + t
51 * (
52 -0.02057706
53 + t * (0.02635537 + t * (-0.01647633 + t * 0.00392377))
54 )
55 )
56 )
57 )
58 )
59 large = q / tl.sqrt(ax)
61 is_large = ax > 3.75
62 y = tl.where(is_large, large, small)
64 # Cast back to input dtype for storage
65 y = y.to(x.dtype)
66 tl.store(out_ptr + offsets, y, mask=mask)
69def _run_special_i0e_kernel(x: torch.Tensor, out: torch.Tensor):
70 if x.device.type != flag_gems.device or out.device.type != flag_gems.device:
71 raise ValueError(f"Tensors must be {flag_gems.device} tensors")
72 assert x.dtype in (
73 torch.float16,
74 torch.bfloat16,
75 torch.float32,
76 torch.float64,
77 ), "Unsupported dtype"
78 assert out.dtype == x.dtype, "Output dtype must match input dtype"
80 x_c = x.contiguous()
81 out_c = out.contiguous()
83 n_elements = out_c.numel()
84 if n_elements == 0:
85 return out
87 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
88 _special_i0e_kernel[grid](x_c, out_c, n_elements, BLOCK_SIZE=1024)
90 if out_c.data_ptr() != out.data_ptr():
91 out.copy_(out_c)
92 return out
95def special_i0e(x: torch.Tensor):
96 """
97 ATen wrapper: special_i0e(Tensor self) -> Tensor
98 """
99 out = torch.empty_like(x)
100 return _run_special_i0e_kernel(x, out)
103def special_i0e_out(x: torch.Tensor, out: torch.Tensor):
104 """
105 ATen wrapper: special_i0e.out(Tensor self, Tensor out) -> Tensor
106 """
107 # Broadcast input to out's shape if needed
108 if x.shape != out.shape:
109 x = x.expand(out.shape)
110 _run_special_i0e_kernel(x, out)
111 return out