Coverage for src/flag_gems/ops/poisson.py: 43%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, libtuner
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
15logger = logging.getLogger(__name__)
18@triton.jit
19def poisson_small_lambda(lam, seed, c0, c1, z, MAX_ITERS: tl.constexpr):
20 """
21 Knuth's algorithm for Poisson sampling with small lambda.
22 Returns the count of exponential inter-arrival times that sum to <= 1.
23 Uses inverse transform: -log(U) / lam for exponential samples.
24 """
25 # L = exp(-lambda)
26 L = tl.exp(-lam)
27 k = (lam * 0).to(tl.int32) # Initialize counter to 0
28 p = lam * 0.0 + 1.0 # Initialize p to 1.0
30 # We need to iterate. Each iteration we multiply p by a uniform random.
31 # Continue while p > L.
32 for _ in range(MAX_ITERS):
33 # Generate uniform random
34 r0, r1, r2, r3 = tl.philox(seed, c0, c1, z, z)
35 u = uint_to_uniform_float(r0)
36 # Ensure u is not 0 to avoid issues
37 u = tl.maximum(u, 1e-10)
38 p = p * u
39 # Increment counter where p > L
40 k = tl.where(p > L, k + 1, k)
41 # Update counter for next iteration
42 c0 = c0 + 1
44 return k.to(tl.float32)
47@triton.jit
48def poisson_large_lambda(lam, seed, c0, c1, z):
49 """
50 Normal approximation for Poisson with large lambda.
51 Poisson(lambda) ~ N(lambda, lambda) for large lambda.
52 Uses Box-Muller transform.
53 """
54 # Generate two uniform random numbers for Box-Muller
55 r0, r1, r2, r3 = tl.philox(seed, c0, c1, z, z)
56 u1 = uint_to_uniform_float(r0)
57 u2 = uint_to_uniform_float(r1)
59 # Avoid log(0)
60 u1 = tl.maximum(u1, 1e-10)
62 # Box-Muller transform for standard normal
63 two_pi = 6.283185307179586
64 r = tl.sqrt(-2.0 * tl.log(u1))
65 theta = two_pi * u2
66 normal_sample = r * tl.cos(theta)
68 # Transform to Poisson approximation: mean=lam, std=sqrt(lam)
69 result = lam + tl.sqrt(lam) * normal_sample
71 # Poisson must be non-negative integer
72 result = tl.maximum(result, 0.0)
73 result = tl.floor(result + 0.5) # Round to nearest integer
75 return result
78@libentry()
79@libtuner(
80 configs=[
81 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2),
82 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2),
83 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2),
84 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3),
85 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3),
86 ],
87 key=["N"],
88)
89@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
90def poisson_kernel(
91 inp_ptr,
92 out_ptr,
93 N,
94 philox_seed,
95 philox_offset,
96 BLOCK: tl.constexpr,
97 LAMBDA_THRESHOLD: tl.constexpr,
98 MAX_ITERS: tl.constexpr,
99):
100 """
101 Poisson sampling kernel.
102 For each input lambda:
103 - If lambda < LAMBDA_THRESHOLD: use Knuth's algorithm
104 - Otherwise: use normal approximation
105 """
106 philox_seed = philox_seed.to(tl.int64)
107 philox_offset = philox_offset.to(tl.int64)
108 c0_base = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
109 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
111 pid = tl.program_id(0)
112 offs = pid * BLOCK + tl.arange(0, BLOCK)
113 mask = offs < N
115 # Load input lambda values
116 lam = tl.load(inp_ptr + offs, mask=mask, other=0.0).to(tl.float32)
118 # Clamp lambda to non-negative
119 lam = tl.maximum(lam, 0.0)
121 # Use different algorithms based on lambda size
122 use_small = lam < LAMBDA_THRESHOLD
124 # For small lambda: Knuth's algorithm
125 # Each thread needs its own random state offset based on position and iteration count
126 c0_small = c0_base + offs.to(tl.uint32) * MAX_ITERS
127 z = c0_small * 0
128 small_result = poisson_small_lambda(lam, philox_seed, c0_small, c1, z, MAX_ITERS)
130 # For large lambda: normal approximation
131 c0_large = c0_base + offs.to(tl.uint32)
132 z_large = c0_large * 0
133 large_result = poisson_large_lambda(lam, philox_seed, c0_large, c1, z_large)
135 # Select result based on lambda size
136 result = tl.where(use_small, small_result, large_result)
138 tl.store(out_ptr + offs, result, mask=mask)
141def poisson(input, generator=None):
142 """
143 Returns a tensor of the same size as input with each element sampled
144 from a Poisson distribution with rate parameter given by the corresponding
145 element in input.
147 Args:
148 input (Tensor): the input tensor containing the rates of the Poisson distribution
149 generator (torch.Generator, optional): a pseudorandom number generator for sampling
151 Returns:
152 Tensor: output tensor with Poisson samples
153 """
154 logger.debug("GEMS POISSON")
156 assert input.dtype in (
157 torch.float16,
158 torch.bfloat16,
159 torch.float32,
160 torch.float64,
161 ), f"Unsupported dtype: {input.dtype}"
163 # Ensure input is contiguous
164 inp = input.contiguous()
165 N = volume(inp.shape)
167 # Create output tensor with same shape and dtype as input
168 out = torch.empty_like(inp)
170 if N == 0:
171 return out
173 # Parameters for the algorithm
174 LAMBDA_THRESHOLD = 30 # Threshold for switching between algorithms
175 MAX_ITERS = 64 # Maximum iterations for Knuth's algorithm
177 # Calculate grid
178 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]),)
180 # Get random seed and offset
181 # Each element may need up to MAX_ITERS random numbers for small lambda case
182 increment = triton.cdiv(N * MAX_ITERS, 4)
183 philox_seed, philox_offset = philox_backend_seed_offset(
184 increment, generator=generator
185 )
187 with torch_device_fn.device(inp.device):
188 poisson_kernel[grid](
189 inp,
190 out,
191 N,
192 philox_seed,
193 philox_offset,
194 LAMBDA_THRESHOLD=LAMBDA_THRESHOLD,
195 MAX_ITERS=MAX_ITERS,
196 )
198 return out