Coverage for src/flag_gems/utils/random_utils.py: 45%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import torch
2import triton
3import triton.language as tl
5import flag_gems
6from flag_gems.runtime import torch_device_fn
8_SPACEMIT_CPU_GENERATOR = None
10try:
11 uint_to_uniform_float = tl.uint_to_uniform_float
12except AttributeError:
13 # Copied from triton.language package for compatibility
14 @triton.jit
15 def uint_to_uniform_float(x):
16 """
17 Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
18 """
19 # TODO: fix frontend issues and cleanup
20 # conditions can be simplified
21 # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
22 if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
23 # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
24 x = x.to(tl.int32, bitcast=True)
25 scale = 4.6566127342e-10
26 else:
27 tl.static_assert(
28 tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)
29 )
30 x = x.to(tl.int64, bitcast=True)
31 scale = 1.0842020432385337e-19
32 x = tl.where(x < 0, -x - 1, x)
33 return x * scale
36# This function is roughly a python wrapper of CUDAGeneratorImpl::philox_cuda_state in Pytorch.
37# https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452
38# It returns the current state of the default Philox RNG in seed and offset and
39# updates the next offset by adding `increment`.
40def philox_backend_seed_offset(increment, generator=None):
41 global _SPACEMIT_CPU_GENERATOR
43 if generator is None:
44 device = torch_device_fn.current_device()
45 # SPACEMIT uses CPU generator
46 if flag_gems.vendor_name == "spacemit":
47 if _SPACEMIT_CPU_GENERATOR is None:
48 _SPACEMIT_CPU_GENERATOR = torch.Generator(device="cpu")
49 generator = _SPACEMIT_CPU_GENERATOR
50 else:
51 generator = torch_device_fn.default_generators[device]
52 state_copy = generator.get_state()
53 # TODO[kunlunxin]: we will upgrade torch version in 2025.04
54 if flag_gems.vendor_name in ("kunlunxin", "aipu"):
55 c0, c1 = state_copy.view(torch.int64)[-2], state_copy.view(torch.int64)[-1]
56 elif flag_gems.vendor_name == "spacemit":
57 state_view = state_copy.view(torch.int64)
58 c0 = state_view[-2].item()
59 c1 = state_view[-1].item()
60 else:
61 c0, c1 = state_copy.view(torch.int64)
63 seed, offset = int(c0), int(c1)
64 increment = (increment + 3) // 4 * 4
65 c1 += increment
66 # get_state returns a new tensor, so it needs set_state to update the actual generator state.
67 generator.set_state(state_copy)
68 return seed, offset
71def set_philox_state(seed, offset, device=None):
72 global _SPACEMIT_CPU_GENERATOR
73 assert offset % 4 == 0
74 if flag_gems.vendor_name == "spacemit":
75 if _SPACEMIT_CPU_GENERATOR is None:
76 _SPACEMIT_CPU_GENERATOR = torch.Generator(device="cpu")
77 gen = _SPACEMIT_CPU_GENERATOR
78 # CPU mt19937 state: write seed/offset into the last two int64 slots
79 # (matching the read positions in philox_backend_seed_offset)
80 state_copy = gen.get_state()
81 state_view = state_copy.view(torch.int64)
82 state_view[-2] = seed
83 state_view[-1] = offset
84 gen.set_state(state_copy)
85 else:
86 device = device or torch_device_fn.current_device()
87 gen = torch_device_fn.default_generators[device]
88 state_copy = gen.get_state()
89 state_copy.view(torch.int64)[0] = seed
90 state_copy.view(torch.int64)[1] = offset
91 gen.set_state(state_copy)
92 return
95def per_thread_offset(N, num_blocks, num_warps, warp_threads=32):
96 block_threads = num_warps * warp_threads
97 max_threads = num_blocks * block_threads
98 offset = (N + max_threads - 1) // max_threads
99 return offset
102@triton.jit
103def uniform(seed, philox_offset, offset):
104 seed = seed.to(tl.int64)
105 philox_offset = philox_offset.to(tl.int64)
106 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
107 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
108 i4 = offset
109 c0 += i4
110 _O = c0 * 0
111 r0, r1, r2, r3 = tl.philox(seed, c0, c1, _O, _O)
112 r0 = uint_to_uniform_float(r0)
113 r1 = uint_to_uniform_float(r1)
114 r2 = uint_to_uniform_float(r2)
115 r3 = uint_to_uniform_float(r3)
116 return r0, r1, r2, r3