Coverage for src/flag_gems/fused/DSA/sparse_mla.py: 24%

173 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils.triton_version_utils import HAS_TLE 

8 

9if HAS_TLE: 

10 import triton.experimental.tle.language as tle 

11else: 

12 tle = None 

13 

14logger = logging.getLogger(__name__) 

15 

16spar_mla_fwd_configs = [ 

17 triton.Config({"num_stages": 4}, num_warps=8), 

18 triton.Config({"num_stages": 2}, num_warps=4), 

19] 

20 

21 

22@triton.autotune( # Decorate the kernel 

23 configs=spar_mla_fwd_configs, 

24 key=["K", "is_causal"], 

25) 

26@triton.jit 

27def triton_sparse_mla_fwd( 

28 q, 

29 kv, 

30 indices, 

31 sm_scale: tl.constexpr, 

32 output, 

33 lse, 

34 stride_qb, 

35 stride_qh, 

36 stride_qm, 

37 stride_qd, 

38 stride_kvb, 

39 stride_kvg, 

40 stride_kvn, 

41 stride_kvd, 

42 stride_tb, 

43 stride_tg, 

44 stride_tm, 

45 stride_tt, # indices dim 

46 stride_ob, 

47 stride_oh, 

48 stride_om, 

49 stride_od, 

50 stride_lb, 

51 stride_lh, 

52 stride_lm, 

53 SQ: tl.constexpr, # seqlen 

54 K: tl.constexpr, # topk 

55 D: tl.constexpr, # QKV dim 

56 TD: tl.constexpr, # tail dim 

57 DP: tl.constexpr, 

58 TDP: tl.constexpr, 

59 G: tl.constexpr, # group_size 

60 BK: tl.constexpr, 

61 BH: tl.constexpr, 

62 is_causal: tl.constexpr, 

63): 

64 i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

65 NH = tl.cdiv(G, BH) 

66 i_g, i_bh = i_gbh // NH, i_gbh % NH 

67 q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh) 

68 tq_base = q_base + D * stride_qd 

69 kv_base = kv + i_b * stride_kvb + i_g * stride_kvg 

70 tkv_base = kv_base + D * stride_kvd 

71 t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg 

72 o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh) 

73 l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh) 

74 

75 offs_h = tl.arange(0, BH) 

76 offs_d = tl.arange(0, DP) 

77 offs_td = tl.arange(0, TDP) 

78 offs_od = tl.arange(0, DP) 

79 offs_t = tl.arange(0, BK) 

80 mask_h = i_bh * BH + offs_h < G 

81 mask_d = offs_d < D 

82 mask_td = offs_td < TD 

83 mask_od = mask_d 

84 

85 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

86 q_msk = mask_h[:, None] & mask_d[None, :] 

87 q_blk = tl.load(q_ptr, q_msk, other=0.0).to(tl.float16) 

88 

89 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd 

90 tq_msk = mask_h[:, None] & mask_td[None, :] 

91 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0).to(tl.float16) 

92 

93 max_log = tl.full([BH], float("-inf"), dtype=tl.float16) 

94 sum_exp = tl.full([BH], 1.0, dtype=tl.float16) 

95 acc = tl.zeros([BH, DP], dtype=tl.float16) 

96 qk = tl.zeros([BH, BK], dtype=tl.float16) 

97 

98 log_scale: tl.constexpr = sm_scale * 1.44269504 

99 

100 # max_col = max(0, i_sq + SKV - SQ) if is_causal else SKV-1 

101 max_col = i_sq if is_causal else SQ - 1 

102 

103 NK = tl.cdiv(K, BK) 

104 for ck in range(NK): 

105 t_ptr = (BK * ck + offs_t) * stride_tt 

106 t_msk = t_ptr < K 

107 t_ptr += t_base 

108 kv_ids = tl.load(t_ptr, t_msk, other=-1) 

109 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) 

110 

111 if tl.max(mask_ids, axis=0) > 0: 

112 kv_ptr = ( 

113 kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn 

114 ) 

115 kv_msk = mask_d[:, None] & mask_ids[None, :] 

116 kv_blk = tl.load(kv_ptr, kv_msk, other=0.0).to(tl.float16) # [DP, BK] 

117 

118 tkv_ptr = ( 

119 tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn 

120 ) 

121 tkv_msk = mask_td[:, None] & mask_ids[None, :] 

122 tkv_blk = tl.load(tkv_ptr, tkv_msk, other=0.0).to(tl.float16) # [TDP, BK] 

123 

124 qk = tl.dot(q_blk, kv_blk, out_dtype=tl.float16) 

125 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * log_scale 

126 # qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float16) * sm_scale 

127 

128 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK] 

129 

130 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) 

131 exp_qk = tl.math.exp2(qk - new_max[:, None]).to(tl.float16) 

132 # exp_qk = tl.math.exp(qk - new_max[:, None]).to(tl.float16) 

133 sum_qk = tl.sum(exp_qk, axis=1) 

134 alpha = tl.math.exp2(max_log - new_max).to(tl.float16) 

135 # alpha = tl.math.exp(max_log - new_max).to(tl.float16) 

136 sum_exp = sum_exp * alpha + sum_qk 

137 acc = acc * alpha[:, None] 

138 acc = tl.dot( 

139 exp_qk, kv_blk.trans(), acc, out_dtype=tl.float16 

140 ) # [BH, BK] @ [BK, DP] = [BH, DP] 

141 

142 max_log = new_max.to(tl.float16) 

143 

144 out_vals = acc / sum_exp[:, None] 

145 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od 

146 o_msk = mask_h[:, None] & mask_od[None, :] 

147 # o_msk &= tl.zeros_like(o_msk) 

148 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk) 

149 

150 fin_log = max_log + tl.math.log2(sum_exp.to(tl.float32)) # return lse / ln2 

151 # fin_log *= 0.69314718 

152 # fin_log = max_log + tl.math.log(sum_exp.to(tl.float32)) 

153 # fin_log *= 1.44269504 # return lse / ln2 

154 l_ptr = l_base + offs_h * stride_lh 

155 l_msk = mask_h 

156 tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk) 

157 

158 

159if HAS_TLE: 

160 

161 @triton.autotune( 

162 configs=spar_mla_fwd_configs, 

163 key=["K", "is_causal"], 

164 ) 

165 @triton.jit 

166 def triton_sparse_mla_fwd_tle( 

167 q, 

168 kv, 

169 indices, 

170 sm_scale: tl.constexpr, 

171 output, 

172 lse, 

173 stride_qb, 

174 stride_qh, 

175 stride_qm, 

176 stride_qd, 

177 stride_kvb, 

178 stride_kvg, 

179 stride_kvn, 

180 stride_kvd, 

181 stride_tb, 

182 stride_tg, 

183 stride_tm, 

184 stride_tt, 

185 stride_ob, 

186 stride_oh, 

187 stride_om, 

188 stride_od, 

189 stride_lb, 

190 stride_lh, 

191 stride_lm, 

192 SQ: tl.constexpr, 

193 K: tl.constexpr, 

194 D: tl.constexpr, 

195 TD: tl.constexpr, 

196 DP: tl.constexpr, 

197 TDP: tl.constexpr, 

198 G: tl.constexpr, 

199 BK: tl.constexpr, 

200 BH: tl.constexpr, 

201 is_causal: tl.constexpr, 

202 ): 

203 i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 

204 i_g, i_bh = i_gbh // G, i_gbh % G 

205 q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh) 

206 tq_base = q_base + D * stride_qd 

207 kv_base = kv + i_b * stride_kvb + i_g * stride_kvg 

208 tkv_base = kv_base + D * stride_kvd 

209 t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg 

210 o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh) 

211 l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh) 

212 

213 offs_h = tl.arange(0, BH) 

214 offs_d = tl.arange(0, DP) 

215 offs_td = tl.arange(0, TDP) 

216 offs_od = tl.arange(0, DP) 

217 offs_t = tl.arange(0, BK) 

218 mask_h = i_bh * BH + offs_h < G 

219 mask_d = offs_d < D 

220 mask_td = offs_td < TD 

221 mask_od = mask_d 

222 

223 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

224 q_msk = mask_h[:, None] & mask_d[None, :] 

225 q_blk = tl.load(q_ptr, q_msk, other=0.0) 

226 

227 tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd 

228 tq_msk = mask_h[:, None] & mask_td[None, :] 

229 tq_blk = tl.load(tq_ptr, tq_msk, other=0.0) 

230 

231 max_prev = tl.full([BH], float("-inf"), dtype=tl.float32) 

232 sum_exp = tl.full([BH], 1.0, dtype=tl.float32) 

233 acc = tl.zeros([BH, DP], dtype=tl.float32) 

234 

235 log_scale: tl.constexpr = sm_scale * 1.44269504 

236 

237 max_col = i_sq if is_causal else SQ - 1 

238 

239 NK = tl.cdiv(K, BK) 

240 for ck in tl.range(NK, num_stages=0): 

241 if ck * BK <= max_col: 

242 t_ptr = (BK * ck + offs_t) * stride_tt 

243 t_msk = t_ptr < K 

244 t_ptr += t_base 

245 kv_ids = tl.load(t_ptr, t_msk, other=-1) 

246 mask_ids = (kv_ids <= max_col) & (kv_ids >= 0) 

247 

248 kv_ptr = ( 

249 kv_base 

250 + offs_d[:, None] * stride_kvd 

251 + kv_ids[None, :] * stride_kvn 

252 ) 

253 kv_msk = mask_d[:, None] & mask_ids[None, :] 

254 kv_blk = tle.load(kv_ptr, kv_msk, other=0.0, is_async=True) 

255 

256 tkv_ptr = ( 

257 tkv_base 

258 + offs_td[:, None] * stride_kvd 

259 + kv_ids[None, :] * stride_kvn 

260 ) 

261 tkv_msk = mask_td[:, None] & mask_ids[None, :] 

262 tkv_blk = tle.load(tkv_ptr, tkv_msk, other=0.0, is_async=False) 

263 

264 qk = tl.dot(tq_blk, tkv_blk, out_dtype=tl.float32) 

265 qk = tl.dot(q_blk, kv_blk, qk, out_dtype=tl.float32) 

266 

267 qk = tl.where(mask_ids[None, :], qk, float("-inf")) 

268 

269 new_max = tl.maximum(max_prev, tl.max(qk, axis=1)) 

270 alpha = tl.math.exp2((max_prev - new_max) * log_scale) 

271 exp_qk = tl.math.exp2(qk * log_scale - new_max[:, None] * log_scale) 

272 sum_qk = tl.sum(exp_qk, axis=1) 

273 sum_exp = sum_exp * alpha + sum_qk 

274 acc = acc * alpha[:, None] 

275 exp_qk = exp_qk.to(tl.bfloat16) 

276 acc = tl.dot(exp_qk, tl.trans(kv_blk), acc, out_dtype=tl.float32) 

277 

278 max_prev = new_max 

279 

280 out_vals = acc / sum_exp[:, None] 

281 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od 

282 o_msk = mask_h[:, None] & mask_od 

283 tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk) 

284 

285 fin_log = max_prev * log_scale + tl.math.log2(sum_exp.to(tl.float32)) 

286 l_ptr = l_base + offs_h * stride_lh 

287 l_msk = mask_h 

288 tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk) 

289 

290 

291def triton_sparse_mla_fwd_interface( 

292 q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512 

293): 

294 logger.debug("GEMS SPARSE_MLA_FWD_INTERFACE") 

295 is_causal = True 

296 assert return_p_sum is False, "This kernel file is for fwd only" 

297 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() 

298 B, SQ, H, DT = q.shape 

299 _, _, VG, _ = kv.shape 

300 

301 # assert DT == 576, "you should assign dim otherwise" 

302 D = d_v 

303 

304 assert kv.shape[-1] == DT 

305 TD = DT - D 

306 DP = triton.next_power_of_2(D) 

307 TDP = triton.next_power_of_2(TD) 

308 assert kv.shape[0] == B 

309 _, _, _, K = indices.shape 

310 assert indices.shape == (B, SQ, VG, K) 

311 G = H // VG 

312 if sm_scale is None: 

313 sm_scale = DT**-0.5 

314 BH = max(16, min(64, triton.next_power_of_2(G))) 

315 NH = triton.cdiv(G, BH) 

316 BK = 32 

317 output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype) 

318 lse = torch.full((B, SQ, H), float("-inf"), device=q.device, dtype=q.dtype) 

319 grid = (B, SQ, VG * NH) # (SQ//BQ, B*H) 

320 kernel_args = ( 

321 q, 

322 kv, 

323 indices, 

324 sm_scale, 

325 output, 

326 lse, 

327 q.stride(0), 

328 q.stride(2), 

329 q.stride(1), 

330 q.stride(3), # [B, H, SQ, DT] 

331 kv.stride(0), 

332 kv.stride(2), 

333 kv.stride(1), 

334 kv.stride(3), # [B, VG, SKV, DT] 

335 indices.stride(0), 

336 indices.stride(2), 

337 indices.stride(1), 

338 indices.stride(3), # [B, VG, SQ, K] 

339 output.stride(0), 

340 output.stride(2), 

341 output.stride(1), 

342 output.stride(3), # [B, H, SQ, D] 

343 lse.stride(0), 

344 lse.stride(2), 

345 lse.stride(1), # [B, H, SQ] 

346 SQ, 

347 K, 

348 D, 

349 TD, 

350 DP, 

351 TDP, 

352 G, 

353 BK, 

354 BH, 

355 is_causal, 

356 ) 

357 if HAS_TLE: 

358 triton_sparse_mla_fwd_tle[grid](*kernel_args) 

359 else: 

360 triton_sparse_mla_fwd[grid](*kernel_args) 

361 return output, lse