Coverage for src/flag_gems/ops/exponential_.py: 32%
202 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, torch_device_fn
8from flag_gems.utils import libentry, libtuner
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger(__name__)
17@triton.jit
18def safe_fast_log_f32(x):
19 min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32)
20 max_u = x * 0.0 + 0.99999994
21 x = tl.minimum(tl.maximum(x, min_normal), max_u)
22 bits = x.to(tl.int32, bitcast=True)
23 exponent = (bits >> 23) - 127
24 mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0
25 m1 = mantissa - 1.0
26 return (
27 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25)))
28 + exponent.to(tl.float32) * 0.6931471805599453
29 )
32@triton.jit
33def safe_fast_log_f64(x):
34 min_normal = x * 0.0 + 2.2250738585072014e-308
35 max_u = x * 0.0 + (1.0 - 2.220446049250313e-16)
36 x = tl.minimum(tl.maximum(x, min_normal), max_u)
37 bits = x.to(tl.int64, bitcast=True)
38 exponent = (bits >> 52) - 1023
39 mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * (
40 1.0 / 4503599627370496.0
41 ) + 1.0
42 m1 = mantissa - 1.0
43 return (
44 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25)))
45 + exponent.to(tl.float64) * 0.6931471805599453
46 )
49@triton.jit
50def paste_u64(hi: tl.uint32, lo: tl.uint32):
51 return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64)
54@triton.jit
55def transform_exponential_f32_precise(u, inv_lambd, eps_minus):
56 log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u))
57 # log = tl.log(tl.maximum(u, 1e-38))
58 return -inv_lambd * log
61@triton.jit
62def transform_exponential_f32_fast(u, inv_lambd, eps_minus):
63 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u))
64 # log = tl.log(tl.maximum(u, 1e-38))
65 return -inv_lambd * log
68if device.vendor_name == "iluvatar":
69 transform_exponential_f32 = transform_exponential_f32_precise
70else:
71 transform_exponential_f32 = transform_exponential_f32_fast
74@triton.jit
75def transform_exponential_f64(u, inv_lambd, eps_minus):
76 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u))
77 return -inv_lambd * log
80@libentry()
81@libtuner(
82 configs=[
83 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2),
84 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=3),
85 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3),
86 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3),
87 ],
88 key=["N"],
89)
90@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
91def fused_exponential_kernel_f32_unroll8(
92 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
93):
94 philox_seed = philox_seed.to(tl.int64)
95 philox_offset = philox_offset.to(tl.int64)
96 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
97 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
99 pid = tl.program_id(0)
100 block_start = pid * BLOCK
101 offsets = block_start + tl.arange(0, BLOCK)
103 c0_first = c0 + offsets * 4
104 c0_second = c0_first + BLOCK * 4
105 z = c0_first * 0
107 r0_0, r1_0, r2_0, r3_0 = tl.philox(philox_seed, c0_first, c1, z, z)
108 r0_1, r1_1, r2_1, r3_1 = tl.philox(philox_seed, c0_second, c1, z, z)
110 y0_0 = transform_exponential_f32(uint_to_uniform_float(r0_0), inv_lambd, eps_minus)
111 y1_0 = transform_exponential_f32(uint_to_uniform_float(r1_0), inv_lambd, eps_minus)
112 y2_0 = transform_exponential_f32(uint_to_uniform_float(r2_0), inv_lambd, eps_minus)
113 y3_0 = transform_exponential_f32(uint_to_uniform_float(r3_0), inv_lambd, eps_minus)
115 y0_1 = transform_exponential_f32(uint_to_uniform_float(r0_1), inv_lambd, eps_minus)
116 y1_1 = transform_exponential_f32(uint_to_uniform_float(r1_1), inv_lambd, eps_minus)
117 y2_1 = transform_exponential_f32(uint_to_uniform_float(r2_1), inv_lambd, eps_minus)
118 y3_1 = transform_exponential_f32(uint_to_uniform_float(r3_1), inv_lambd, eps_minus)
120 base_off = pid.to(tl.uint64) * BLOCK * 8
121 off0 = base_off + tl.arange(0, BLOCK)
122 off1 = off0 + BLOCK
123 off2 = off1 + BLOCK
124 off3 = off2 + BLOCK
125 off4 = off3 + BLOCK
126 off5 = off4 + BLOCK
127 off6 = off5 + BLOCK
128 off7 = off6 + BLOCK
130 tl.store(out_ptr + off0, y0_0, mask=off0 < N)
131 tl.store(out_ptr + off1, y1_0, mask=off1 < N)
132 tl.store(out_ptr + off2, y2_0, mask=off2 < N)
133 tl.store(out_ptr + off3, y3_0, mask=off3 < N)
134 tl.store(out_ptr + off4, y0_1, mask=off4 < N)
135 tl.store(out_ptr + off5, y1_1, mask=off5 < N)
136 tl.store(out_ptr + off6, y2_1, mask=off6 < N)
137 tl.store(out_ptr + off7, y3_1, mask=off7 < N)
140@libentry()
141@libtuner(
142 configs=[
143 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2),
144 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=3),
145 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3),
146 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3),
147 ],
148 key=["N"],
149)
150@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
151def fused_exponential_kernel_f32(
152 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
153):
154 philox_seed = philox_seed.to(tl.int64)
155 philox_offset = philox_offset.to(tl.int64)
156 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
157 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
159 pid = tl.program_id(0)
160 i = pid * BLOCK + tl.arange(0, BLOCK)
161 c0 += i
162 z = c0 * 0
163 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)
165 y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus)
166 y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus)
167 y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus)
168 y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus)
170 start = pid.to(tl.uint64) * BLOCK * 4
171 off0 = start + tl.arange(0, BLOCK)
172 off1 = off0 + BLOCK
173 off2 = off1 + BLOCK
174 off3 = off2 + BLOCK
176 tl.store(out_ptr + off0, y0, mask=off0 < N)
177 tl.store(out_ptr + off1, y1, mask=off1 < N)
178 tl.store(out_ptr + off2, y2, mask=off2 < N)
179 tl.store(out_ptr + off3, y3, mask=off3 < N)
182@libentry()
183@libtuner(
184 configs=[
185 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2),
186 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=2),
187 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3),
188 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3),
189 ],
190 key=["N"],
191)
192@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
193def fused_exponential_kernel_f32_small(
194 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
195):
196 philox_seed = philox_seed.to(tl.int64)
197 philox_offset = philox_offset.to(tl.int64)
198 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
199 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
201 pid = tl.program_id(0)
202 base_idx = pid * BLOCK * 4
203 c0_i = c0 + tl.arange(0, BLOCK)
204 z = c0_i * 0
206 r0, r1, r2, r3 = tl.philox(philox_seed, c0_i, c1, z, z)
208 y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus)
209 y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus)
210 y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus)
211 y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus)
213 off0 = base_idx + tl.arange(0, BLOCK)
214 off1 = off0 + BLOCK
215 off2 = off1 + BLOCK
216 off3 = off2 + BLOCK
218 tl.store(out_ptr + off0, y0, mask=off0 < N)
219 tl.store(out_ptr + off1, y1, mask=off1 < N)
220 tl.store(out_ptr + off2, y2, mask=off2 < N)
221 tl.store(out_ptr + off3, y3, mask=off3 < N)
224@libentry()
225@libtuner(
226 configs=[
227 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2),
228 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=3),
229 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3),
230 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3),
231 ],
232 key=["N"],
233)
234@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
235def fused_exponential_kernel_f64(
236 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr
237):
238 philox_seed = philox_seed.to(tl.int64)
239 philox_offset = philox_offset.to(tl.int64)
240 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
241 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
243 pid = tl.program_id(0)
244 base_idx = pid * BLOCK * 4
245 block_offset = tl.arange(0, BLOCK)
246 c0_base = c0 + block_offset
247 z = c0_base * 0
249 r0_0, r1_0, r2_0, r3_0 = tl.philox(philox_seed, c0_base, c1, z, z)
250 r0_1, r1_1, r2_1, r3_1 = tl.philox(philox_seed, c0_base + BLOCK, c1, z, z)
252 u0_0 = uint_to_uniform_float(paste_u64(r0_0, r2_0))
253 u1_0 = uint_to_uniform_float(paste_u64(r1_0, r3_0))
254 u0_1 = uint_to_uniform_float(paste_u64(r0_1, r2_1))
255 u1_1 = uint_to_uniform_float(paste_u64(r1_1, r3_1))
257 y0_0 = transform_exponential_f64(u0_0, inv_lambd, eps_minus)
258 y1_0 = transform_exponential_f64(u1_0, inv_lambd, eps_minus)
259 y0_1 = transform_exponential_f64(u0_1, inv_lambd, eps_minus)
260 y1_1 = transform_exponential_f64(u1_1, inv_lambd, eps_minus)
262 off0 = base_idx + tl.arange(0, BLOCK)
263 off1 = off0 + BLOCK
264 off2 = off1 + BLOCK
265 off3 = off2 + BLOCK
267 tl.store(out_ptr + off0, y0_0, mask=off0 < N)
268 tl.store(out_ptr + off1, y1_0, mask=off1 < N)
269 tl.store(out_ptr + off2, y0_1, mask=off2 < N)
270 tl.store(out_ptr + off3, y1_1, mask=off3 < N)
273def exponential_(x, lambd: float = 1.0, *, generator=None):
274 logger.debug("GEMS EXPONENTIAL_")
276 dtype = x.dtype
277 device = x.device
278 inplace = x.is_contiguous()
279 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
281 N = x.numel()
282 inv_lambd = 1.0 / lambd
283 eps_minus = -0.5 * torch.finfo(dtype).eps
285 out = x if inplace else torch.empty_like(x)
287 if dtype is torch.float64:
288 UNROLL = 2
289 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
290 increment = triton.cdiv(N, UNROLL)
291 philox_seed, philox_offset = philox_backend_seed_offset(
292 increment, generator=generator
293 )
294 with torch_device_fn.device(device):
295 fused_exponential_kernel_f64[grid](
296 out, N, inv_lambd, eps_minus, philox_seed, philox_offset
297 )
298 elif dtype in (torch.float16, torch.bfloat16) and N < 65536:
299 UNROLL = 4
300 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
301 increment = triton.cdiv(N, UNROLL)
302 philox_seed, philox_offset = philox_backend_seed_offset(
303 increment, generator=generator
304 )
305 with torch_device_fn.device(device):
306 fused_exponential_kernel_f32_small[grid](
307 out, N, inv_lambd, eps_minus, philox_seed, philox_offset
308 )
309 else:
310 UNROLL = 8
311 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
312 increment = triton.cdiv(N, UNROLL)
313 philox_seed, philox_offset = philox_backend_seed_offset(
314 increment, generator=generator
315 )
316 with torch_device_fn.device(device):
317 fused_exponential_kernel_f32_unroll8[grid](
318 out, N, inv_lambd, eps_minus, philox_seed, philox_offset
319 )
321 if not inplace:
322 x.copy_(out)
323 return x