Coverage for src/flag_gems/ops/bernoulli_.py: 47%
45 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems import runtime
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
12from flag_gems.utils.shape_utils import volume
14logger = logging.getLogger(__name__)
17@triton.heuristics(runtime.get_heuristic_config("uniform"))
18@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "p"])
19def bernoulli_kernel(
20 out_ptr,
21 N,
22 p,
23 philox_seed,
24 philox_offset,
25 BLOCK: tl.constexpr,
26):
27 philox_seed = philox_seed.to(tl.int64)
28 philox_offset = philox_offset.to(tl.int64)
29 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
30 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
31 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
32 c0 += i4
33 _O = c0 * 0
34 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
36 # Convert random uint32 to uniform float in [0, 1)
37 u0 = uint_to_uniform_float(r0)
38 u1 = uint_to_uniform_float(r1)
39 u2 = uint_to_uniform_float(r2)
40 u3 = uint_to_uniform_float(r3)
42 # Bernoulli sampling: output 1.0 if random < p, else 0.0
43 y0 = tl.where(u0 < p, 1.0, 0.0)
44 y1 = tl.where(u1 < p, 1.0, 0.0)
45 y2 = tl.where(u2 < p, 1.0, 0.0)
46 y3 = tl.where(u3 < p, 1.0, 0.0)
48 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
49 off_1 = off_0 + BLOCK
50 off_2 = off_1 + BLOCK
51 off_3 = off_2 + BLOCK
53 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
54 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
55 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
56 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
59UNROLL = 4
62def bernoulli_(self, p=0.5, *, generator=None):
63 logger.debug("GEMS BERNOULLI_")
64 N = volume(self.shape)
65 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
67 increment = triton.cdiv(N, UNROLL)
68 philox_seed, philox_offset = philox_backend_seed_offset(
69 increment, generator=generator
70 )
71 with torch_device_fn.device(self.device):
72 bernoulli_kernel[grid_fn](self, N, p, philox_seed, philox_offset)
73 return self