Coverage for src/flag_gems/ops/randint_like.py: 51%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import ( 

11 philox_backend_seed_offset, 

12 uint_to_uniform_float, 

13) 

14 

15logger = logging.getLogger(__name__) 

16 

17UNROLL = 4 

18 

19 

20@triton.heuristics(runtime.get_heuristic_config("rand")) 

21@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) 

22def randint_kernel( 

23 out_ptr, 

24 N, 

25 high, 

26 philox_seed, 

27 philox_offset, 

28 BLOCK: tl.constexpr, 

29): 

30 philox_seed = philox_seed.to(tl.int64) 

31 philox_offset = philox_offset.to(tl.int64) 

32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

34 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

35 c0 += i4 

36 _O = c0 * 0 

37 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) 

38 

39 # Convert to uniform float in [0, 1) 

40 u0 = uint_to_uniform_float(r0) 

41 u1 = uint_to_uniform_float(r1) 

42 u2 = uint_to_uniform_float(r2) 

43 u3 = uint_to_uniform_float(r3) 

44 

45 # Scale to [0, high) and convert to int32 

46 high_f = high.to(tl.float32) 

47 i0 = (u0 * high_f).to(tl.int32) 

48 i1 = (u1 * high_f).to(tl.int32) 

49 i2 = (u2 * high_f).to(tl.int32) 

50 i3 = (u3 * high_f).to(tl.int32) 

51 

52 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) 

53 off_1 = off_0 + BLOCK 

54 off_2 = off_1 + BLOCK 

55 off_3 = off_2 + BLOCK 

56 

57 tl.store(out_ptr + off_0, i0, mask=off_0 < N, eviction_policy="evict_first") 

58 tl.store(out_ptr + off_1, i1, mask=off_1 < N, eviction_policy="evict_first") 

59 tl.store(out_ptr + off_2, i2, mask=off_2 < N, eviction_policy="evict_first") 

60 tl.store(out_ptr + off_3, i3, mask=off_3 < N, eviction_policy="evict_first") 

61 

62 

63def randint_like( 

64 self, 

65 high, 

66 *, 

67 dtype=None, 

68 layout=None, 

69 device=None, 

70 pin_memory=None, 

71 memory_format=None, 

72): 

73 logger.debug("GEMS RANDINT_LIKE") 

74 if device is None: 

75 device = self.device 

76 if dtype is None: 

77 dtype = self.dtype 

78 out = torch.empty_like(self, device=device, dtype=dtype) 

79 N = self.numel() 

80 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

81 increment = triton.cdiv(N, UNROLL) 

82 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

83 with torch_device_fn.device(self.device): 

84 randint_kernel[grid_fn](out, N, high, philox_seed, philox_offset) 

85 return out