Coverage for src/flag_gems/fused/FLA/chunk_fused_tail_vblock.py: 45%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1# V-blocked fused tail for the official K=V=BT=64 chunk_gated_delta_rule path. 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8 

9_FUSED_TAIL_BV = 16 

10 

11 

12def can_use_fused_tail_vblock( 

13 q: torch.Tensor, 

14 k: torch.Tensor, 

15 w: torch.Tensor, 

16 u: torch.Tensor, 

17 g: torch.Tensor, 

18 initial_state: torch.Tensor | None, 

19 output_final_state: bool, 

20 *, 

21 chunk_size: int, 

22 cu_seqlens: torch.Tensor | None, 

23) -> bool: 

24 if cu_seqlens is not None or initial_state is None or not output_final_state: 

25 return False 

26 if q.ndim != 4 or k.ndim != 4 or w.ndim != 4 or u.ndim != 4 or g.ndim != 3: 

27 return False 

28 

29 B, T, Hg, K = q.shape 

30 H, V = u.shape[2], u.shape[3] 

31 if k.shape != (B, T, Hg, K): 

32 return False 

33 if w.shape != (B, T, H, K) or g.shape != (B, T, H): 

34 return False 

35 if initial_state.shape != (B, H, K, V): 

36 return False 

37 if chunk_size != 64 or T % 64 != 0 or (K, V) != (64, 64) or H % Hg != 0: 

38 return False 

39 if q.dtype not in (torch.float16, torch.bfloat16): 

40 return False 

41 if not all(x.dtype == q.dtype for x in (k, w, u, g, initial_state)): 

42 return False 

43 return all(x.is_contiguous() for x in (q, k, w, u, g, initial_state)) 

44 

45 

46@libentry() 

47@triton.jit 

48def _chunk_gated_delta_rule_fused_tail_vblock_kernel( 

49 q, 

50 k, 

51 w, 

52 u, 

53 g, 

54 h0, 

55 o, 

56 ht, 

57 scale: tl.constexpr, 

58 T: tl.constexpr, 

59 H: tl.constexpr, 

60 Hg: tl.constexpr, 

61 BT: tl.constexpr, 

62 K: tl.constexpr, 

63 V: tl.constexpr, 

64 BV: tl.constexpr, 

65): 

66 i_v = tl.program_id(0) 

67 i_bh = tl.program_id(1) 

68 i_b = i_bh // H 

69 i_h = i_bh % H 

70 i_hg = i_h // (H // Hg) 

71 

72 offs_t = tl.arange(0, BT) 

73 offs_k = tl.arange(0, K) 

74 offs_v = i_v * BV + tl.arange(0, BV) 

75 v_mask = offs_v < V 

76 

77 h0_base = ((i_b * H + i_h) * K) * V 

78 h_acc = tl.load( 

79 h0 + h0_base + offs_k[:, None] * V + offs_v[None, :], 

80 mask=v_mask[None, :], 

81 other=0.0, 

82 ).to(tl.float32) 

83 

84 for i_t in range(0, tl.cdiv(T, BT)): 

85 t = i_t * BT + offs_t 

86 

87 q_block = tl.load( 

88 q + (((i_b * T + t[:, None]) * Hg + i_hg) * K + offs_k[None, :]) 

89 ) 

90 k_t_block = tl.load( 

91 k + (((i_b * T + t[None, :]) * Hg + i_hg) * K + offs_k[:, None]) 

92 ) 

93 w_block = tl.load( 

94 w + (((i_b * T + t[:, None]) * H + i_h) * K + offs_k[None, :]) 

95 ) 

96 u_block = tl.load( 

97 u + (((i_b * T + t[:, None]) * H + i_h) * V + offs_v[None, :]), 

98 mask=v_mask[None, :], 

99 other=0.0, 

100 ) 

101 g_vec = tl.load(g + (i_b * T + t) * H + i_h).to(tl.float32) 

102 

103 residual = u_block.to(tl.float32) - tl.dot(w_block, h_acc.to(w_block.dtype)) 

104 

105 q_h = tl.dot(q_block, h_acc.to(q_block.dtype)) 

106 qk = tl.dot(q_block, k_t_block).to(tl.float32) 

107 causal = offs_t[:, None] >= offs_t[None, :] 

108 qk = tl.where(causal, qk * tl.exp(g_vec[:, None] - g_vec[None, :]), 0.0) 

109 out = ( 

110 q_h * tl.exp(g_vec)[:, None] 

111 + tl.dot(qk.to(u_block.dtype), residual.to(u_block.dtype)) 

112 ) * scale 

113 tl.store( 

114 o + (((i_b * T + t[:, None]) * H + i_h) * V + offs_v[None, :]), 

115 out, 

116 mask=v_mask[None, :], 

117 ) 

118 

119 g_last = tl.load(g + (i_b * T + ((i_t + 1) * BT - 1)) * H + i_h).to(tl.float32) 

120 residual_for_state = residual * tl.exp(g_last - g_vec)[:, None] 

121 h_acc = h_acc * tl.exp(g_last) + tl.dot( 

122 k_t_block, residual_for_state.to(k_t_block.dtype) 

123 ) 

124 

125 ht_base = ((i_b * H + i_h) * K) * V 

126 tl.store( 

127 ht + ht_base + offs_k[:, None] * V + offs_v[None, :], 

128 h_acc, 

129 mask=v_mask[None, :], 

130 ) 

131 

132 

133def chunk_gated_delta_rule_fused_tail_vblock( 

134 q: torch.Tensor, 

135 k: torch.Tensor, 

136 w: torch.Tensor, 

137 u: torch.Tensor, 

138 g: torch.Tensor, 

139 initial_state: torch.Tensor, 

140 *, 

141 scale: float, 

142) -> tuple[torch.Tensor, torch.Tensor]: 

143 B, T, Hg, K = q.shape 

144 H, V = u.shape[2], u.shape[3] 

145 

146 o = torch.empty_like(u) 

147 final_state = torch.empty(B, H, K, V, device=q.device, dtype=torch.float32) 

148 _chunk_gated_delta_rule_fused_tail_vblock_kernel[ 

149 (triton.cdiv(V, _FUSED_TAIL_BV), B * H) 

150 ]( 

151 q, 

152 k, 

153 w, 

154 u, 

155 g, 

156 initial_state, 

157 o, 

158 final_state, 

159 scale=scale, 

160 T=T, 

161 H=H, 

162 Hg=Hg, 

163 BT=64, 

164 K=64, 

165 V=64, 

166 BV=_FUSED_TAIL_BV, 

167 num_warps=4, 

168 num_stages=3, 

169 ) 

170 return o, final_state