Coverage for src/flag_gems/fused/FLA/chunk_gated_delta_direct.py: 45%

75 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# This file contains a guarded direct forward path for small chunk_gated_delta_rule 

2# shapes. It follows the recurrent definition directly and falls back to the 

3# chunk decomposition for unsupported cases. 

4 

5from __future__ import annotations 

6 

7import torch 

8import triton 

9import triton.language as tl 

10 

11from flag_gems.fused.FLA.triton_ops_helper import exp 

12from flag_gems.utils import libentry 

13 

14_DIRECT_MAX_T = 128 

15_DIRECT_MAX_K = 128 

16_DIRECT_MAX_V = 128 

17_DIRECT_BV = 32 

18 

19 

20@libentry() 

21@triton.heuristics( 

22 { 

23 "USE_INITIAL_STATE": lambda args: args["initial_state"] is not None, 

24 "STORE_FINAL_STATE": lambda args: args["final_state"] is not None, 

25 } 

26) 

27@triton.jit 

28def _chunk_gated_delta_rule_direct_fwd_kernel( 

29 q, 

30 k, 

31 v, 

32 g, 

33 beta, 

34 o, 

35 initial_state, 

36 final_state, 

37 scale, 

38 T: tl.constexpr, 

39 H: tl.constexpr, 

40 Hg: tl.constexpr, 

41 K: tl.constexpr, 

42 V: tl.constexpr, 

43 BK: tl.constexpr, 

44 BV: tl.constexpr, 

45 USE_INITIAL_STATE: tl.constexpr, 

46 STORE_FINAL_STATE: tl.constexpr, 

47 USE_QK_L2NORM_IN_KERNEL: tl.constexpr, 

48): 

49 i_v = tl.program_id(0) 

50 i_bh = tl.program_id(1) 

51 i_b = i_bh // H 

52 i_h = i_bh % H 

53 i_hg = i_h // (H // Hg) 

54 

55 o_k = tl.arange(0, BK) 

56 o_v = i_v * BV + tl.arange(0, BV) 

57 mask_k = o_k < K 

58 mask_v = o_v < V 

59 mask_h = mask_k[:, None] & mask_v[None, :] 

60 

61 b_h = tl.zeros([BK, BV], dtype=tl.float32) 

62 if USE_INITIAL_STATE: 

63 p_h0 = ( 

64 initial_state + ((i_b * H + i_h) * K * V) + o_k[:, None] * V + o_v[None, :] 

65 ) 

66 b_h += tl.load(p_h0, mask=mask_h, other=0.0).to(tl.float32) 

67 

68 q_base = q + ((i_b * T * Hg + i_hg) * K) 

69 k_base = k + ((i_b * T * Hg + i_hg) * K) 

70 v_base = v + ((i_b * T * H + i_h) * V) 

71 o_base = o + ((i_b * T * H + i_h) * V) 

72 g_base = g + i_b * T * H + i_h 

73 beta_base = beta + i_b * T * H + i_h 

74 for i_t in range(0, T): 

75 b_q = tl.load(q_base + i_t * Hg * K + o_k, mask=mask_k, other=0.0).to( 

76 tl.float32 

77 ) 

78 b_k = tl.load(k_base + i_t * Hg * K + o_k, mask=mask_k, other=0.0).to( 

79 tl.float32 

80 ) 

81 if USE_QK_L2NORM_IN_KERNEL: 

82 b_q = b_q / tl.maximum(tl.sqrt(tl.sum(b_q * b_q)), 1e-6) 

83 b_k = b_k / tl.maximum(tl.sqrt(tl.sum(b_k * b_k)), 1e-6) 

84 b_v = tl.load(v_base + i_t * H * V + o_v, mask=mask_v, other=0.0).to(tl.float32) 

85 b_g = tl.load(g_base + i_t * H).to(tl.float32) 

86 b_beta = tl.load(beta_base + i_t * H).to(tl.float32) 

87 

88 b_h *= exp(b_g) 

89 b_v = (b_v - tl.sum(b_h * b_k[:, None], axis=0)) * b_beta 

90 b_h += b_k[:, None] * b_v[None, :] 

91 b_o = tl.sum(b_h * (b_q * scale)[:, None], axis=0) 

92 tl.store( 

93 o_base + i_t * H * V + o_v, 

94 b_o.to(o.dtype.element_ty), 

95 mask=mask_v, 

96 ) 

97 

98 if STORE_FINAL_STATE: 

99 p_ht = final_state + ((i_b * H + i_h) * K * V) + o_k[:, None] * V + o_v[None, :] 

100 tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) 

101 

102 

103def can_use_chunk_gated_delta_rule_direct( 

104 q: torch.Tensor, 

105 k: torch.Tensor, 

106 v: torch.Tensor, 

107 g: torch.Tensor, 

108 beta: torch.Tensor, 

109 initial_state: torch.Tensor | None, 

110 cu_seqlens: torch.LongTensor | None, 

111) -> bool: 

112 if cu_seqlens is not None: 

113 return False 

114 if initial_state is not None: 

115 return False 

116 if not (q.is_contiguous() and k.is_contiguous() and v.is_contiguous()): 

117 return False 

118 if not (g.is_contiguous() and beta.is_contiguous()): 

119 return False 

120 B, T, Hg, K = q.shape 

121 Bv, Tv, H, V = v.shape 

122 return ( 

123 B == Bv 

124 and T == Tv 

125 and 0 < T <= _DIRECT_MAX_T 

126 and 0 < K <= _DIRECT_MAX_K 

127 and 0 < V <= _DIRECT_MAX_V 

128 and H % Hg == 0 

129 and q.dtype in (torch.float16, torch.bfloat16, torch.float32) 

130 ) 

131 

132 

133def chunk_gated_delta_rule_direct_fwd( 

134 q: torch.Tensor, 

135 k: torch.Tensor, 

136 v: torch.Tensor, 

137 g: torch.Tensor, 

138 beta: torch.Tensor, 

139 scale: float, 

140 initial_state: torch.Tensor | None, 

141 output_final_state: bool, 

142 use_qk_l2norm_in_kernel: bool = False, 

143) -> tuple[torch.Tensor, torch.Tensor | None]: 

144 B, T, Hg, K = q.shape 

145 H, V = v.shape[2], v.shape[3] 

146 BK = triton.next_power_of_2(K) 

147 use_one_warp = (K <= 16 and V <= 16) or ( 

148 q.dtype == torch.float32 and K <= 32 and V <= 32 

149 ) 

150 BV = min(triton.next_power_of_2(V), 16 if use_one_warp else _DIRECT_BV) 

151 

152 o = torch.empty_like(v) 

153 final_state = ( 

154 torch.empty(B, H, K, V, device=v.device, dtype=torch.float32) 

155 if output_final_state 

156 else None 

157 ) 

158 

159 def grid(meta): 

160 return (triton.cdiv(V, meta["BV"]), B * H) 

161 

162 _chunk_gated_delta_rule_direct_fwd_kernel[grid]( 

163 q=q, 

164 k=k, 

165 v=v, 

166 g=g, 

167 beta=beta, 

168 o=o, 

169 initial_state=initial_state, 

170 final_state=final_state, 

171 scale=float(scale), 

172 T=T, 

173 H=H, 

174 Hg=Hg, 

175 K=K, 

176 V=V, 

177 BK=BK, 

178 BV=BV, 

179 USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, 

180 num_warps=1 if use_one_warp else (4 if K >= 128 else 2), 

181 num_stages=1 if K <= 16 and V <= 16 else (2 if use_one_warp else 3), 

182 ) 

183 return o, final_state