Coverage for src/flag_gems/ops/randint_like.py: 51%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.random_utils import (
11 philox_backend_seed_offset,
12 uint_to_uniform_float,
13)
15logger = logging.getLogger(__name__)
17UNROLL = 4
20@triton.heuristics(runtime.get_heuristic_config("rand"))
21@triton.jit(do_not_specialize=["philox_seed", "philox_offset"])
22def randint_kernel(
23 out_ptr,
24 N,
25 high,
26 philox_seed,
27 philox_offset,
28 BLOCK: tl.constexpr,
29):
30 philox_seed = philox_seed.to(tl.int64)
31 philox_offset = philox_offset.to(tl.int64)
32 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
33 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
34 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
35 c0 += i4
36 _O = c0 * 0
37 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
39 # Convert to uniform float in [0, 1)
40 u0 = uint_to_uniform_float(r0)
41 u1 = uint_to_uniform_float(r1)
42 u2 = uint_to_uniform_float(r2)
43 u3 = uint_to_uniform_float(r3)
45 # Scale to [0, high) and convert to int32
46 high_f = high.to(tl.float32)
47 i0 = (u0 * high_f).to(tl.int32)
48 i1 = (u1 * high_f).to(tl.int32)
49 i2 = (u2 * high_f).to(tl.int32)
50 i3 = (u3 * high_f).to(tl.int32)
52 off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)
53 off_1 = off_0 + BLOCK
54 off_2 = off_1 + BLOCK
55 off_3 = off_2 + BLOCK
57 tl.store(out_ptr + off_0, i0, mask=off_0 < N, eviction_policy="evict_first")
58 tl.store(out_ptr + off_1, i1, mask=off_1 < N, eviction_policy="evict_first")
59 tl.store(out_ptr + off_2, i2, mask=off_2 < N, eviction_policy="evict_first")
60 tl.store(out_ptr + off_3, i3, mask=off_3 < N, eviction_policy="evict_first")
63def randint_like(
64 self,
65 high,
66 *,
67 dtype=None,
68 layout=None,
69 device=None,
70 pin_memory=None,
71 memory_format=None,
72):
73 logger.debug("GEMS RANDINT_LIKE")
74 if device is None:
75 device = self.device
76 if dtype is None:
77 dtype = self.dtype
78 out = torch.empty_like(self, device=device, dtype=dtype)
79 N = self.numel()
80 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
81 increment = triton.cdiv(N, UNROLL)
82 philox_seed, philox_offset = philox_backend_seed_offset(increment)
83 with torch_device_fn.device(self.device):
84 randint_kernel[grid_fn](out, N, high, philox_seed, philox_offset)
85 return out