Coverage for src/flag_gems/ops/poisson.py: 43%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13from flag_gems.utils.shape_utils import volume 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@triton.jit 

19def poisson_small_lambda(lam, seed, c0, c1, z, MAX_ITERS: tl.constexpr): 

20 """ 

21 Knuth's algorithm for Poisson sampling with small lambda. 

22 Returns the count of exponential inter-arrival times that sum to <= 1. 

23 Uses inverse transform: -log(U) / lam for exponential samples. 

24 """ 

25 # L = exp(-lambda) 

26 L = tl.exp(-lam) 

27 k = (lam * 0).to(tl.int32) # Initialize counter to 0 

28 p = lam * 0.0 + 1.0 # Initialize p to 1.0 

29 

30 # We need to iterate. Each iteration we multiply p by a uniform random. 

31 # Continue while p > L. 

32 for _ in range(MAX_ITERS): 

33 # Generate uniform random 

34 r0, r1, r2, r3 = tl.philox(seed, c0, c1, z, z) 

35 u = uint_to_uniform_float(r0) 

36 # Ensure u is not 0 to avoid issues 

37 u = tl.maximum(u, 1e-10) 

38 p = p * u 

39 # Increment counter where p > L 

40 k = tl.where(p > L, k + 1, k) 

41 # Update counter for next iteration 

42 c0 = c0 + 1 

43 

44 return k.to(tl.float32) 

45 

46 

47@triton.jit 

48def poisson_large_lambda(lam, seed, c0, c1, z): 

49 """ 

50 Normal approximation for Poisson with large lambda. 

51 Poisson(lambda) ~ N(lambda, lambda) for large lambda. 

52 Uses Box-Muller transform. 

53 """ 

54 # Generate two uniform random numbers for Box-Muller 

55 r0, r1, r2, r3 = tl.philox(seed, c0, c1, z, z) 

56 u1 = uint_to_uniform_float(r0) 

57 u2 = uint_to_uniform_float(r1) 

58 

59 # Avoid log(0) 

60 u1 = tl.maximum(u1, 1e-10) 

61 

62 # Box-Muller transform for standard normal 

63 two_pi = 6.283185307179586 

64 r = tl.sqrt(-2.0 * tl.log(u1)) 

65 theta = two_pi * u2 

66 normal_sample = r * tl.cos(theta) 

67 

68 # Transform to Poisson approximation: mean=lam, std=sqrt(lam) 

69 result = lam + tl.sqrt(lam) * normal_sample 

70 

71 # Poisson must be non-negative integer 

72 result = tl.maximum(result, 0.0) 

73 result = tl.floor(result + 0.5) # Round to nearest integer 

74 

75 return result 

76 

77 

78@libentry() 

79@libtuner( 

80 configs=[ 

81 triton.Config({"BLOCK": 64}, num_warps=2, num_stages=2), 

82 triton.Config({"BLOCK": 128}, num_warps=2, num_stages=2), 

83 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=2), 

84 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=3), 

85 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3), 

86 ], 

87 key=["N"], 

88) 

89@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

90def poisson_kernel( 

91 inp_ptr, 

92 out_ptr, 

93 N, 

94 philox_seed, 

95 philox_offset, 

96 BLOCK: tl.constexpr, 

97 LAMBDA_THRESHOLD: tl.constexpr, 

98 MAX_ITERS: tl.constexpr, 

99): 

100 """ 

101 Poisson sampling kernel. 

102 For each input lambda: 

103 - If lambda < LAMBDA_THRESHOLD: use Knuth's algorithm 

104 - Otherwise: use normal approximation 

105 """ 

106 philox_seed = philox_seed.to(tl.int64) 

107 philox_offset = philox_offset.to(tl.int64) 

108 c0_base = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

109 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

110 

111 pid = tl.program_id(0) 

112 offs = pid * BLOCK + tl.arange(0, BLOCK) 

113 mask = offs < N 

114 

115 # Load input lambda values 

116 lam = tl.load(inp_ptr + offs, mask=mask, other=0.0).to(tl.float32) 

117 

118 # Clamp lambda to non-negative 

119 lam = tl.maximum(lam, 0.0) 

120 

121 # Use different algorithms based on lambda size 

122 use_small = lam < LAMBDA_THRESHOLD 

123 

124 # For small lambda: Knuth's algorithm 

125 # Each thread needs its own random state offset based on position and iteration count 

126 c0_small = c0_base + offs.to(tl.uint32) * MAX_ITERS 

127 z = c0_small * 0 

128 small_result = poisson_small_lambda(lam, philox_seed, c0_small, c1, z, MAX_ITERS) 

129 

130 # For large lambda: normal approximation 

131 c0_large = c0_base + offs.to(tl.uint32) 

132 z_large = c0_large * 0 

133 large_result = poisson_large_lambda(lam, philox_seed, c0_large, c1, z_large) 

134 

135 # Select result based on lambda size 

136 result = tl.where(use_small, small_result, large_result) 

137 

138 tl.store(out_ptr + offs, result, mask=mask) 

139 

140 

141def poisson(input, generator=None): 

142 """ 

143 Returns a tensor of the same size as input with each element sampled 

144 from a Poisson distribution with rate parameter given by the corresponding 

145 element in input. 

146 

147 Args: 

148 input (Tensor): the input tensor containing the rates of the Poisson distribution 

149 generator (torch.Generator, optional): a pseudorandom number generator for sampling 

150 

151 Returns: 

152 Tensor: output tensor with Poisson samples 

153 """ 

154 logger.debug("GEMS POISSON") 

155 

156 assert input.dtype in ( 

157 torch.float16, 

158 torch.bfloat16, 

159 torch.float32, 

160 torch.float64, 

161 ), f"Unsupported dtype: {input.dtype}" 

162 

163 # Ensure input is contiguous 

164 inp = input.contiguous() 

165 N = volume(inp.shape) 

166 

167 # Create output tensor with same shape and dtype as input 

168 out = torch.empty_like(inp) 

169 

170 if N == 0: 

171 return out 

172 

173 # Parameters for the algorithm 

174 LAMBDA_THRESHOLD = 30 # Threshold for switching between algorithms 

175 MAX_ITERS = 64 # Maximum iterations for Knuth's algorithm 

176 

177 # Calculate grid 

178 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"]),) 

179 

180 # Get random seed and offset 

181 # Each element may need up to MAX_ITERS random numbers for small lambda case 

182 increment = triton.cdiv(N * MAX_ITERS, 4) 

183 philox_seed, philox_offset = philox_backend_seed_offset( 

184 increment, generator=generator 

185 ) 

186 

187 with torch_device_fn.device(inp.device): 

188 poisson_kernel[grid]( 

189 inp, 

190 out, 

191 N, 

192 philox_seed, 

193 philox_offset, 

194 LAMBDA_THRESHOLD=LAMBDA_THRESHOLD, 

195 MAX_ITERS=MAX_ITERS, 

196 ) 

197 

198 return out