Coverage for src/flag_gems/ops/randint.py: 53%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils.random_utils import philox_backend_seed_offset
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
17def randint_kernel(
18 out_ptr,
19 N,
20 high,
21 philox_seed,
22 philox_offset,
23 BLOCK: tl.constexpr,
24):
25 philox_seed = philox_seed.to(tl.int64)
26 philox_offset = philox_offset.to(tl.int64)
27 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
28 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
30 pid = tl.program_id(0)
31 i = pid * BLOCK + tl.arange(0, BLOCK)
32 c0 += i
33 z = c0 * 0
34 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z)
36 high_val = high.to(tl.uint64)
37 r0_mod = (r0 % high_val).to(out_ptr.dtype.element_ty)
38 r1_mod = (r1 % high_val).to(out_ptr.dtype.element_ty)
39 r2_mod = (r2 % high_val).to(out_ptr.dtype.element_ty)
40 r3_mod = (r3 % high_val).to(out_ptr.dtype.element_ty)
42 start = pid.to(tl.uint64) * BLOCK * 4
43 off0 = start + tl.arange(0, BLOCK)
44 off1 = off0 + BLOCK
45 off2 = off1 + BLOCK
46 off3 = off2 + BLOCK
48 tl.store(out_ptr + off0, r0_mod, mask=off0 < N)
49 tl.store(out_ptr + off1, r1_mod, mask=off1 < N)
50 tl.store(out_ptr + off2, r2_mod, mask=off2 < N)
51 tl.store(out_ptr + off3, r3_mod, mask=off3 < N)
54def randint(
55 high,
56 size,
57 *,
58 generator=None,
59 out=None,
60 dtype=torch.int64,
61 layout=None,
62 device=None,
63 requires_grad=False,
64 pin_memory=None,
65):
66 logger.debug("GEMS RANDINT")
67 if dtype is None:
68 dtype = torch.int64
70 if device is None:
71 device = torch.device("cpu")
73 if pin_memory is None:
74 pin_memory = False
76 if layout is None:
77 layout = torch.strided
79 N = 1
80 for s in size:
81 N *= s
83 BLOCK_SIZE = 128 # matches philox 4-wide output for efficient random generation
84 UNROLL = 4
85 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
86 increment = triton.cdiv(N, UNROLL)
88 result = torch.empty(size, device=device, dtype=dtype, pin_memory=pin_memory)
90 philox_seed, philox_offset = philox_backend_seed_offset(
91 increment, generator=generator
92 )
94 with torch_device_fn.device(device):
95 randint_kernel[grid](
96 result,
97 N,
98 high,
99 philox_seed,
100 philox_offset,
101 BLOCK_SIZE,
102 )
104 if out is not None:
105 out.copy_(result)
106 return out
107 return result