Coverage for src/flag_gems/ops/i0_.py: 47%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def i0_kernel_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
21 xf = tl.cast(x, tl.float32)
22 ax = tl.abs(xf)
24 t_small = ax / 3.75
25 y_small = t_small * t_small
26 poly_small = 1.0 + y_small * (
27 3.5156229
28 + y_small
29 * (
30 3.0899424
31 + y_small
32 * (
33 1.2067492
34 + y_small * (0.2659732 + y_small * (0.0360768 + y_small * 0.0045813))
35 )
36 )
37 )
39 y_large = 3.75 / ax
40 poly_large = 0.39894228 + y_large * (
41 0.01328592
42 + y_large
43 * (
44 0.00225319
45 + y_large
46 * (
47 -0.00157565
48 + y_large
49 * (
50 0.00916281
51 + y_large
52 * (
53 -0.02057706
54 + y_large
55 * (0.02635537 + y_large * (-0.01647633 + y_large * 0.00392377))
56 )
57 )
58 )
59 )
60 )
61 val_large = tl.exp(ax) * poly_large / tl.sqrt(ax)
63 result = tl.where(ax <= 3.75, poly_small, val_large)
65 result_cast = tl.cast(result, x.dtype)
66 tl.store(x_ptr + offsets, result_cast, mask=mask)
69def i0_(*args, **kwargs):
70 logger.debug("GEMS I0_")
71 x = None
72 if len(args) > 0:
73 x = args[0]
74 else:
75 # Try common keyword names
76 for k in ("input", "self", "x"):
77 if k in kwargs:
78 x = kwargs[k]
79 break
80 if x is None:
81 raise ValueError(
82 "i0_ expects a tensor as the first positional argument or in keyword 'input'/'self'/'x'."
83 )
85 if x.device.type != flag_gems.device:
86 raise AssertionError(f"Input tensor must be on a {flag_gems.device} device.")
87 if not x.is_contiguous():
88 raise AssertionError("Input tensor must be contiguous.")
89 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
90 raise AssertionError(
91 "Unsupported dtype for i0_. Supported: float16, bfloat16, float32, float64."
92 )
94 n_elements = x.numel()
95 if n_elements == 0:
96 return x
98 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
99 i0_kernel_[grid](x, n_elements, BLOCK_SIZE=1024)
100 return x