Coverage for src/flag_gems/fused/FLA/fused_recurrent.py: 16%
256 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
6import logging
8import torch
9import triton
10import triton.language as tl
12from flag_gems.fused.FLA.triton_ops_helper import exp
14logger = logging.getLogger(__name__)
17@triton.heuristics(
18 {
19 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
20 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
21 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
22 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
23 }
24)
25@triton.jit(do_not_specialize=["N", "T"])
26# This kernel is specialized for Qwen3-Next model.
27# It requires modifications to the calling logic for Qwen3-Next:
28# Refer to the rearrange_mixed_qkv logic in the benchmark, where setting contiguous=False
29# can provide a certain performance boost by avoiding unnecessary contiguous operations.
30def fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel(
31 q,
32 k,
33 v,
34 g,
35 beta,
36 o,
37 h0,
38 ht,
39 cu_seqlens,
40 ssm_state_indices,
41 num_accepted_tokens,
42 scale,
43 N: tl.int64,
44 T: tl.int64,
45 # stride_q_b: tl.int64,
46 stride_q_t: tl.int64,
47 stride_q_h: tl.int64,
48 stride_q_k: tl.int64,
49 # stride_k_b: tl.int64,
50 stride_k_t: tl.int64,
51 stride_k_h: tl.int64,
52 stride_k_k: tl.int64,
53 # stride_v_b: tl.int64,
54 stride_v_t: tl.int64,
55 stride_v_hv: tl.int64,
56 stride_v_v: tl.int64,
57 B: tl.constexpr,
58 H: tl.constexpr,
59 HV: tl.constexpr,
60 K: tl.constexpr,
61 V: tl.constexpr,
62 BK: tl.constexpr,
63 BV: tl.constexpr,
64 stride_init_state_token: tl.constexpr,
65 stride_final_state_token: tl.constexpr,
66 stride_indices_seq: tl.constexpr,
67 stride_indices_tok: tl.constexpr,
68 USE_INITIAL_STATE: tl.constexpr,
69 INPLACE_FINAL_STATE: tl.constexpr,
70 IS_BETA_HEADWISE: tl.constexpr,
71 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
72 IS_VARLEN: tl.constexpr,
73 IS_CONTINUOUS_BATCHING: tl.constexpr,
74 IS_SPEC_DECODING: tl.constexpr,
75):
76 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
77 i_n, i_hv = i_nh // HV, i_nh % HV
78 i_h = i_hv // (HV // H)
79 if IS_VARLEN:
80 bos, eos = (
81 tl.load(cu_seqlens + i_n).to(tl.int64),
82 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
83 )
84 all = T
85 T = eos - bos
86 else:
87 bos, eos = i_n * T, i_n * T + T
88 all = B * T
90 if T == 0:
91 # no tokens to process for this sequence
92 return
94 o_k = i_k * BK + tl.arange(0, BK)
95 o_v = i_v * BV + tl.arange(0, BV)
97 p_q = q + bos * stride_q_t + i_h * stride_q_h + o_k * stride_q_k
98 p_k = k + bos * stride_k_t + i_h * stride_k_h + o_k * stride_k_k
99 p_v = v + bos * stride_v_t + i_hv * stride_v_hv + o_v * stride_v_v
100 if IS_BETA_HEADWISE:
101 p_beta = beta + (bos * HV + i_hv) * V + o_v
102 else:
103 p_beta = beta + bos * HV + i_hv
105 p_g = g + bos * HV + i_hv
107 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
109 mask_k = o_k < K
110 mask_v = o_v < V
111 mask_h = mask_k[:, None] & mask_v[None, :]
113 b_h = tl.zeros([BK, BV], dtype=tl.float32)
114 if USE_INITIAL_STATE:
115 if IS_CONTINUOUS_BATCHING:
116 if IS_SPEC_DECODING:
117 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
118 else:
119 i_t = 0
120 p_h0 = (
121 h0
122 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
123 tl.int64
124 )
125 * stride_init_state_token
126 )
127 else:
128 p_h0 = h0 + bos * HV * K * V
129 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
130 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
132 for i_t in range(0, T):
133 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
134 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
135 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
137 if USE_QK_L2NORM_IN_KERNEL:
138 b_q *= tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)
139 b_k *= tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)
140 b_q *= scale
141 # [BK, BV]
142 b_g = tl.load(p_g).to(tl.float32)
143 b_h *= exp(b_g)
144 # [BV]
145 b_v -= tl.sum(b_h * b_k[:, None], 0)
146 if IS_BETA_HEADWISE:
147 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
148 else:
149 b_beta = tl.load(p_beta).to(tl.float32)
150 b_v *= b_beta
151 # [BK, BV]
152 b_h += b_k[:, None] * b_v[None, :]
153 # [BV]
154 b_o = tl.sum(b_h * b_q[:, None], 0)
155 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
157 # keep the states for multi-query tokens
158 if INPLACE_FINAL_STATE:
159 p_ht = (
160 ht
161 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
162 tl.int64
163 )
164 * stride_final_state_token
165 )
166 else:
167 p_ht = ht + (bos + i_t) * stride_final_state_token
168 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
169 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
171 p_q += stride_q_t
172 p_k += stride_k_t
173 p_v += stride_v_t
174 p_o += HV * V
175 p_g += HV
176 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
179@triton.heuristics(
180 {
181 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
182 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
183 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
184 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
185 }
186)
187@triton.jit(do_not_specialize=["N", "T"])
188def fused_recurrent_gated_delta_rule_fwd_kernel(
189 q,
190 k,
191 v,
192 g,
193 beta,
194 o,
195 h0,
196 ht,
197 cu_seqlens,
198 ssm_state_indices,
199 num_accepted_tokens,
200 scale,
201 N: tl.int64, # num of sequences
202 T: tl.int64, # num of tokens
203 B: tl.constexpr,
204 H: tl.constexpr,
205 HV: tl.constexpr,
206 K: tl.constexpr,
207 V: tl.constexpr,
208 BK: tl.constexpr,
209 BV: tl.constexpr,
210 stride_init_state_token: tl.constexpr,
211 stride_final_state_token: tl.constexpr,
212 stride_indices_seq: tl.constexpr,
213 stride_indices_tok: tl.constexpr,
214 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
215 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
216 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
217 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
218 IS_VARLEN: tl.constexpr,
219 IS_CONTINUOUS_BATCHING: tl.constexpr,
220 IS_SPEC_DECODING: tl.constexpr,
221 IS_KDA: tl.constexpr,
222):
223 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
224 i_n, i_hv = i_nh // HV, i_nh % HV
225 i_h = i_hv // (HV // H)
226 if IS_VARLEN:
227 bos, eos = (
228 tl.load(cu_seqlens + i_n).to(tl.int64),
229 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
230 )
231 all = T
232 T = eos - bos
233 else:
234 bos, eos = i_n * T, i_n * T + T
235 all = B * T
237 if T == 0:
238 # no tokens to process for this sequence
239 return
241 o_k = i_k * BK + tl.arange(0, BK)
242 o_v = i_v * BV + tl.arange(0, BV)
244 p_q = q + (bos * H + i_h) * K + o_k
245 p_k = k + (bos * H + i_h) * K + o_k
246 p_v = v + (bos * HV + i_hv) * V + o_v
247 if IS_BETA_HEADWISE:
248 p_beta = beta + (bos * HV + i_hv) * V + o_v
249 else:
250 p_beta = beta + bos * HV + i_hv
252 if not IS_KDA:
253 p_g = g + bos * HV + i_hv
254 else:
255 p_gk = g + (bos * HV + i_hv) * K + o_k
257 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
259 mask_k = o_k < K
260 mask_v = o_v < V
261 mask_h = mask_k[:, None] & mask_v[None, :]
263 b_h = tl.zeros([BK, BV], dtype=tl.float32)
264 if USE_INITIAL_STATE:
265 if IS_CONTINUOUS_BATCHING:
266 if IS_SPEC_DECODING:
267 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
268 else:
269 i_t = 0
270 p_h0 = (
271 h0
272 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
273 tl.int64
274 )
275 * stride_init_state_token
276 )
277 else:
278 p_h0 = h0 + bos * HV * K * V
279 p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
280 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
282 for i_t in range(0, T):
283 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
284 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
285 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
287 if USE_QK_L2NORM_IN_KERNEL:
288 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
289 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
290 b_q = b_q * scale
291 # [BK, BV]
292 if not IS_KDA:
293 b_g = tl.load(p_g).to(tl.float32)
294 b_h *= exp(b_g)
295 else:
296 b_gk = tl.load(p_gk).to(tl.float32)
297 b_h *= exp(b_gk[:, None])
298 # [BV]
299 b_v -= tl.sum(b_h * b_k[:, None], 0)
300 if IS_BETA_HEADWISE:
301 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
302 else:
303 b_beta = tl.load(p_beta).to(tl.float32)
304 b_v *= b_beta
305 # [BK, BV]
306 b_h += b_k[:, None] * b_v[None, :]
307 # [BV]
308 b_o = tl.sum(b_h * b_q[:, None], 0)
309 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
311 # keep the states for multi-query tokens
312 if INPLACE_FINAL_STATE:
313 p_ht = (
314 ht
315 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
316 tl.int64
317 )
318 * stride_final_state_token
319 )
320 else:
321 p_ht = ht + (bos + i_t) * stride_final_state_token
322 p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
323 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
325 p_q += H * K
326 p_k += H * K
327 p_o += HV * V
328 p_v += HV * V
329 if not IS_KDA:
330 p_g += HV
331 else:
332 p_gk += HV * K
333 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
336@triton.heuristics(
337 {
338 "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
339 "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
340 "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
341 "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
342 }
343)
344@triton.jit(do_not_specialize=["N", "T"])
345def fused_recurrent_gated_delta_rule_large_t_fwd_kernel(
346 q,
347 k,
348 v,
349 g,
350 beta,
351 o,
352 h0,
353 ht,
354 cu_seqlens,
355 ssm_state_indices,
356 num_accepted_tokens,
357 scale,
358 N: tl.int64, # num of sequences
359 T: tl.int64, # num of tokens
360 B: tl.constexpr,
361 H: tl.constexpr,
362 HV: tl.constexpr,
363 K: tl.constexpr,
364 V: tl.constexpr,
365 BK: tl.constexpr,
366 BV: tl.constexpr,
367 stride_init_state_token: tl.constexpr,
368 stride_final_state_token: tl.constexpr,
369 stride_indices_seq: tl.constexpr,
370 stride_indices_tok: tl.constexpr,
371 USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
372 INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
373 IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
374 USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
375 IS_VARLEN: tl.constexpr,
376 IS_CONTINUOUS_BATCHING: tl.constexpr,
377 IS_SPEC_DECODING: tl.constexpr,
378 IS_KDA: tl.constexpr,
379):
380 i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
381 i_n, i_hv = i_nh // HV, i_nh % HV
382 i_h = i_hv // (HV // H)
383 if IS_VARLEN:
384 bos, eos = (
385 tl.load(cu_seqlens + i_n).to(tl.int64),
386 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
387 )
388 all = T
389 T = eos - bos
390 else:
391 bos, eos = i_n * T, i_n * T + T
392 all = B * T
394 if T == 0:
395 # no tokens to process for this sequence
396 return
398 o_k = i_k * BK + tl.arange(0, BK)
399 o_v = i_v * BV + tl.arange(0, BV)
401 p_q = q + (bos * H + i_h) * K + o_k
402 p_k = k + (bos * H + i_h) * K + o_k
403 p_v = v + (bos * HV + i_hv) * V + o_v
404 if IS_BETA_HEADWISE:
405 p_beta = beta + (bos * HV + i_hv) * V + o_v
406 else:
407 p_beta = beta + bos * HV + i_hv
409 if not IS_KDA:
410 p_g = g + bos * HV + i_hv
411 else:
412 p_gk = g + (bos * HV + i_hv) * K + o_k
414 p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
416 mask_k = o_k < K
417 mask_v = o_v < V
418 mask_h = mask_v[:, None] & mask_k[None, :]
420 b_h = tl.zeros([BV, BK], dtype=tl.float32)
421 if USE_INITIAL_STATE:
422 if IS_CONTINUOUS_BATCHING:
423 if IS_SPEC_DECODING:
424 i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
425 else:
426 i_t = 0
427 # Load state index and check for PAD_SLOT_ID (-1)
428 state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
429 tl.int64
430 )
431 # Skip if state index is invalid (PAD_SLOT_ID = -1)
432 if state_idx < 0:
433 return
434 p_h0 = h0 + state_idx * stride_init_state_token
435 else:
436 p_h0 = h0 + bos * HV * V * K
437 p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
438 b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
440 for i_t in range(0, T):
441 b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
442 b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
443 b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
445 if USE_QK_L2NORM_IN_KERNEL:
446 b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
447 b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
448 b_q = b_q * scale
449 # [BV, BK]
450 if not IS_KDA:
451 b_g = tl.load(p_g).to(tl.float32)
452 b_h *= exp(b_g)
453 else:
454 b_gk = tl.load(p_gk).to(tl.float32)
455 b_h *= exp(b_gk[None, :])
456 # [BV]
457 b_v -= tl.sum(b_h * b_k[None, :], 1)
458 if IS_BETA_HEADWISE:
459 b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
460 else:
461 b_beta = tl.load(p_beta).to(tl.float32)
462 b_v *= b_beta
463 # [BV, BK]
464 b_h += b_v[:, None] * b_k[None, :]
465 # [BV]
466 b_o = tl.sum(b_h * b_q[None, :], 1)
467 tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
469 # keep the states for multi-query tokens
470 if INPLACE_FINAL_STATE:
471 # Load state index and check for PAD_SLOT_ID (-1)
472 final_state_idx = tl.load(
473 ssm_state_indices + i_n * stride_indices_seq + i_t
474 ).to(tl.int64)
475 # Only store if state index is valid (not PAD_SLOT_ID)
476 if final_state_idx >= 0:
477 p_ht = ht + final_state_idx * stride_final_state_token
478 p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
479 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
480 else:
481 p_ht = ht + (bos + i_t) * stride_final_state_token
482 p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
483 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
485 p_q += H * K
486 p_k += H * K
487 p_o += HV * V
488 p_v += HV * V
489 if not IS_KDA:
490 p_g += HV
491 else:
492 p_gk += HV * K
493 p_beta += HV * (V if IS_BETA_HEADWISE else 1)
496def fused_recurrent_gated_delta_rule_fwd(
497 q: torch.Tensor,
498 k: torch.Tensor,
499 v: torch.Tensor,
500 g: torch.Tensor,
501 beta: torch.Tensor,
502 scale: float,
503 initial_state: torch.Tensor,
504 inplace_final_state: bool = True,
505 cu_seqlens: torch.LongTensor | None = None,
506 ssm_state_indices: torch.Tensor | None = None,
507 num_accepted_tokens: torch.Tensor | None = None,
508 use_qk_l2norm_in_kernel: bool = False,
509) -> tuple[torch.Tensor, torch.Tensor]:
510 logger.debug("GEMS FUSED RECURRENT GATED DELTA RULE FWD")
511 B, T, H, K, V = *k.shape, v.shape[-1]
512 HV = v.shape[2]
513 N = B if cu_seqlens is None else len(cu_seqlens) - 1
514 BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
515 NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
516 assert NK == 1, "NK > 1 is not supported yet"
517 num_stages = 3
518 num_warps = 1
519 qkv_contiguous = (
520 (q.stride(0) == q.stride(1) + q.stride(2))
521 and (k.stride(0) == k.stride(1) + k.stride(2))
522 and (v.stride(0) == v.stride(1) + v.stride(2))
523 )
525 o = q.new_empty(NK, *v.shape)
526 if inplace_final_state:
527 final_state = initial_state
528 else:
529 final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
531 stride_init_state_token = initial_state.stride(0)
532 stride_final_state_token = final_state.stride(0)
534 if ssm_state_indices is None:
535 stride_indices_seq, stride_indices_tok = 1, 1
536 elif ssm_state_indices.ndim == 1:
537 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
538 else:
539 stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
541 grid = (NK, NV, N * HV)
542 if qkv_contiguous:
543 fused_recurrent_gated_delta_rule_fwd_kernel[grid](
544 q=q,
545 k=k,
546 v=v,
547 g=g,
548 beta=beta,
549 o=o,
550 h0=initial_state,
551 ht=final_state,
552 cu_seqlens=cu_seqlens,
553 ssm_state_indices=ssm_state_indices,
554 num_accepted_tokens=num_accepted_tokens,
555 scale=scale,
556 N=N,
557 T=T,
558 B=B,
559 H=H,
560 HV=HV,
561 K=K,
562 V=V,
563 BK=BK,
564 BV=BV,
565 stride_init_state_token=stride_init_state_token,
566 stride_final_state_token=stride_final_state_token,
567 stride_indices_seq=stride_indices_seq,
568 stride_indices_tok=stride_indices_tok,
569 IS_BETA_HEADWISE=beta.ndim == v.ndim,
570 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
571 INPLACE_FINAL_STATE=inplace_final_state,
572 IS_KDA=False,
573 num_warps=num_warps,
574 num_stages=num_stages,
575 )
576 else:
577 logger.debug(
578 "GEMS fused_recurrent_gated_delta_rule_fwd, "
579 "[q.shape]: %s, [q.stride]: %s, "
580 "[k.shape]: %s, [k.stride]: %s, "
581 "[v.shape]: %s, [v.stride]: %s, "
582 "[g.shape]: %s, [beta.shape]: %s, [initial_state.shape]: %s, "
583 "[cu_seqlens.shape]: %s, N: %s, T: %s, B: %s, H: %s, HV: %s, K: %s, V: %s",
584 q.shape,
585 q.stride(),
586 k.shape,
587 k.stride(),
588 v.shape,
589 v.stride(),
590 g.shape,
591 beta.shape,
592 initial_state.shape,
593 cu_seqlens.shape,
594 N,
595 T,
596 B,
597 H,
598 HV,
599 K,
600 V,
601 )
602 if T <= 64:
603 fused_recurrent_gated_delta_rule_fwd_sp_for_qwen3_next_kernel[grid](
604 q=q,
605 k=k,
606 v=v,
607 g=g,
608 beta=beta,
609 o=o,
610 h0=initial_state,
611 ht=final_state,
612 cu_seqlens=cu_seqlens,
613 ssm_state_indices=ssm_state_indices,
614 num_accepted_tokens=num_accepted_tokens,
615 scale=scale,
616 N=N,
617 T=T,
618 B=B,
619 H=H,
620 HV=HV,
621 K=K,
622 V=V,
623 BK=BK,
624 BV=BV,
625 stride_init_state_token=stride_init_state_token,
626 stride_final_state_token=stride_final_state_token,
627 stride_indices_seq=stride_indices_seq,
628 stride_indices_tok=stride_indices_tok,
629 # stride_q_b=q.stride(0),
630 stride_q_t=q.stride(1),
631 stride_q_h=q.stride(2),
632 stride_q_k=q.stride(3),
633 # stride_k_b=k.stride(0),
634 stride_k_t=k.stride(1),
635 stride_k_h=k.stride(2),
636 stride_k_k=k.stride(3),
637 # stride_v_b=v.stride(0),
638 stride_v_t=v.stride(1),
639 stride_v_hv=v.stride(2),
640 stride_v_v=v.stride(3),
641 IS_BETA_HEADWISE=beta.ndim == v.ndim,
642 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
643 INPLACE_FINAL_STATE=inplace_final_state,
644 IS_SPEC_DECODING=num_accepted_tokens is not None,
645 IS_CONTINUOUS_BATCHING=ssm_state_indices is not None,
646 IS_VARLEN=cu_seqlens is not None,
647 USE_INITIAL_STATE=initial_state is not None,
648 num_warps=num_warps,
649 num_stages=num_stages,
650 )
651 else:
652 fused_recurrent_gated_delta_rule_large_t_fwd_kernel[grid](
653 q=q.contiguous(),
654 k=k.contiguous(),
655 v=v.contiguous(),
656 g=g.contiguous(),
657 beta=beta.contiguous(),
658 o=o,
659 h0=initial_state,
660 ht=final_state,
661 cu_seqlens=cu_seqlens,
662 ssm_state_indices=ssm_state_indices,
663 num_accepted_tokens=num_accepted_tokens,
664 scale=scale,
665 N=N,
666 T=T,
667 B=B,
668 H=H,
669 HV=HV,
670 K=K,
671 V=V,
672 BK=BK,
673 BV=BV,
674 stride_init_state_token=stride_init_state_token,
675 stride_final_state_token=stride_final_state_token,
676 stride_indices_seq=stride_indices_seq,
677 stride_indices_tok=stride_indices_tok,
678 IS_BETA_HEADWISE=beta.ndim == v.ndim,
679 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
680 INPLACE_FINAL_STATE=inplace_final_state,
681 IS_KDA=False,
682 num_warps=num_warps,
683 num_stages=num_stages,
684 )
685 o = o.squeeze(0)
686 return o, final_state