Coverage for src/flag_gems/fused/FLA/chunk_fused_tail_vblock.py: 45%
65 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# V-blocked fused tail for the official K=V=BT=64 chunk_gated_delta_rule path.
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
9_FUSED_TAIL_BV = 16
12def can_use_fused_tail_vblock(
13 q: torch.Tensor,
14 k: torch.Tensor,
15 w: torch.Tensor,
16 u: torch.Tensor,
17 g: torch.Tensor,
18 initial_state: torch.Tensor | None,
19 output_final_state: bool,
20 *,
21 chunk_size: int,
22 cu_seqlens: torch.Tensor | None,
23) -> bool:
24 if cu_seqlens is not None or initial_state is None or not output_final_state:
25 return False
26 if q.ndim != 4 or k.ndim != 4 or w.ndim != 4 or u.ndim != 4 or g.ndim != 3:
27 return False
29 B, T, Hg, K = q.shape
30 H, V = u.shape[2], u.shape[3]
31 if k.shape != (B, T, Hg, K):
32 return False
33 if w.shape != (B, T, H, K) or g.shape != (B, T, H):
34 return False
35 if initial_state.shape != (B, H, K, V):
36 return False
37 if chunk_size != 64 or T % 64 != 0 or (K, V) != (64, 64) or H % Hg != 0:
38 return False
39 if q.dtype not in (torch.float16, torch.bfloat16):
40 return False
41 if not all(x.dtype == q.dtype for x in (k, w, u, g, initial_state)):
42 return False
43 return all(x.is_contiguous() for x in (q, k, w, u, g, initial_state))
46@libentry()
47@triton.jit
48def _chunk_gated_delta_rule_fused_tail_vblock_kernel(
49 q,
50 k,
51 w,
52 u,
53 g,
54 h0,
55 o,
56 ht,
57 scale: tl.constexpr,
58 T: tl.constexpr,
59 H: tl.constexpr,
60 Hg: tl.constexpr,
61 BT: tl.constexpr,
62 K: tl.constexpr,
63 V: tl.constexpr,
64 BV: tl.constexpr,
65):
66 i_v = tl.program_id(0)
67 i_bh = tl.program_id(1)
68 i_b = i_bh // H
69 i_h = i_bh % H
70 i_hg = i_h // (H // Hg)
72 offs_t = tl.arange(0, BT)
73 offs_k = tl.arange(0, K)
74 offs_v = i_v * BV + tl.arange(0, BV)
75 v_mask = offs_v < V
77 h0_base = ((i_b * H + i_h) * K) * V
78 h_acc = tl.load(
79 h0 + h0_base + offs_k[:, None] * V + offs_v[None, :],
80 mask=v_mask[None, :],
81 other=0.0,
82 ).to(tl.float32)
84 for i_t in range(0, tl.cdiv(T, BT)):
85 t = i_t * BT + offs_t
87 q_block = tl.load(
88 q + (((i_b * T + t[:, None]) * Hg + i_hg) * K + offs_k[None, :])
89 )
90 k_t_block = tl.load(
91 k + (((i_b * T + t[None, :]) * Hg + i_hg) * K + offs_k[:, None])
92 )
93 w_block = tl.load(
94 w + (((i_b * T + t[:, None]) * H + i_h) * K + offs_k[None, :])
95 )
96 u_block = tl.load(
97 u + (((i_b * T + t[:, None]) * H + i_h) * V + offs_v[None, :]),
98 mask=v_mask[None, :],
99 other=0.0,
100 )
101 g_vec = tl.load(g + (i_b * T + t) * H + i_h).to(tl.float32)
103 residual = u_block.to(tl.float32) - tl.dot(w_block, h_acc.to(w_block.dtype))
105 q_h = tl.dot(q_block, h_acc.to(q_block.dtype))
106 qk = tl.dot(q_block, k_t_block).to(tl.float32)
107 causal = offs_t[:, None] >= offs_t[None, :]
108 qk = tl.where(causal, qk * tl.exp(g_vec[:, None] - g_vec[None, :]), 0.0)
109 out = (
110 q_h * tl.exp(g_vec)[:, None]
111 + tl.dot(qk.to(u_block.dtype), residual.to(u_block.dtype))
112 ) * scale
113 tl.store(
114 o + (((i_b * T + t[:, None]) * H + i_h) * V + offs_v[None, :]),
115 out,
116 mask=v_mask[None, :],
117 )
119 g_last = tl.load(g + (i_b * T + ((i_t + 1) * BT - 1)) * H + i_h).to(tl.float32)
120 residual_for_state = residual * tl.exp(g_last - g_vec)[:, None]
121 h_acc = h_acc * tl.exp(g_last) + tl.dot(
122 k_t_block, residual_for_state.to(k_t_block.dtype)
123 )
125 ht_base = ((i_b * H + i_h) * K) * V
126 tl.store(
127 ht + ht_base + offs_k[:, None] * V + offs_v[None, :],
128 h_acc,
129 mask=v_mask[None, :],
130 )
133def chunk_gated_delta_rule_fused_tail_vblock(
134 q: torch.Tensor,
135 k: torch.Tensor,
136 w: torch.Tensor,
137 u: torch.Tensor,
138 g: torch.Tensor,
139 initial_state: torch.Tensor,
140 *,
141 scale: float,
142) -> tuple[torch.Tensor, torch.Tensor]:
143 B, T, Hg, K = q.shape
144 H, V = u.shape[2], u.shape[3]
146 o = torch.empty_like(u)
147 final_state = torch.empty(B, H, K, V, device=q.device, dtype=torch.float32)
148 _chunk_gated_delta_rule_fused_tail_vblock_kernel[
149 (triton.cdiv(V, _FUSED_TAIL_BV), B * H)
150 ](
151 q,
152 k,
153 w,
154 u,
155 g,
156 initial_state,
157 o,
158 final_state,
159 scale=scale,
160 T=T,
161 H=H,
162 Hg=Hg,
163 BT=64,
164 K=64,
165 V=64,
166 BV=_FUSED_TAIL_BV,
167 num_warps=4,
168 num_stages=3,
169 )
170 return o, final_state