Coverage for src/flag_gems/fused/chunk_gated_delta_rule.py: 77%
107 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.fused.FLA import chunk_gated_delta_rule_fwd
6from flag_gems.fused.FLA.chunk_gated_delta_direct import (
7 can_use_chunk_gated_delta_rule_direct,
8 chunk_gated_delta_rule_direct_fwd,
9)
10from flag_gems.utils import libentry
13@libentry()
14@triton.jit
15def _l2_normalize_last_dim_kernel(
16 x,
17 out,
18 n_rows: tl.constexpr,
19 H: tl.constexpr,
20 K: tl.constexpr,
21 stride_x_b: tl.constexpr,
22 stride_x_t: tl.constexpr,
23 stride_x_h: tl.constexpr,
24 stride_x_k: tl.constexpr,
25 BLOCK_K: tl.constexpr,
26):
27 row = tl.program_id(0)
28 offs = tl.arange(0, BLOCK_K)
29 mask = offs < K
31 h = row % H
32 row_bt = row // H
33 t = row_bt % n_rows
34 b = row_bt // n_rows
35 x_base = x + b * stride_x_b + t * stride_x_t + h * stride_x_h
36 values = tl.load(x_base + offs * stride_x_k, mask=mask, other=0.0).to(tl.float32)
37 inv_norm = 1.0 / tl.maximum(tl.sqrt(tl.sum(values * values, axis=0)), 1e-6)
38 tl.store(out + row * K + offs, values * inv_norm, mask=mask)
41def _as_seq_first(
42 x: torch.Tensor,
43 *,
44 name: str,
45 head_first: bool,
46 expected_ndim: int,
47) -> torch.Tensor:
48 if not isinstance(x, torch.Tensor):
49 raise TypeError(f"{name} must be a torch.Tensor")
50 if x.ndim != expected_ndim:
51 raise ValueError(f"{name} must be {expected_ndim}D, got shape {tuple(x.shape)}")
52 if head_first:
53 return x.transpose(1, 2)
54 return x
57def _validate_inputs(
58 q: torch.Tensor,
59 k: torch.Tensor,
60 v: torch.Tensor,
61 beta: torch.Tensor,
62 g: torch.Tensor,
63 initial_state: torch.Tensor | None,
64 cu_seqlens: torch.Tensor | None,
65) -> None:
66 B, T, Hg, K = q.shape
67 Bk, Tk, Hk, Kk = k.shape
68 Bv, Tv, H, V = v.shape
70 tensors = {"k": k, "v": v, "beta": beta, "g": g}
71 for name, tensor in tensors.items():
72 if tensor.device != q.device:
73 raise ValueError(f"{name} must be on the same device as q")
74 if tensor.dtype != q.dtype:
75 raise ValueError(f"{name} must have the same dtype as q")
77 if (Bk, Tk, Hk, Kk) != (B, T, Hg, K):
78 raise ValueError(
79 "q and k must have matching [B, T, Hq, K] shapes after layout conversion"
80 )
81 if (Bv, Tv) != (B, T):
82 raise ValueError("v must have matching B and T dimensions with q/k")
83 if H % Hg != 0:
84 raise ValueError("the q/k head count must divide the v head count")
85 if beta.shape != (B, T, H):
86 raise ValueError(
87 f"beta must have shape {(B, T, H)} after layout conversion, got {tuple(beta.shape)}"
88 )
89 if g.shape != (B, T, H):
90 raise ValueError(
91 f"g must have shape {(B, T, H)} after layout conversion, got {tuple(g.shape)}"
92 )
93 if cu_seqlens is not None:
94 if not isinstance(cu_seqlens, torch.Tensor):
95 raise TypeError("cu_seqlens must be a torch.Tensor")
96 if cu_seqlens.ndim != 1:
97 raise ValueError("cu_seqlens must be a 1D tensor")
98 if cu_seqlens.dtype != torch.long:
99 raise ValueError("cu_seqlens must have dtype torch.long")
100 if cu_seqlens.device != q.device:
101 raise ValueError("cu_seqlens must be on the same device as q")
102 if B != 1:
103 raise ValueError("cu_seqlens packed varlen inputs must use B=1")
105 if initial_state is not None:
106 if initial_state.device != q.device:
107 raise ValueError("initial_state must be on the same device as q")
108 if initial_state.dtype != q.dtype:
109 raise ValueError("initial_state must have the same dtype as q")
110 expected_n = B if cu_seqlens is None else cu_seqlens.numel() - 1
111 expected_shape = (expected_n, H, K, V)
112 if initial_state.shape != expected_shape:
113 raise ValueError(
114 f"initial_state must have shape {expected_shape}, got {tuple(initial_state.shape)}"
115 )
118def _direct_contiguous(x: torch.Tensor) -> torch.Tensor:
119 return x if x.is_contiguous() else x.contiguous()
122def _l2_normalize_last_dim(x: torch.Tensor) -> torch.Tensor:
123 B, T, H, K = x.shape
124 out = torch.empty_like(x, memory_format=torch.contiguous_format)
125 block_k = triton.next_power_of_2(K)
126 _l2_normalize_last_dim_kernel[(B * T * H,)](
127 x=x,
128 out=out,
129 n_rows=T,
130 H=H,
131 K=K,
132 stride_x_b=x.stride(0),
133 stride_x_t=x.stride(1),
134 stride_x_h=x.stride(2),
135 stride_x_k=x.stride(3),
136 BLOCK_K=block_k,
137 )
138 return out
141def chunk_gated_delta_rule(
142 q: torch.Tensor,
143 k: torch.Tensor,
144 v: torch.Tensor,
145 beta: torch.Tensor,
146 g: torch.Tensor,
147 BT: int = 64,
148 initial_state: torch.Tensor | None = None,
149 output_final_state: bool = False,
150 cu_seqlens: torch.Tensor | None = None,
151 head_first: bool = True,
152 scale: float | None = None,
153 use_qk_l2norm_in_kernel: bool = False,
154) -> tuple[torch.Tensor, torch.Tensor | None]:
155 """Public wrapper for the chunk gated delta rule forward operator.
157 Inputs follow common FLA layouts:
158 - ``head_first=True``: q/k/v are ``[B, H, T, D]`` and beta/g are ``[B, H, T]``.
159 - ``head_first=False``: q/k/v are ``[B, T, H, D]`` and beta/g are ``[B, T, H]``.
161 q/k may use fewer heads than v when the q/k head count divides the v head count.
162 """
163 if BT != 64:
164 raise ValueError("chunk_gated_delta_rule currently supports only BT=64")
166 q_seq = _as_seq_first(q, name="q", head_first=head_first, expected_ndim=4)
167 k_seq = _as_seq_first(k, name="k", head_first=head_first, expected_ndim=4)
168 v_seq = _as_seq_first(v, name="v", head_first=head_first, expected_ndim=4)
169 beta_seq = _as_seq_first(beta, name="beta", head_first=head_first, expected_ndim=3)
170 g_seq = _as_seq_first(g, name="g", head_first=head_first, expected_ndim=3)
172 _validate_inputs(q_seq, k_seq, v_seq, beta_seq, g_seq, initial_state, cu_seqlens)
174 if scale is None:
175 scale = k_seq.shape[-1] ** -0.5
177 B, T, Hg, K = q_seq.shape
178 H, V = v_seq.shape[2], v_seq.shape[3]
179 if (
180 initial_state is None
181 and cu_seqlens is None
182 and T <= 128
183 and K <= 128
184 and V <= 128
185 and H % Hg == 0
186 ):
187 q_direct = _direct_contiguous(q_seq)
188 k_direct = _direct_contiguous(k_seq)
189 v_direct = _direct_contiguous(v_seq)
190 g_direct = _direct_contiguous(g_seq)
191 beta_direct = _direct_contiguous(beta_seq)
192 if can_use_chunk_gated_delta_rule_direct(
193 q=q_direct,
194 k=k_direct,
195 v=v_direct,
196 g=g_direct,
197 beta=beta_direct,
198 initial_state=None,
199 cu_seqlens=None,
200 ):
201 o, final_state = chunk_gated_delta_rule_direct_fwd(
202 q=q_direct,
203 k=k_direct,
204 v=v_direct,
205 g=g_direct,
206 beta=beta_direct,
207 scale=float(scale),
208 initial_state=None,
209 output_final_state=output_final_state,
210 use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
211 )
212 if head_first:
213 o = o.transpose(1, 2)
214 return o, final_state
216 if use_qk_l2norm_in_kernel:
217 q_seq = _l2_normalize_last_dim(q_seq)
218 k_seq = _l2_normalize_last_dim(k_seq)
220 _, o, _, final_state, _, _, _ = chunk_gated_delta_rule_fwd(
221 q=q_seq,
222 k=k_seq,
223 v=v_seq,
224 g=g_seq,
225 beta=beta_seq,
226 scale=float(scale),
227 initial_state=initial_state,
228 output_final_state=output_final_state,
229 cu_seqlens=cu_seqlens,
230 )
232 if head_first:
233 o = o.transpose(1, 2)
234 return o, final_state