Coverage for src/flag_gems/utils/random_utils.py: 45%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5import flag_gems 

6from flag_gems.runtime import torch_device_fn 

7 

8_SPACEMIT_CPU_GENERATOR = None 

9 

10try: 

11 uint_to_uniform_float = tl.uint_to_uniform_float 

12except AttributeError: 

13 # Copied from triton.language package for compatibility 

14 @triton.jit 

15 def uint_to_uniform_float(x): 

16 """ 

17 Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). 

18 """ 

19 # TODO: fix frontend issues and cleanup 

20 # conditions can be simplified 

21 # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) 

22 if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32): 

23 # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) 

24 x = x.to(tl.int32, bitcast=True) 

25 scale = 4.6566127342e-10 

26 else: 

27 tl.static_assert( 

28 tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64) 

29 ) 

30 x = x.to(tl.int64, bitcast=True) 

31 scale = 1.0842020432385337e-19 

32 x = tl.where(x < 0, -x - 1, x) 

33 return x * scale 

34 

35 

36# This function is roughly a python wrapper of CUDAGeneratorImpl::philox_cuda_state in Pytorch. 

37# https://github.com/pytorch/pytorch/blob/8a4597980c2692b73f35fb3c7145eaeaf2273e77/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp#L452 

38# It returns the current state of the default Philox RNG in seed and offset and 

39# updates the next offset by adding `increment`. 

40def philox_backend_seed_offset(increment, generator=None): 

41 global _SPACEMIT_CPU_GENERATOR 

42 

43 if generator is None: 

44 device = torch_device_fn.current_device() 

45 # SPACEMIT uses CPU generator 

46 if flag_gems.vendor_name == "spacemit": 

47 if _SPACEMIT_CPU_GENERATOR is None: 

48 _SPACEMIT_CPU_GENERATOR = torch.Generator(device="cpu") 

49 generator = _SPACEMIT_CPU_GENERATOR 

50 else: 

51 generator = torch_device_fn.default_generators[device] 

52 state_copy = generator.get_state() 

53 # TODO[kunlunxin]: we will upgrade torch version in 2025.04 

54 if flag_gems.vendor_name in ("kunlunxin", "aipu"): 

55 c0, c1 = state_copy.view(torch.int64)[-2], state_copy.view(torch.int64)[-1] 

56 elif flag_gems.vendor_name == "spacemit": 

57 state_view = state_copy.view(torch.int64) 

58 c0 = state_view[-2].item() 

59 c1 = state_view[-1].item() 

60 else: 

61 c0, c1 = state_copy.view(torch.int64) 

62 

63 seed, offset = int(c0), int(c1) 

64 increment = (increment + 3) // 4 * 4 

65 c1 += increment 

66 # get_state returns a new tensor, so it needs set_state to update the actual generator state. 

67 generator.set_state(state_copy) 

68 return seed, offset 

69 

70 

71def set_philox_state(seed, offset, device=None): 

72 global _SPACEMIT_CPU_GENERATOR 

73 assert offset % 4 == 0 

74 if flag_gems.vendor_name == "spacemit": 

75 if _SPACEMIT_CPU_GENERATOR is None: 

76 _SPACEMIT_CPU_GENERATOR = torch.Generator(device="cpu") 

77 gen = _SPACEMIT_CPU_GENERATOR 

78 # CPU mt19937 state: write seed/offset into the last two int64 slots 

79 # (matching the read positions in philox_backend_seed_offset) 

80 state_copy = gen.get_state() 

81 state_view = state_copy.view(torch.int64) 

82 state_view[-2] = seed 

83 state_view[-1] = offset 

84 gen.set_state(state_copy) 

85 else: 

86 device = device or torch_device_fn.current_device() 

87 gen = torch_device_fn.default_generators[device] 

88 state_copy = gen.get_state() 

89 state_copy.view(torch.int64)[0] = seed 

90 state_copy.view(torch.int64)[1] = offset 

91 gen.set_state(state_copy) 

92 return 

93 

94 

95def per_thread_offset(N, num_blocks, num_warps, warp_threads=32): 

96 block_threads = num_warps * warp_threads 

97 max_threads = num_blocks * block_threads 

98 offset = (N + max_threads - 1) // max_threads 

99 return offset 

100 

101 

102@triton.jit 

103def uniform(seed, philox_offset, offset): 

104 seed = seed.to(tl.int64) 

105 philox_offset = philox_offset.to(tl.int64) 

106 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

107 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

108 i4 = offset 

109 c0 += i4 

110 _O = c0 * 0 

111 r0, r1, r2, r3 = tl.philox(seed, c0, c1, _O, _O) 

112 r0 = uint_to_uniform_float(r0) 

113 r1 = uint_to_uniform_float(r1) 

114 r2 = uint_to_uniform_float(r2) 

115 r3 = uint_to_uniform_float(r3) 

116 return r0, r1, r2, r3