Coverage for src/flag_gems/fused/chunk_gated_delta_rule.py: 77%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.fused.FLA import chunk_gated_delta_rule_fwd 

6from flag_gems.fused.FLA.chunk_gated_delta_direct import ( 

7 can_use_chunk_gated_delta_rule_direct, 

8 chunk_gated_delta_rule_direct_fwd, 

9) 

10from flag_gems.utils import libentry 

11 

12 

13@libentry() 

14@triton.jit 

15def _l2_normalize_last_dim_kernel( 

16 x, 

17 out, 

18 n_rows: tl.constexpr, 

19 H: tl.constexpr, 

20 K: tl.constexpr, 

21 stride_x_b: tl.constexpr, 

22 stride_x_t: tl.constexpr, 

23 stride_x_h: tl.constexpr, 

24 stride_x_k: tl.constexpr, 

25 BLOCK_K: tl.constexpr, 

26): 

27 row = tl.program_id(0) 

28 offs = tl.arange(0, BLOCK_K) 

29 mask = offs < K 

30 

31 h = row % H 

32 row_bt = row // H 

33 t = row_bt % n_rows 

34 b = row_bt // n_rows 

35 x_base = x + b * stride_x_b + t * stride_x_t + h * stride_x_h 

36 values = tl.load(x_base + offs * stride_x_k, mask=mask, other=0.0).to(tl.float32) 

37 inv_norm = 1.0 / tl.maximum(tl.sqrt(tl.sum(values * values, axis=0)), 1e-6) 

38 tl.store(out + row * K + offs, values * inv_norm, mask=mask) 

39 

40 

41def _as_seq_first( 

42 x: torch.Tensor, 

43 *, 

44 name: str, 

45 head_first: bool, 

46 expected_ndim: int, 

47) -> torch.Tensor: 

48 if not isinstance(x, torch.Tensor): 

49 raise TypeError(f"{name} must be a torch.Tensor") 

50 if x.ndim != expected_ndim: 

51 raise ValueError(f"{name} must be {expected_ndim}D, got shape {tuple(x.shape)}") 

52 if head_first: 

53 return x.transpose(1, 2) 

54 return x 

55 

56 

57def _validate_inputs( 

58 q: torch.Tensor, 

59 k: torch.Tensor, 

60 v: torch.Tensor, 

61 beta: torch.Tensor, 

62 g: torch.Tensor, 

63 initial_state: torch.Tensor | None, 

64 cu_seqlens: torch.Tensor | None, 

65) -> None: 

66 B, T, Hg, K = q.shape 

67 Bk, Tk, Hk, Kk = k.shape 

68 Bv, Tv, H, V = v.shape 

69 

70 tensors = {"k": k, "v": v, "beta": beta, "g": g} 

71 for name, tensor in tensors.items(): 

72 if tensor.device != q.device: 

73 raise ValueError(f"{name} must be on the same device as q") 

74 if tensor.dtype != q.dtype: 

75 raise ValueError(f"{name} must have the same dtype as q") 

76 

77 if (Bk, Tk, Hk, Kk) != (B, T, Hg, K): 

78 raise ValueError( 

79 "q and k must have matching [B, T, Hq, K] shapes after layout conversion" 

80 ) 

81 if (Bv, Tv) != (B, T): 

82 raise ValueError("v must have matching B and T dimensions with q/k") 

83 if H % Hg != 0: 

84 raise ValueError("the q/k head count must divide the v head count") 

85 if beta.shape != (B, T, H): 

86 raise ValueError( 

87 f"beta must have shape {(B, T, H)} after layout conversion, got {tuple(beta.shape)}" 

88 ) 

89 if g.shape != (B, T, H): 

90 raise ValueError( 

91 f"g must have shape {(B, T, H)} after layout conversion, got {tuple(g.shape)}" 

92 ) 

93 if cu_seqlens is not None: 

94 if not isinstance(cu_seqlens, torch.Tensor): 

95 raise TypeError("cu_seqlens must be a torch.Tensor") 

96 if cu_seqlens.ndim != 1: 

97 raise ValueError("cu_seqlens must be a 1D tensor") 

98 if cu_seqlens.dtype != torch.long: 

99 raise ValueError("cu_seqlens must have dtype torch.long") 

100 if cu_seqlens.device != q.device: 

101 raise ValueError("cu_seqlens must be on the same device as q") 

102 if B != 1: 

103 raise ValueError("cu_seqlens packed varlen inputs must use B=1") 

104 

105 if initial_state is not None: 

106 if initial_state.device != q.device: 

107 raise ValueError("initial_state must be on the same device as q") 

108 if initial_state.dtype != q.dtype: 

109 raise ValueError("initial_state must have the same dtype as q") 

110 expected_n = B if cu_seqlens is None else cu_seqlens.numel() - 1 

111 expected_shape = (expected_n, H, K, V) 

112 if initial_state.shape != expected_shape: 

113 raise ValueError( 

114 f"initial_state must have shape {expected_shape}, got {tuple(initial_state.shape)}" 

115 ) 

116 

117 

118def _direct_contiguous(x: torch.Tensor) -> torch.Tensor: 

119 return x if x.is_contiguous() else x.contiguous() 

120 

121 

122def _l2_normalize_last_dim(x: torch.Tensor) -> torch.Tensor: 

123 B, T, H, K = x.shape 

124 out = torch.empty_like(x, memory_format=torch.contiguous_format) 

125 block_k = triton.next_power_of_2(K) 

126 _l2_normalize_last_dim_kernel[(B * T * H,)]( 

127 x=x, 

128 out=out, 

129 n_rows=T, 

130 H=H, 

131 K=K, 

132 stride_x_b=x.stride(0), 

133 stride_x_t=x.stride(1), 

134 stride_x_h=x.stride(2), 

135 stride_x_k=x.stride(3), 

136 BLOCK_K=block_k, 

137 ) 

138 return out 

139 

140 

141def chunk_gated_delta_rule( 

142 q: torch.Tensor, 

143 k: torch.Tensor, 

144 v: torch.Tensor, 

145 beta: torch.Tensor, 

146 g: torch.Tensor, 

147 BT: int = 64, 

148 initial_state: torch.Tensor | None = None, 

149 output_final_state: bool = False, 

150 cu_seqlens: torch.Tensor | None = None, 

151 head_first: bool = True, 

152 scale: float | None = None, 

153 use_qk_l2norm_in_kernel: bool = False, 

154) -> tuple[torch.Tensor, torch.Tensor | None]: 

155 """Public wrapper for the chunk gated delta rule forward operator. 

156 

157 Inputs follow common FLA layouts: 

158 - ``head_first=True``: q/k/v are ``[B, H, T, D]`` and beta/g are ``[B, H, T]``. 

159 - ``head_first=False``: q/k/v are ``[B, T, H, D]`` and beta/g are ``[B, T, H]``. 

160 

161 q/k may use fewer heads than v when the q/k head count divides the v head count. 

162 """ 

163 if BT != 64: 

164 raise ValueError("chunk_gated_delta_rule currently supports only BT=64") 

165 

166 q_seq = _as_seq_first(q, name="q", head_first=head_first, expected_ndim=4) 

167 k_seq = _as_seq_first(k, name="k", head_first=head_first, expected_ndim=4) 

168 v_seq = _as_seq_first(v, name="v", head_first=head_first, expected_ndim=4) 

169 beta_seq = _as_seq_first(beta, name="beta", head_first=head_first, expected_ndim=3) 

170 g_seq = _as_seq_first(g, name="g", head_first=head_first, expected_ndim=3) 

171 

172 _validate_inputs(q_seq, k_seq, v_seq, beta_seq, g_seq, initial_state, cu_seqlens) 

173 

174 if scale is None: 

175 scale = k_seq.shape[-1] ** -0.5 

176 

177 B, T, Hg, K = q_seq.shape 

178 H, V = v_seq.shape[2], v_seq.shape[3] 

179 if ( 

180 initial_state is None 

181 and cu_seqlens is None 

182 and T <= 128 

183 and K <= 128 

184 and V <= 128 

185 and H % Hg == 0 

186 ): 

187 q_direct = _direct_contiguous(q_seq) 

188 k_direct = _direct_contiguous(k_seq) 

189 v_direct = _direct_contiguous(v_seq) 

190 g_direct = _direct_contiguous(g_seq) 

191 beta_direct = _direct_contiguous(beta_seq) 

192 if can_use_chunk_gated_delta_rule_direct( 

193 q=q_direct, 

194 k=k_direct, 

195 v=v_direct, 

196 g=g_direct, 

197 beta=beta_direct, 

198 initial_state=None, 

199 cu_seqlens=None, 

200 ): 

201 o, final_state = chunk_gated_delta_rule_direct_fwd( 

202 q=q_direct, 

203 k=k_direct, 

204 v=v_direct, 

205 g=g_direct, 

206 beta=beta_direct, 

207 scale=float(scale), 

208 initial_state=None, 

209 output_final_state=output_final_state, 

210 use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, 

211 ) 

212 if head_first: 

213 o = o.transpose(1, 2) 

214 return o, final_state 

215 

216 if use_qk_l2norm_in_kernel: 

217 q_seq = _l2_normalize_last_dim(q_seq) 

218 k_seq = _l2_normalize_last_dim(k_seq) 

219 

220 _, o, _, final_state, _, _, _ = chunk_gated_delta_rule_fwd( 

221 q=q_seq, 

222 k=k_seq, 

223 v=v_seq, 

224 g=g_seq, 

225 beta=beta_seq, 

226 scale=float(scale), 

227 initial_state=initial_state, 

228 output_final_state=output_final_state, 

229 cu_seqlens=cu_seqlens, 

230 ) 

231 

232 if head_first: 

233 o = o.transpose(1, 2) 

234 return o, final_state