Coverage for src/flag_gems/ops/cauchy.py: 53%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.random_utils import (
9 philox_backend_seed_offset,
10 uint_to_uniform_float,
11)
12from flag_gems.utils.shape_utils import volume
14logger = logging.getLogger(__name__)
16PI = tl.constexpr(3.14159265358979323846)
19@triton.jit
20def uniform_to_cauchy(u, median, sigma):
21 # Transform uniform [0, 1) to Cauchy using inverse CDF
22 # X = median + sigma * tan(PI * (u - 0.5))
23 # Clamp u to avoid tan(±PI/2) which is undefined
24 u = tl.maximum(1.0e-7, u)
25 u = tl.minimum(1.0 - 1.0e-7, u)
26 # tan(x) = sin(x) / cos(x)
27 angle = PI * (u - 0.5)
28 return median + sigma * (tl.sin(angle) / tl.cos(angle))
31# @triton.heuristics(runtime.get_heuristic_config("cauchy"))
32configs = [
33 triton.Config({"BLOCK": 256}, num_warps=8, num_stages=2),
34 triton.Config({"BLOCK": 512}, num_warps=4, num_stages=2),
35 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3),
36 triton.Config({"BLOCK": 1024}, num_warps=4, num_stages=2),
37 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=3),
38 triton.Config({"BLOCK": 1024}, num_warps=8, num_stages=4),
39]
42@triton.autotune(configs=configs, key=["N"])
43@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "median", "sigma"])
44def cauchy_kernel(
45 out_ptr,
46 N,
47 median,
48 sigma,
49 philox_seed,
50 philox_offset,
51 BLOCK: tl.constexpr,
52):
53 philox_seed = philox_seed.to(tl.int64)
54 philox_offset = philox_offset.to(tl.int64)
55 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
56 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
57 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
58 c0 += i4
59 _O = c0 * 0
60 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
61 r0 = uint_to_uniform_float(r0)
62 r1 = uint_to_uniform_float(r1)
63 r2 = uint_to_uniform_float(r2)
64 r3 = uint_to_uniform_float(r3)
65 c0 = uniform_to_cauchy(r0, median, sigma)
66 c1 = uniform_to_cauchy(r1, median, sigma)
67 c2 = uniform_to_cauchy(r2, median, sigma)
68 c3 = uniform_to_cauchy(r3, median, sigma)
69 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
70 off_1 = off_0 + BLOCK
71 off_2 = off_1 + BLOCK
72 off_3 = off_2 + BLOCK
74 tl.store(out_ptr + off_0, c0, mask=off_0 < N, eviction_policy="evict_first")
75 tl.store(out_ptr + off_1, c1, mask=off_1 < N, eviction_policy="evict_first")
76 tl.store(out_ptr + off_2, c2, mask=off_2 < N, eviction_policy="evict_first")
77 tl.store(out_ptr + off_3, c3, mask=off_3 < N, eviction_policy="evict_first")
80UNROLL = 4
83def cauchy_(self, median=0, sigma=1, *, generator=None):
84 """
85 In-place Cauchy distribution sampler.
87 Fills self with elements drawn from the Cauchy distribution:
88 f(x) = 1 / (π * sigma * (1 + ((x - median) / sigma)^2))
90 Uses inverse transform sampling: X = median + sigma * tan(π * (U - 0.5))
91 where U ~ Uniform(0, 1).
92 """
93 logger.debug("GEMS CAUCHY_")
94 shape = self.shape
95 device = self.device
96 N = volume(shape)
97 if N == 0:
98 return self
99 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
100 increment = triton.cdiv(N, UNROLL)
101 philox_seed, philox_offset = philox_backend_seed_offset(
102 increment, generator=generator
103 )
104 with torch_device_fn.device(device):
105 cauchy_kernel[grid_fn](self, N, median, sigma, philox_seed, philox_offset)
106 return self
109def cauchy(self, median=0, sigma=1, *, generator=None):
110 """
111 Out-of-place Cauchy distribution sampler.
113 Returns a new tensor with elements drawn from the Cauchy distribution.
114 """
115 logger.debug("GEMS CAUCHY")
116 out = torch.empty_like(self)
117 cauchy_(out, median, sigma, generator=generator)
118 return out