Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/exponential_.py: 0%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.xpu.libdevice import log2
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils.random_utils import (
10 philox_backend_seed_offset,
11 uint_to_uniform_float,
12)
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16CLUSTER_NUM = 12
19def heur_block(args):
20 N = args.get("N", 0)
21 if N <= 4096:
22 return 256
23 elif N <= 65536:
24 return 512
25 else:
26 return 1024
29def heur_num_warps(args):
30 N = args.get("N", 0)
31 if N <= 4096:
32 return 4
33 elif N <= 65536:
34 return 8
35 else:
36 return 16
39@triton.heuristics(
40 {
41 "BLOCK": heur_block,
42 "num_warps": heur_num_warps,
43 }
44)
45@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
46def fused_exponential_kernel(
47 out_ptr,
48 N,
49 is_double: tl.constexpr,
50 lambd,
51 eps,
52 philox_seed,
53 philox_offset,
54 BLOCK: tl.constexpr,
55):
56 philox_seed = philox_seed.to(tl.int64)
57 philox_offset = philox_offset.to(tl.int64)
58 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)
59 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)
60 i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
61 c0 += i4
62 _O = c0 * 0
63 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)
64 if is_double:
65 d0 = uint_to_uniform_float(paste_u64(r0, r2))
66 d1 = uint_to_uniform_float(paste_u64(r1, r3))
67 y0 = transform_exponential(d0, lambd, eps)
68 y1 = transform_exponential(d1, lambd, eps)
69 UNROLL = 2
70 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
71 off_0 = start + tl.arange(0, BLOCK)
72 off_1 = off_0 + BLOCK
73 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
74 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
75 else:
76 f0 = uint_to_uniform_float(r0)
77 f1 = uint_to_uniform_float(r1)
78 f2 = uint_to_uniform_float(r2)
79 f3 = uint_to_uniform_float(r3)
80 y0 = transform_exponential(f0, lambd, eps)
81 y1 = transform_exponential(f1, lambd, eps)
82 y2 = transform_exponential(f2, lambd, eps)
83 y3 = transform_exponential(f3, lambd, eps)
84 UNROLL = 4
85 start = tl.program_id(0).to(tl.uint64) * BLOCK * UNROLL
86 off_0 = start + tl.arange(0, BLOCK)
87 off_1 = off_0 + BLOCK
88 off_2 = off_1 + BLOCK
89 off_3 = off_2 + BLOCK
90 tl.store(out_ptr + off_0, y0, mask=off_0 < N, eviction_policy="evict_first")
91 tl.store(out_ptr + off_1, y1, mask=off_1 < N, eviction_policy="evict_first")
92 tl.store(out_ptr + off_2, y2, mask=off_2 < N, eviction_policy="evict_first")
93 tl.store(out_ptr + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")
96@triton.jit
97def paste_u64(hi: tl.uint32, lo: tl.uint32):
98 hi = hi.to(tl.uint64) << 32
99 x = hi | lo.to(tl.uint64)
100 return x
103@triton.jit
104def transform_exponential(u, lambd, eps):
105 eps1 = -0.5 * eps
106 is_min = u >= 1.0 + eps1
107 trans_scale = 1.0 / 1.4426950408889634
108 log = tl.where(is_min, eps1, log2(u) * trans_scale)
109 v = -1.0 / lambd * log
110 return v
113def exponential_(x, lambd: float = 1.0, *, generator=None):
114 logger.debug("GEMS_KUNLUNXIN EXPONENTIAL_")
115 dtype = x.dtype
116 device = x.device
117 inplace = x.is_contiguous()
118 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
119 is_double = dtype in (torch.float64,)
120 UNROLL = 2 if is_double else 4
121 N = x.numel()
122 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
123 # (TODO) Using Triton autotuner makes kernel parameters opaque to the caller,
124 # hence we cannot obtain the per thread offset as in Pytorch.
125 increment = triton.cdiv(N, UNROLL)
126 philox_seed, philox_offset = philox_backend_seed_offset(
127 increment, generator=generator
128 )
129 eps = torch.finfo(dtype).eps
130 x_ = x if inplace else torch.empty(x.size(), dtype=dtype, device=device)
131 with torch_device_fn.device(device):
132 fused_exponential_kernel[grid_fn](
133 x_, N, is_double, lambd, eps, philox_seed, philox_offset
134 )
135 if not inplace:
136 x.copy_(x_)
137 return x