Coverage for src/flag_gems/fused/FLA/chunk.py: 67%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# This file contains code copied from the flash-linear-attention project.
2# The original source code was licensed under the MIT license and included
3# the following copyright notice:
4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
5# ruff: noqa: E501
7import logging
9import torch
11from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h
12from flag_gems.fused.FLA.chunk_fused_tail_vblock import (
13 can_use_fused_tail_vblock,
14 chunk_gated_delta_rule_fused_tail_vblock,
15)
16from flag_gems.fused.FLA.chunk_o import chunk_fwd_o
17from flag_gems.fused.FLA.fused_cumsum_kkt_solve_tril import (
18 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril,
19)
20from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL
21from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd
23logger = logging.getLogger(__name__)
26def _chunk_size_for_sequence(T: int, is_varlen: bool) -> int:
27 if is_varlen:
28 return 64
29 return min(64, max(16, 1 << (T - 1).bit_length()))
32def chunk_gated_delta_rule_fwd(
33 q: torch.Tensor,
34 k: torch.Tensor,
35 v: torch.Tensor,
36 g: torch.Tensor,
37 beta: torch.Tensor,
38 scale: float,
39 initial_state: torch.Tensor,
40 output_final_state: bool,
41 cu_seqlens: torch.LongTensor | None = None,
42):
43 logger.debug("GEMS CHUNK GATED DELTA RULE FWD")
44 q_contiguous = q.is_contiguous()
45 k_contiguous = k.is_contiguous()
46 v_contiguous = v.is_contiguous()
47 g_contiguous = g.is_contiguous()
48 beta_contiguous = beta.is_contiguous()
49 initial_state_contiguous = initial_state is None or initial_state.is_contiguous()
50 cu_seqlens_contiguous = cu_seqlens is None or cu_seqlens.is_contiguous()
51 if not (
52 q_contiguous
53 and k_contiguous
54 and v_contiguous
55 and g_contiguous
56 and beta_contiguous
57 and initial_state_contiguous
58 and cu_seqlens_contiguous
59 ):
60 if not q_contiguous:
61 q = q.contiguous()
62 if not k_contiguous:
63 k = k.contiguous()
64 if not v_contiguous:
65 v = v.contiguous()
66 if not g_contiguous:
67 g = g.contiguous()
68 if not beta_contiguous:
69 beta = beta.contiguous()
70 if not initial_state_contiguous:
71 initial_state = initial_state.contiguous()
72 if not cu_seqlens_contiguous:
73 cu_seqlens = cu_seqlens.contiguous()
75 chunk_size = _chunk_size_for_sequence(q.shape[1], cu_seqlens is not None)
77 g, A = chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril(
78 g=g,
79 k=k,
80 beta=beta,
81 cu_seqlens=cu_seqlens,
82 chunk_size=chunk_size,
83 output_dtype=k.dtype,
84 )
85 w, u = recompute_w_u_fwd(
86 k=k,
87 v=v,
88 beta=beta,
89 A=A,
90 g_cumsum=g,
91 cu_seqlens=cu_seqlens,
92 )
93 if SUPPRESS_LEVEL < 3 and can_use_fused_tail_vblock(
94 q=q,
95 k=k,
96 w=w,
97 u=u,
98 g=g,
99 initial_state=initial_state,
100 output_final_state=output_final_state,
101 chunk_size=chunk_size,
102 cu_seqlens=cu_seqlens,
103 ):
104 o, final_state = chunk_gated_delta_rule_fused_tail_vblock(
105 q=q,
106 k=k,
107 w=w,
108 u=u,
109 g=g,
110 initial_state=initial_state,
111 scale=scale,
112 )
113 return g, o, A, final_state, None, None, None
114 h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
115 k=k,
116 w=w,
117 u=u,
118 g=g,
119 initial_state=initial_state,
120 output_final_state=output_final_state,
121 chunk_size=chunk_size,
122 cu_seqlens=cu_seqlens,
123 )
124 o = chunk_fwd_o(
125 q=q,
126 k=k,
127 v=v_new,
128 h=h,
129 g=g,
130 scale=scale,
131 cu_seqlens=cu_seqlens,
132 chunk_size=chunk_size,
133 )
134 if SUPPRESS_LEVEL < 3:
135 return g, o, A, final_state, None, None, None
136 elif SUPPRESS_LEVEL >= 3:
137 return g, o, A, final_state, w, h, v_new