Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/randn.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import device, get_heuristic_config, torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
13from flag_gems.utils.shape_utils import volume
16@triton.jit
17def high_precision_fast_sin_cos(x):
18 # Normalize to [-π, π]
19 two_pi = 6.283185307179586
20 x = x - two_pi * tl.floor(x / two_pi + 0.5)
21 x2 = x * x
23 # --- SIN: 7th-order minimax (x * P(x²)) ---
24 # Coefficients optimized for [-π, π], max error ~1.5e-9
25 s_c0 = 0.99999999999999999999
26 s_c1 = -0.16666666666666666654
27 s_c2 = 0.00833333333333332876
28 s_c3 = -0.00019841269841269616
29 s_c4 = 2.755731922398589e-6
30 s_c5 = -2.505210838544172e-8
32 sin_x = x * (
33 s_c0 + x2 * (s_c1 + x2 * (s_c2 + x2 * (s_c3 + x2 * (s_c4 + x2 * s_c5))))
34 )
36 # --- COS: 6th-order minimax (Q(x²)) ---
37 c_c0 = 1.0
38 c_c1 = -0.49999999999999999983
39 c_c2 = 0.04166666666666666636
40 c_c3 = -0.00138888888888888742
41 c_c4 = 2.4801587301587299e-5
42 c_c5 = -2.755731922398581e-7
44 cos_x = c_c0 + x2 * (c_c1 + x2 * (c_c2 + x2 * (c_c3 + x2 * (c_c4 + x2 * c_c5))))
46 return sin_x, cos_x
49@triton.jit
50def pair_uniform_to_normal_fast(u1, u2):
51 u1 = tl.maximum(1.0e-7, u1)
52 theta = 6.283185307179586 * u2
53 r = tl.sqrt(-2.0 * tl.log(u1))
54 sin_t, cos_t = high_precision_fast_sin_cos(theta)
55 return r * cos_t, r * sin_t
58device_ = device
59logger = logging.getLogger(__name__)
62@libentry()
63# @libtuner(
64# configs = [
65# triton.Config(kwargs={"BLOCK": 256}, num_stages=1, num_warps=1),
66# triton.Config(kwargs={"BLOCK": 512}, num_stages=1, num_warps=1),
67# triton.Config(kwargs={"BLOCK": 1024}, num_stages=1, num_warps=1),
68# triton.Config(kwargs={"BLOCK": 4096}, num_stages=1, num_warps=1),
69# triton.Config(kwargs={"BLOCK": 16384}, num_stages=1, num_warps=1),
70# triton.Config(kwargs={"BLOCK": 32768}, num_stages=1, num_warps=1),
71# ],
72# key=["N"],
73# strategy=["log"],
74# )
75@triton.heuristics(get_heuristic_config("randn"))
76@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
77def randn_kernel(
78 out_ptr,
79 N,
80 philox_seed,
81 philox_offset,
82 BLOCK: tl.constexpr,
83):
84 philox_seed = philox_seed.to(tl.int64)
85 philox_offset = philox_offset.to(tl.int64)
86 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
87 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
88 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
89 c0 += i4
90 _O = c0 * 0
91 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
92 r0 = uint_to_uniform_float(r0)
93 r1 = uint_to_uniform_float(r1)
94 r2 = uint_to_uniform_float(r2)
95 r3 = uint_to_uniform_float(r3)
96 n0, n1 = pair_uniform_to_normal_fast(r0, r1)
97 n2, n3 = pair_uniform_to_normal_fast(r2, r3)
98 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
99 off_1 = off_0 + BLOCK
100 off_2 = off_1 + BLOCK
101 off_3 = off_2 + BLOCK
103 tl.store(out_ptr + off_0, n0, mask=off_0 < N, eviction_policy="evict_first")
104 tl.store(out_ptr + off_1, n1, mask=off_1 < N, eviction_policy="evict_first")
105 tl.store(out_ptr + off_2, n2, mask=off_2 < N, eviction_policy="evict_first")
106 tl.store(out_ptr + off_3, n3, mask=off_3 < N, eviction_policy="evict_first")
109UNROLL = 4
112def randn(size, *, dtype=None, layout=None, device=None, pin_memory=None):
113 logger.debug("GEMS_TSINGMICRO RANDN")
114 if dtype is None:
115 dtype = torch.get_default_dtype()
116 if device is None:
117 device = torch.device(device_.name)
118 out = torch.empty(size, device=device, dtype=dtype)
119 N = volume(size)
120 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
121 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
122 # hence we cannot obtain the per thread offset as in Pytorch.
123 increment = triton.cdiv(N, UNROLL)
124 philox_seed, philox_offset = philox_backend_seed_offset(increment)
125 with torch_device_fn.device(device):
126 randn_kernel[grid_fn](out, N, philox_seed, philox_offset)
127 return out