Coverage for src/flag_gems/ops/randint.py: 53%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils.random_utils import philox_backend_seed_offset 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

17def randint_kernel( 

18 out_ptr, 

19 N, 

20 high, 

21 philox_seed, 

22 philox_offset, 

23 BLOCK: tl.constexpr, 

24): 

25 philox_seed = philox_seed.to(tl.int64) 

26 philox_offset = philox_offset.to(tl.int64) 

27 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

28 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

29 

30 pid = tl.program_id(0) 

31 i = pid * BLOCK + tl.arange(0, BLOCK) 

32 c0 += i 

33 z = c0 * 0 

34 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) 

35 

36 high_val = high.to(tl.uint64) 

37 r0_mod = (r0 % high_val).to(out_ptr.dtype.element_ty) 

38 r1_mod = (r1 % high_val).to(out_ptr.dtype.element_ty) 

39 r2_mod = (r2 % high_val).to(out_ptr.dtype.element_ty) 

40 r3_mod = (r3 % high_val).to(out_ptr.dtype.element_ty) 

41 

42 start = pid.to(tl.uint64) * BLOCK * 4 

43 off0 = start + tl.arange(0, BLOCK) 

44 off1 = off0 + BLOCK 

45 off2 = off1 + BLOCK 

46 off3 = off2 + BLOCK 

47 

48 tl.store(out_ptr + off0, r0_mod, mask=off0 < N) 

49 tl.store(out_ptr + off1, r1_mod, mask=off1 < N) 

50 tl.store(out_ptr + off2, r2_mod, mask=off2 < N) 

51 tl.store(out_ptr + off3, r3_mod, mask=off3 < N) 

52 

53 

54def randint( 

55 high, 

56 size, 

57 *, 

58 generator=None, 

59 out=None, 

60 dtype=torch.int64, 

61 layout=None, 

62 device=None, 

63 requires_grad=False, 

64 pin_memory=None, 

65): 

66 logger.debug("GEMS RANDINT") 

67 if dtype is None: 

68 dtype = torch.int64 

69 

70 if device is None: 

71 device = torch.device("cpu") 

72 

73 if pin_memory is None: 

74 pin_memory = False 

75 

76 if layout is None: 

77 layout = torch.strided 

78 

79 N = 1 

80 for s in size: 

81 N *= s 

82 

83 BLOCK_SIZE = 128 # matches philox 4-wide output for efficient random generation 

84 UNROLL = 4 

85 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

86 increment = triton.cdiv(N, UNROLL) 

87 

88 result = torch.empty(size, device=device, dtype=dtype, pin_memory=pin_memory) 

89 

90 philox_seed, philox_offset = philox_backend_seed_offset( 

91 increment, generator=generator 

92 ) 

93 

94 with torch_device_fn.device(device): 

95 randint_kernel[grid]( 

96 result, 

97 N, 

98 high, 

99 philox_seed, 

100 philox_offset, 

101 BLOCK_SIZE, 

102 ) 

103 

104 if out is not None: 

105 out.copy_(result) 

106 return out 

107 return result