Coverage for src/flag_gems/runtime/backend/_sunrise/fused/flash_mla.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11vendor_name = device.vendor_name 

12device = device.name 

13logger = logging.getLogger(__name__) 

14 

15 

16# @triton.autotune( 

17# configs=[ 

18# triton.Config({"BLOCK_H": h, "BLOCK_N": n}, num_warps=w, num_stages=s) 

19# for h in [32, 64, 128] 

20# for n in [32, 64, 128] 

21# for w in [4, 8] 

22# for s in [1, 2] 

23# ], 

24# key=["head_num"] 

25# ) 

26@triton.heuristics( 

27 values={ 

28 "EVEN_H": lambda META: META["head_num"] % META["BLOCK_H"] == 0, 

29 } 

30) 

31@triton.jit 

32def flash_mla_attn_kernel( 

33 Q_ptr, 

34 Kv_cache, 

35 Req_to_tokens, 

36 B_seq_len, 

37 O, 

38 sm_scale, 

39 head_num, 

40 stride_q_bs, 

41 stride_q_h, 

42 stride_kv_bs, 

43 stride_req_to_tokens_bs, 

44 stride_o_b, 

45 stride_o_h, 

46 stride_o_s, 

47 BLOCK_H: tl.constexpr, 

48 BLOCK_N: tl.constexpr, 

49 EVEN_H: tl.constexpr, 

50 PAGE_SIZE: tl.constexpr, 

51 HEAD_DIM_V: tl.constexpr, 

52 HEAD_DIM: tl.constexpr, 

53): 

54 cur_head_id = ext.program_id(0) 

55 cur_batch_id = ext.program_id(1) 

56 Req_to_tokens += stride_req_to_tokens_bs * cur_batch_id 

57 

58 cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) 

59 

60 offs_d_ckv = tl.arange(0, HEAD_DIM_V) 

61 offs_q_nope = ( 

62 cur_batch_id * stride_q_bs 

63 + cur_head[:, None] * stride_q_h 

64 + offs_d_ckv[None, :] 

65 ) 

66 

67 offs_d_kpe = tl.arange(HEAD_DIM_V, HEAD_DIM) 

68 offs_q_pe = ( 

69 cur_batch_id * stride_q_bs 

70 + cur_head[:, None] * stride_q_h 

71 + offs_d_kpe[None, :] 

72 ) 

73 

74 if EVEN_H: 

75 q_nope = tl.load(Q_ptr + offs_q_nope) 

76 q_pe = tl.load(Q_ptr + offs_q_pe) 

77 else: 

78 mask_head = cur_head < head_num 

79 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None]) 

80 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None]) 

81 

82 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) 

83 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) 

84 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) 

85 

86 cur_batch_seq_len = tl.load(B_seq_len + cur_batch_id) 

87 loop_time = cur_batch_seq_len // BLOCK_N 

88 remainder = cur_batch_seq_len % BLOCK_N 

89 offs_n = tl.arange(0, BLOCK_N) 

90 for i in range(0, loop_time): 

91 kv_page_number = tl.load(Req_to_tokens + offs_n // PAGE_SIZE) 

92 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

93 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

94 v_c = tl.load(Kv_cache + offs_v_c) 

95 k_c = tl.trans(v_c) 

96 

97 qk = tl.dot(q_nope, k_c) # qk_nope 

98 

99 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

100 k_pe = tl.load(Kv_cache + offs_k_pe) 

101 

102 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

103 qk *= sm_scale 

104 

105 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

106 re_scale = tl.exp(e_max - n_e_max) 

107 p = tl.exp(qk - n_e_max[:, None]) 

108 acc *= re_scale[:, None] 

109 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

110 

111 e_sum = e_sum * re_scale + tl.sum(p, 1) 

112 e_max = n_e_max 

113 offs_n += BLOCK_N 

114 

115 if remainder: 

116 mask_kvsplit = offs_n < cur_batch_seq_len 

117 kv_page_number = tl.load( 

118 Req_to_tokens + offs_n // PAGE_SIZE, 

119 mask=mask_kvsplit, 

120 other=0, 

121 ) 

122 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

123 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_ckv[None, :] 

124 v_c = tl.load(Kv_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) 

125 k_c = tl.trans(v_c) 

126 

127 qk = tl.dot(q_nope, k_c) # qk_nope 

128 

129 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_kpe[:, None] 

130 k_pe = tl.load(Kv_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) 

131 

132 qk = tl.dot(q_pe, k_pe, acc=qk) # qk_rope 

133 qk *= sm_scale 

134 

135 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) 

136 

137 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

138 re_scale = tl.exp(e_max - n_e_max) 

139 p = tl.exp(qk - n_e_max[:, None]) 

140 acc *= re_scale[:, None] 

141 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

142 

143 e_sum = e_sum * re_scale + tl.sum(p, 1) 

144 

145 offs_o = ( 

146 cur_batch_id * stride_o_b + cur_head[:, None] * stride_o_h + offs_d_ckv[None, :] 

147 ) 

148 if EVEN_H: 

149 tl.store( 

150 O + offs_o, 

151 acc / e_sum[:, None], 

152 ) 

153 else: 

154 tl.store(O + offs_o, acc / e_sum[:, None], mask=mask_head[:, None]) 

155 

156 

157def flash_mla( 

158 q, 

159 block_table, 

160 blocked_k, 

161 max_seqlen_pad, 

162 block_size, 

163 b, 

164 s_q, 

165 cache_seqlens, 

166 h_q, 

167 h_kv, 

168 d, 

169 dv, 

170 causal, 

171): 

172 logger.debug("GEMS FLASH MLA") 

173 assert causal, "causal False not supported" 

174 assert d > dv, "mla with rope dim should be larger than no rope dim" 

175 

176 batch_size, s_q, head_num, d = list(q.shape) 

177 q = q.view([-1, head_num, d]).contiguous() 

178 blocked_k = blocked_k.view([-1, d]).contiguous() 

179 block_table = block_table.contiguous() 

180 cache_seqlens = cache_seqlens.contiguous() 

181 

182 sm_scale = 1 / math.sqrt(d) 

183 

184 o = torch.empty([b * s_q, h_q, dv], dtype=q.dtype, device=device) 

185 

186 BLOCK_H = 16 

187 num_stages = 1 

188 

189 BLOCK_N = 16 

190 grid = ( 

191 triton.cdiv(head_num, BLOCK_H), 

192 batch_size, 

193 ) 

194 with torch_device_fn.device(device): 

195 flash_mla_attn_kernel[grid]( 

196 q, 

197 blocked_k, 

198 block_table, 

199 cache_seqlens, 

200 o, 

201 sm_scale, 

202 head_num, 

203 # stride 

204 q.stride(0), 

205 q.stride(1), 

206 blocked_k.stride(-2), 

207 block_table.stride(0), 

208 o.stride(0), 

209 o.stride(1), 

210 o.stride(2), 

211 BLOCK_H=BLOCK_H, 

212 BLOCK_N=BLOCK_N, 

213 PAGE_SIZE=block_size, 

214 HEAD_DIM_V=dv, 

215 HEAD_DIM=d, 

216 num_warps=16, 

217 num_stages=num_stages, 

218 ) 

219 

220 return o.view([b, s_q, h_q, dv])