Coverage for src/flag_gems/ops/i0.py: 52%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def i0_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0)
21 x_f32 = x.to(tl.float32)
22 ax = tl.abs(x_f32)
24 # Small region: |x| <= 3.75
25 t = x_f32 / 3.75
26 y = t * t
27 p_small = 1.0 + y * (
28 3.5156229
29 + y
30 * (
31 3.0899424
32 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))
33 )
34 )
36 # Large region: |x| > 3.75
37 yb = 3.75 / ax
38 p_big = 0.39894228 + yb * (
39 0.01328592
40 + yb
41 * (
42 0.00225319
43 + yb
44 * (
45 -0.00157565
46 + yb
47 * (
48 0.00916281
49 + yb
50 * (
51 -0.02057706
52 + yb * (0.02635537 + yb * (-0.01647633 + yb * 0.00392377))
53 )
54 )
55 )
56 )
57 )
58 # Avoid division by zero via masking; big branch only used when ax > 3.75
59 res_big = tl.exp(ax) * p_big / tl.sqrt(ax)
61 use_small = ax <= 3.75
62 res = tl.where(use_small, p_small, res_big)
64 # Store result; Triton will cast to the dtype of out_ptr as needed
65 tl.store(out_ptr + offsets, res, mask=mask)
68def _launch_i0(out: torch.Tensor, x: torch.Tensor):
69 if x.device.type != flag_gems.device or out.device.type != flag_gems.device:
70 raise ValueError(f"Input and output must be {flag_gems.device} tensors")
71 assert (
72 out.numel() == x.numel()
73 ), "Input and output must have the same number of elements"
74 assert out.device == x.device, "Input and output must be on the same device"
76 x_in = x
77 out_in = out
79 # Ensure floating point compute
80 if not x_in.is_floating_point():
81 x_in = x_in.to(torch.get_default_dtype())
83 # Cast input to match the desired output dtype if needed
84 # (Compute will be done in fp32 inside kernel; store will cast to out dtype)
85 if x_in.dtype != out_in.dtype:
86 x_in = x_in.to(out_in.dtype)
88 x_contig = x_in.contiguous()
89 out_was_noncontig = not out_in.is_contiguous()
90 out_contig = out_in.contiguous() if out_was_noncontig else out_in
92 n_elements = out_contig.numel()
93 BLOCK_SIZE = 1024
94 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
96 i0_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=BLOCK_SIZE)
98 if out_was_noncontig:
99 out_in.copy_(out_contig)
100 return out_in
103def i0(x: torch.Tensor):
104 logger.debug("GEMS I0")
105 if x.device.type != flag_gems.device:
106 raise ValueError(f"i0: input tensor must be on {flag_gems.device} device")
107 # Result dtype follows PyTorch's floating type behavior
108 out_dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype()
109 out = torch.empty_like(x.to(dtype=out_dtype), dtype=out_dtype, device=x.device)
110 _launch_i0(out, x)
111 return out
114def i0_out(x: torch.Tensor, out: torch.Tensor):
115 logger.debug("GEMS I0_OUT")
116 if x.device.type != flag_gems.device or out.device.type != flag_gems.device:
117 raise ValueError(
118 f"i0_out: input and output tensors must be on {flag_gems.device} device"
119 )
120 if not out.is_floating_point():
121 raise TypeError("i0_out: output tensor must be a floating point type")
122 if x.numel() != out.numel():
123 raise ValueError(
124 "i0_out: input and output must have the same number of elements"
125 )
126 _launch_i0(out, x)
127 return out