Coverage for src/flag_gems/runtime/backend/_arm/ops/exponential_.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
14@triton.heuristics(runtime.get_heuristic_config("exponential_"))
15@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
16def fused_exponential_kernel(
17 out_ptr,
18 N,
19 is_double,
20 lambd,
21 eps,
22 philox_seed,
23 philox_offset,
24 BLOCK: tl.constexpr,
25):
26 philox_seed = philox_seed.to(tl.int64)
27 philox_offset = philox_offset.to(tl.int64)
28 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
29 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
30 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
31 c0 += i4
32 _O = c0 * 0
33 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
34 if is_double:
35 d0 = uint_to_uniform_float(paste_u64(r0, r2))
36 d1 = uint_to_uniform_float(paste_u64(r1, r3))
37 y0 = transform_exponential(d0, lambd, eps)
38 y1 = transform_exponential(d1, lambd, eps)
39 UNROLL = 2
40 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
41 off_0 = start + tl.arange(0, BLOCK)
42 off_1 = off_0 + BLOCK
43 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
44 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
45 else:
46 f0 = uint_to_uniform_float(r0)
47 f1 = uint_to_uniform_float(r1)
48 f2 = uint_to_uniform_float(r2)
49 f3 = uint_to_uniform_float(r3)
50 y0 = transform_exponential(f0, lambd, eps)
51 y1 = transform_exponential(f1, lambd, eps)
52 y2 = transform_exponential(f2, lambd, eps)
53 y3 = transform_exponential(f3, lambd, eps)
54 UNROLL = 4
55 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
56 off_0 = start + tl.arange(0, BLOCK)
57 off_1 = off_0 + BLOCK
58 off_2 = off_1 + BLOCK
59 off_3 = off_2 + BLOCK
60 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
61 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
62 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
63 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
66@triton.jit
67def paste_u64(hi: tl.uint32, lo: tl.uint32):
68 hi = hi.to(tl.uint64) << 32
69 x = hi | lo.to(tl.uint64)
70 return x
73@triton.jit
74def transform_exponential(u, lambd, eps):
75 eps1 = -0.5 * eps
76 is_min = u >= 1.0 + eps1
77 log = tl.where(is_min, eps1, tl.math.log(u))
78 v = -1.0 / lambd * log
79 return v
82def exponential_(x, lambd: float = 1.0, *, gen=None):
83 logging.debug("GEMS EXPONENTIAL_")
84 dtype = x.dtype
85 device = x.device
86 inplace = x.is_contiguous()
87 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
88 is_double = dtype in (torch.float64,)
89 UNROLL = 2 if is_double else 4
90 N = x.numel()
91 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
92 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
93 # hence we cannot obtain the per thread offset as in Pytorch.
94 increment = triton.cdiv(N, UNROLL)
95 philox_seed, philox_offset = philox_backend_seed_offset(increment)
96 eps = torch.finfo(dtype).eps
97 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
98 # with torch_device_fn.device(device):
99 fused_exponential_kernel[grid_fn](
100 x_, N, is_double, lambd, eps, philox_seed, philox_offset
101 )
102 if not inplace:
103 x.copy_(x_)
104 return x