Coverage for src/flag_gems/fused/FLA/chunk.py: 67%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1# This file contains code copied from the flash-linear-attention project. 

2# The original source code was licensed under the MIT license and included 

3# the following copyright notice: 

4# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang 

5# ruff: noqa: E501 

6 

7import logging 

8 

9import torch 

10 

11from flag_gems.fused.FLA.chunk_delta_h import chunk_gated_delta_rule_fwd_h 

12from flag_gems.fused.FLA.chunk_fused_tail_vblock import ( 

13 can_use_fused_tail_vblock, 

14 chunk_gated_delta_rule_fused_tail_vblock, 

15) 

16from flag_gems.fused.FLA.chunk_o import chunk_fwd_o 

17from flag_gems.fused.FLA.fused_cumsum_kkt_solve_tril import ( 

18 chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril, 

19) 

20from flag_gems.fused.FLA.utils import SUPPRESS_LEVEL 

21from flag_gems.fused.FLA.wy_fast import recompute_w_u_fwd 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26def _chunk_size_for_sequence(T: int, is_varlen: bool) -> int: 

27 if is_varlen: 

28 return 64 

29 return min(64, max(16, 1 << (T - 1).bit_length())) 

30 

31 

32def chunk_gated_delta_rule_fwd( 

33 q: torch.Tensor, 

34 k: torch.Tensor, 

35 v: torch.Tensor, 

36 g: torch.Tensor, 

37 beta: torch.Tensor, 

38 scale: float, 

39 initial_state: torch.Tensor, 

40 output_final_state: bool, 

41 cu_seqlens: torch.LongTensor | None = None, 

42): 

43 logger.debug("GEMS CHUNK GATED DELTA RULE FWD") 

44 q_contiguous = q.is_contiguous() 

45 k_contiguous = k.is_contiguous() 

46 v_contiguous = v.is_contiguous() 

47 g_contiguous = g.is_contiguous() 

48 beta_contiguous = beta.is_contiguous() 

49 initial_state_contiguous = initial_state is None or initial_state.is_contiguous() 

50 cu_seqlens_contiguous = cu_seqlens is None or cu_seqlens.is_contiguous() 

51 if not ( 

52 q_contiguous 

53 and k_contiguous 

54 and v_contiguous 

55 and g_contiguous 

56 and beta_contiguous 

57 and initial_state_contiguous 

58 and cu_seqlens_contiguous 

59 ): 

60 if not q_contiguous: 

61 q = q.contiguous() 

62 if not k_contiguous: 

63 k = k.contiguous() 

64 if not v_contiguous: 

65 v = v.contiguous() 

66 if not g_contiguous: 

67 g = g.contiguous() 

68 if not beta_contiguous: 

69 beta = beta.contiguous() 

70 if not initial_state_contiguous: 

71 initial_state = initial_state.contiguous() 

72 if not cu_seqlens_contiguous: 

73 cu_seqlens = cu_seqlens.contiguous() 

74 

75 chunk_size = _chunk_size_for_sequence(q.shape[1], cu_seqlens is not None) 

76 

77 g, A = chunk_gated_delta_rule_fused_cumsum_kkt_solve_tril( 

78 g=g, 

79 k=k, 

80 beta=beta, 

81 cu_seqlens=cu_seqlens, 

82 chunk_size=chunk_size, 

83 output_dtype=k.dtype, 

84 ) 

85 w, u = recompute_w_u_fwd( 

86 k=k, 

87 v=v, 

88 beta=beta, 

89 A=A, 

90 g_cumsum=g, 

91 cu_seqlens=cu_seqlens, 

92 ) 

93 if SUPPRESS_LEVEL < 3 and can_use_fused_tail_vblock( 

94 q=q, 

95 k=k, 

96 w=w, 

97 u=u, 

98 g=g, 

99 initial_state=initial_state, 

100 output_final_state=output_final_state, 

101 chunk_size=chunk_size, 

102 cu_seqlens=cu_seqlens, 

103 ): 

104 o, final_state = chunk_gated_delta_rule_fused_tail_vblock( 

105 q=q, 

106 k=k, 

107 w=w, 

108 u=u, 

109 g=g, 

110 initial_state=initial_state, 

111 scale=scale, 

112 ) 

113 return g, o, A, final_state, None, None, None 

114 h, v_new, final_state = chunk_gated_delta_rule_fwd_h( 

115 k=k, 

116 w=w, 

117 u=u, 

118 g=g, 

119 initial_state=initial_state, 

120 output_final_state=output_final_state, 

121 chunk_size=chunk_size, 

122 cu_seqlens=cu_seqlens, 

123 ) 

124 o = chunk_fwd_o( 

125 q=q, 

126 k=k, 

127 v=v_new, 

128 h=h, 

129 g=g, 

130 scale=scale, 

131 cu_seqlens=cu_seqlens, 

132 chunk_size=chunk_size, 

133 ) 

134 if SUPPRESS_LEVEL < 3: 

135 return g, o, A, final_state, None, None, None 

136 elif SUPPRESS_LEVEL >= 3: 

137 return g, o, A, final_state, w, h, v_new