Coverage for src/flag_gems/fused/FLA/chunk_gated_delta_direct.py: 45%
75 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# This file contains a guarded direct forward path for small chunk_gated_delta_rule
2# shapes. It follows the recurrent definition directly and falls back to the
3# chunk decomposition for unsupported cases.
5from __future__ import annotations
7import torch
8import triton
9import triton.language as tl
11from flag_gems.fused.FLA.triton_ops_helper import exp
12from flag_gems.utils import libentry
14_DIRECT_MAX_T = 128
15_DIRECT_MAX_K = 128
16_DIRECT_MAX_V = 128
17_DIRECT_BV = 32
20@libentry()
21@triton.heuristics(
22 {
23 "USE_INITIAL_STATE": lambda args: args["initial_state"] is not None,
24 "STORE_FINAL_STATE": lambda args: args["final_state"] is not None,
25 }
26)
27@triton.jit
28def _chunk_gated_delta_rule_direct_fwd_kernel(
29 q,
30 k,
31 v,
32 g,
33 beta,
34 o,
35 initial_state,
36 final_state,
37 scale,
38 T: tl.constexpr,
39 H: tl.constexpr,
40 Hg: tl.constexpr,
41 K: tl.constexpr,
42 V: tl.constexpr,
43 BK: tl.constexpr,
44 BV: tl.constexpr,
45 USE_INITIAL_STATE: tl.constexpr,
46 STORE_FINAL_STATE: tl.constexpr,
47 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
48):
49 i_v = tl.program_id(0)
50 i_bh = tl.program_id(1)
51 i_b = i_bh // H
52 i_h = i_bh % H
53 i_hg = i_h // (H // Hg)
55 o_k = tl.arange(0, BK)
56 o_v = i_v * BV + tl.arange(0, BV)
57 mask_k = o_k < K
58 mask_v = o_v < V
59 mask_h = mask_k[:, None] & mask_v[None, :]
61 b_h = tl.zeros([BK, BV], dtype=tl.float32)
62 if USE_INITIAL_STATE:
63 p_h0 = (
64 initial_state + ((i_b * H + i_h) * K * V) + o_k[:, None] * V + o_v[None, :]
65 )
66 b_h += tl.load(p_h0, mask=mask_h, other=0.0).to(tl.float32)
68 q_base = q + ((i_b * T * Hg + i_hg) * K)
69 k_base = k + ((i_b * T * Hg + i_hg) * K)
70 v_base = v + ((i_b * T * H + i_h) * V)
71 o_base = o + ((i_b * T * H + i_h) * V)
72 g_base = g + i_b * T * H + i_h
73 beta_base = beta + i_b * T * H + i_h
74 for i_t in range(0, T):
75 b_q = tl.load(q_base + i_t * Hg * K + o_k, mask=mask_k, other=0.0).to(
76 tl.float32
77 )
78 b_k = tl.load(k_base + i_t * Hg * K + o_k, mask=mask_k, other=0.0).to(
79 tl.float32
80 )
81 if USE_QK_L2NORM_IN_KERNEL:
82 b_q = b_q / tl.maximum(tl.sqrt(tl.sum(b_q * b_q)), 1e-6)
83 b_k = b_k / tl.maximum(tl.sqrt(tl.sum(b_k * b_k)), 1e-6)
84 b_v = tl.load(v_base + i_t * H * V + o_v, mask=mask_v, other=0.0).to(tl.float32)
85 b_g = tl.load(g_base + i_t * H).to(tl.float32)
86 b_beta = tl.load(beta_base + i_t * H).to(tl.float32)
88 b_h *= exp(b_g)
89 b_v = (b_v - tl.sum(b_h * b_k[:, None], axis=0)) * b_beta
90 b_h += b_k[:, None] * b_v[None, :]
91 b_o = tl.sum(b_h * (b_q * scale)[:, None], axis=0)
92 tl.store(
93 o_base + i_t * H * V + o_v,
94 b_o.to(o.dtype.element_ty),
95 mask=mask_v,
96 )
98 if STORE_FINAL_STATE:
99 p_ht = final_state + ((i_b * H + i_h) * K * V) + o_k[:, None] * V + o_v[None, :]
100 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
103def can_use_chunk_gated_delta_rule_direct(
104 q: torch.Tensor,
105 k: torch.Tensor,
106 v: torch.Tensor,
107 g: torch.Tensor,
108 beta: torch.Tensor,
109 initial_state: torch.Tensor | None,
110 cu_seqlens: torch.LongTensor | None,
111) -> bool:
112 if cu_seqlens is not None:
113 return False
114 if initial_state is not None:
115 return False
116 if not (q.is_contiguous() and k.is_contiguous() and v.is_contiguous()):
117 return False
118 if not (g.is_contiguous() and beta.is_contiguous()):
119 return False
120 B, T, Hg, K = q.shape
121 Bv, Tv, H, V = v.shape
122 return (
123 B == Bv
124 and T == Tv
125 and 0 < T <= _DIRECT_MAX_T
126 and 0 < K <= _DIRECT_MAX_K
127 and 0 < V <= _DIRECT_MAX_V
128 and H % Hg == 0
129 and q.dtype in (torch.float16, torch.bfloat16, torch.float32)
130 )
133def chunk_gated_delta_rule_direct_fwd(
134 q: torch.Tensor,
135 k: torch.Tensor,
136 v: torch.Tensor,
137 g: torch.Tensor,
138 beta: torch.Tensor,
139 scale: float,
140 initial_state: torch.Tensor | None,
141 output_final_state: bool,
142 use_qk_l2norm_in_kernel: bool = False,
143) -> tuple[torch.Tensor, torch.Tensor | None]:
144 B, T, Hg, K = q.shape
145 H, V = v.shape[2], v.shape[3]
146 BK = triton.next_power_of_2(K)
147 use_one_warp = (K <= 16 and V <= 16) or (
148 q.dtype == torch.float32 and K <= 32 and V <= 32
149 )
150 BV = min(triton.next_power_of_2(V), 16 if use_one_warp else _DIRECT_BV)
152 o = torch.empty_like(v)
153 final_state = (
154 torch.empty(B, H, K, V, device=v.device, dtype=torch.float32)
155 if output_final_state
156 else None
157 )
159 def grid(meta):
160 return (triton.cdiv(V, meta["BV"]), B * H)
162 _chunk_gated_delta_rule_direct_fwd_kernel[grid](
163 q=q,
164 k=k,
165 v=v,
166 g=g,
167 beta=beta,
168 o=o,
169 initial_state=initial_state,
170 final_state=final_state,
171 scale=float(scale),
172 T=T,
173 H=H,
174 Hg=Hg,
175 K=K,
176 V=V,
177 BK=BK,
178 BV=BV,
179 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
180 num_warps=1 if use_one_warp else (4 if K >= 128 else 2),
181 num_stages=1 if K <= 16 and V <= 16 else (2 if use_one_warp else 3),
182 )
183 return o, final_state