Coverage for src/flag_gems/fused/flashmla_sparse.py: 11%

501 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import os 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils.triton_version_utils import has_triton_tle 

9 

10if has_triton_tle(3, 6, 0): 

11 try: 

12 import triton.experimental.tle.language as tle 

13 

14 HAS_TLE_FLASHMLA_SPARSE = True 

15 except ImportError: 

16 tle = None 

17 HAS_TLE_FLASHMLA_SPARSE = False 

18else: 

19 tle = None 

20 HAS_TLE_FLASHMLA_SPARSE = False 

21 

22 

23TLE_FLASHMLA_PREFILL_BK = 64 

24TLE_FLASHMLA_PREFILL_BH = 64 

25TLE_FLASHMLA_PREFILL_PAIR_BLOCKS = 2 

26TLE_FLASHMLA_PREFILL_WORKER_NUM_WARPS = 4 

27 

28 

29@triton.autotune( 

30 configs=[ 

31 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=2), 

32 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=4), 

33 ], 

34 key=["SQ", "HQ", "DQK", "SKV", "TOPK", "HAVE_ATTN_SINK", "HAVE_TOPK_LENGTH"], 

35) 

36@triton.jit 

37def triton_flash_mla_sparse_fwd( 

38 q, 

39 kv, 

40 indices, 

41 attn_sink, 

42 topk_length, 

43 sm_scale: tl.constexpr, 

44 output, 

45 max_logits, 

46 lse, 

47 stride_qh, 

48 stride_qm, 

49 stride_kvg, 

50 stride_kvn, 

51 stride_tg, 

52 stride_tm, 

53 stride_oh, 

54 stride_om, 

55 stride_mm, 

56 stride_lm, 

57 SQ, # s_q 

58 HQ: tl.constexpr, # h_q=64 or 128 

59 DQK: tl.constexpr, # d_qk=512 or 576 

60 SKV, # s_kv 

61 TOPK: tl.constexpr, # topk 

62 HAVE_ATTN_SINK: tl.constexpr, 

63 HAVE_TOPK_LENGTH: tl.constexpr, 

64 BK: tl.constexpr, 

65 BH: tl.constexpr, 

66): 

67 num_head_blocks: tl.constexpr = (HQ + BH - 1) // BH 

68 pid = tl.program_id(0) 

69 i_sq = pid // num_head_blocks 

70 i_sq = i_sq.to(tl.int64) # prevent mul overflow 

71 i_gbh = pid % num_head_blocks 

72 gbh_base = i_gbh * BH 

73 DP: tl.constexpr = 512 

74 BDP: tl.constexpr = 256 

75 

76 q_base = q + i_sq * stride_qm + gbh_base * stride_qh 

77 kv_base = kv 

78 tkv_base = kv + DP 

79 t_base = indices + i_sq * stride_tm 

80 attn_sink_ptr = attn_sink + gbh_base if HAVE_ATTN_SINK else 0 

81 topk_length_ptr = topk_length + i_sq if HAVE_TOPK_LENGTH else 0 

82 o_base = output + i_sq * stride_om + gbh_base * stride_oh 

83 max_log_base = max_logits + i_sq * stride_mm + gbh_base 

84 l_base = lse + i_sq * stride_lm + gbh_base 

85 

86 offs_h = tl.arange(0, BH) 

87 offs_d = tl.arange(0, BDP) 

88 if DQK == 576: 

89 offs_td = tl.arange(0, 64) 

90 offs_t = tl.arange(0, BK) 

91 

92 # `[BH, 256] x 2` delivers better performance than `[BH, 512]` when BH=64 

93 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] 

94 q_blk0 = tl.load(q_ptr, eviction_policy="evict_first") 

95 q_blk1 = tl.load(q_ptr + BDP, eviction_policy="evict_first") 

96 if DQK == 576: 

97 tq_ptr = q_base + DP + offs_h[:, None] * stride_qh + offs_td[None, :] 

98 tq_blk = tl.load(tq_ptr, eviction_policy="evict_first") 

99 

100 max_log = tl.full([BH], float("-inf"), dtype=tl.float32) 

101 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

102 acc0 = tl.zeros([BH, BDP], dtype=tl.float32) 

103 acc1 = tl.zeros([BH, BDP], dtype=tl.float32) 

104 

105 topk_len = tl.load(topk_length_ptr) if HAVE_TOPK_LENGTH else TOPK 

106 NK = tl.cdiv(topk_len, BK) 

107 for ck in range(NK): 

108 # step1: load indices 

109 t_ptr = BK * ck + offs_t # [BK] 

110 t_msk = t_ptr < topk_len 

111 t_ptr += t_base 

112 kv_ids = tl.load(t_ptr, t_msk, other=-1) 

113 mask_ids = (kv_ids < SKV) & (kv_ids >= 0) 

114 # filter invalid index that may cause overflow in mul 

115 kv_ids = tl.where(mask_ids, kv_ids, 0) 

116 

117 # step2: gather kv with indices 

118 kv_ptr = kv_base + offs_d[:, None] + kv_ids[None, :] * stride_kvn 

119 kv_blk0 = tl.load(kv_ptr, cache_modifier=".cg") # [BDP, BK] 

120 kv_blk1 = tl.load(kv_ptr + BDP, cache_modifier=".cg") # [BDP, BK] 

121 # step3: (q @ kv) * sm_scale 

122 qk = tl.dot( 

123 q_blk0, kv_blk0, out_dtype=tl.float32 

124 ) # [BH, BDP]@[BDP, BK] -> [BH, BK] 

125 qk = tl.dot(q_blk1, kv_blk1, qk, out_dtype=tl.float32) 

126 if DQK == 576: 

127 tkv_ptr = tkv_base + offs_td[:, None] + kv_ids[None, :] * stride_kvn 

128 tkv_blk = tl.load(tkv_ptr, cache_modifier=".cg") # [TDP, BK] 

129 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float32) 

130 qk *= sm_scale 

131 

132 # step4: preprocess for logsumexp 

133 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK] 

134 # step5: lse=logsumexp(qk), loop part 

135 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) # [BH] 

136 exp_qk = tl.math.exp(qk - new_max[:, None]) # [BH, BK] 

137 sum_qk = tl.sum(exp_qk, axis=1) # [BH] 

138 alpha = tl.math.exp(max_log - new_max) # [BH] 

139 sum_exp = sum_exp * alpha + sum_qk # [BH] 

140 # step6: exp(qk-lse) @ gathered_kv.trans(), loop part 

141 acc0 = tl.dot( 

142 exp_qk.to(tl.bfloat16), 

143 kv_blk0.trans(), 

144 acc0 * alpha[:, None], 

145 out_dtype=tl.float32, 

146 ) # [BH, BK]@[BK, BDP]->[BH, BDP] 

147 acc1 = tl.dot( 

148 exp_qk.to(tl.bfloat16), 

149 kv_blk1.trans(), 

150 acc1 * alpha[:, None], 

151 out_dtype=tl.float32, 

152 ) # [BH, BK]@[BK, BDP]->[BH, BDP] 

153 max_log = new_max 

154 

155 # step7: store max_logits 

156 valid_mask = max_log != float("-inf") 

157 max_log = tl.where(valid_mask, max_log, float("-inf")) 

158 tl.store(max_log_base + offs_h, max_log) # [BH], float32 

159 

160 # step8: lse=logsumexp(qk) final part, store lse 

161 orig_lse = max_log + tl.math.log(sum_exp) 

162 lse_out = tl.where(valid_mask, orig_lse, float("inf")) 

163 tl.store(l_base + offs_h, lse_out) # [BH], float32 

164 

165 # step9: exp(qk-lse) @ gathered_kv.trans(), final part 

166 if HAVE_ATTN_SINK: 

167 # step10: attn_sink 

168 sink = tl.load(attn_sink_ptr + offs_h) # [BH] 

169 sum_exp_new_lse = tl.math.exp(orig_lse) + tl.math.exp(sink) 

170 factor = tl.math.exp(max_log) / sum_exp_new_lse 

171 else: 

172 factor = 1.0 / sum_exp 

173 

174 out_vals0 = tl.where(valid_mask[:, None], acc0 * factor[:, None], 0.0) 

175 out_vals1 = tl.where(valid_mask[:, None], acc1 * factor[:, None], 0.0) 

176 # step11: store output 

177 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] # [BH, BDP] 

178 tl.store(o_ptr, out_vals0.to(tl.bfloat16)) 

179 tl.store(o_ptr + BDP, out_vals1.to(tl.bfloat16)) 

180 

181 

182if HAS_TLE_FLASHMLA_SPARSE: 

183 

184 @triton.jit 

185 def _tle_flashmla_prefill_producer( 

186 k0_l_writer, 

187 k0_r_writer, 

188 k1_l_writer, 

189 k1_r_writer, 

190 valid_writer, 

191 kv_base, 

192 tkv_base, 

193 t_base, 

194 topk_len_ptr, 

195 D: tl.constexpr, 

196 TD: tl.constexpr, 

197 DPH: tl.constexpr, 

198 TDP: tl.constexpr, 

199 VG: tl.constexpr, 

200 SKV, 

201 TOPK: tl.constexpr, 

202 HAVE_TOPK_LENGTH: tl.constexpr, 

203 HAVE_TAIL: tl.constexpr, 

204 BK: tl.constexpr, 

205 ): 

206 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK 

207 max_col = SKV - 1 

208 stride_kvn: tl.constexpr = VG * (TD + D) 

209 NK = tl.cdiv(topk_len, BK) 

210 NPAIRS = tl.cdiv(NK, 2) 

211 offs_t = tl.arange(0, BK) 

212 offs_tile = tl.arange(0, 64) 

213 kv_tile_rows = tl.broadcast_to(offs_t[:, None], (BK, 64)) 

214 for pair in tl.range(NPAIRS): 

215 ck0 = pair * 2 

216 ck1 = ck0 + 1 

217 t_offs0 = BK * ck0 + offs_t 

218 t_msk0 = t_offs0 < topk_len 

219 kv_ids0 = tl.load(t_base + t_offs0, t_msk0, other=-1) 

220 valid0 = t_msk0 & (kv_ids0 <= max_col) & (kv_ids0 >= 0) 

221 kv_offsets0 = tl.where(valid0, kv_ids0, 0).to(tl.int64) * stride_kvn 

222 

223 t_offs1 = BK * ck1 + offs_t 

224 t_msk1 = t_offs1 < topk_len 

225 kv_ids1 = tl.load(t_base + t_offs1, t_msk1, other=-1) 

226 valid1 = t_msk1 & (kv_ids1 <= max_col) & (kv_ids1 >= 0) 

227 kv_offsets1 = tl.where(valid1, kv_ids1, 0).to(tl.int64) * stride_kvn 

228 

229 k0_l_slot = k0_l_writer.acquire(pair) 

230 for tile in tl.static_range(0, DPH, 64): 

231 k_cols = tile + offs_tile 

232 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

233 k0_l_ptr = kv_base + kv_offsets0[:, None] + k_cols[None, :] 

234 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :] 

235 k0_l_blk = tl.load( 

236 k0_l_ptr, 

237 mask=k0_l_msk, 

238 other=0.0, 

239 eviction_policy="evict_last", 

240 ) 

241 tl.store( 

242 tle.gpu.local_ptr(k0_l_slot.sK, (kv_tile_rows, k_cols_b)), 

243 k0_l_blk, 

244 mask=k0_l_msk, 

245 ) 

246 k0_l_writer.commit(pair) 

247 

248 k1_r_slot = k1_r_writer.acquire(pair) 

249 for tile in tl.static_range(0, DPH, 64): 

250 k_cols = DPH + tile + offs_tile 

251 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

252 k1_r_ptr = kv_base + kv_offsets1[:, None] + k_cols[None, :] 

253 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :] 

254 k1_r_blk = tl.load( 

255 k1_r_ptr, 

256 mask=k1_r_msk, 

257 other=0.0, 

258 eviction_policy="evict_last", 

259 ) 

260 tl.store( 

261 tle.gpu.local_ptr(k1_r_slot.sK, (kv_tile_rows, k_cols_b)), 

262 k1_r_blk, 

263 mask=k1_r_msk, 

264 ) 

265 if HAVE_TAIL: 

266 offs_td = tl.arange(0, TDP) 

267 k1_r_tail_ptr = tkv_base + kv_offsets1[:, None] + offs_td[None, :] 

268 k1_r_tail_msk = valid1[:, None] & (offs_td < TD)[None, :] 

269 k1_r_tail_blk = tl.load( 

270 k1_r_tail_ptr, 

271 mask=k1_r_tail_msk, 

272 other=0.0, 

273 eviction_policy="evict_last", 

274 ) 

275 tl.store( 

276 tle.gpu.local_ptr(k1_r_slot.sK_tail), 

277 k1_r_tail_blk, 

278 mask=k1_r_tail_msk, 

279 ) 

280 k1_r_writer.commit(pair) 

281 

282 k0_r_slot = k0_r_writer.acquire(pair) 

283 for tile in tl.static_range(0, DPH, 64): 

284 k_cols = DPH + tile + offs_tile 

285 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

286 k0_r_ptr = kv_base + kv_offsets0[:, None] + k_cols[None, :] 

287 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :] 

288 k0_r_blk = tl.load( 

289 k0_r_ptr, 

290 mask=k0_r_msk, 

291 other=0.0, 

292 eviction_policy="evict_last", 

293 ) 

294 tl.store( 

295 tle.gpu.local_ptr(k0_r_slot.sK, (kv_tile_rows, k_cols_b)), 

296 k0_r_blk, 

297 mask=k0_r_msk, 

298 ) 

299 if HAVE_TAIL: 

300 offs_td = tl.arange(0, TDP) 

301 k0_r_tail_ptr = tkv_base + kv_offsets0[:, None] + offs_td[None, :] 

302 k0_r_tail_msk = valid0[:, None] & (offs_td < TD)[None, :] 

303 k0_r_tail_blk = tl.load( 

304 k0_r_tail_ptr, 

305 mask=k0_r_tail_msk, 

306 other=0.0, 

307 eviction_policy="evict_last", 

308 ) 

309 tl.store( 

310 tle.gpu.local_ptr(k0_r_slot.sK_tail), 

311 k0_r_tail_blk, 

312 mask=k0_r_tail_msk, 

313 ) 

314 k0_r_writer.commit(pair) 

315 

316 k1_l_slot = k1_l_writer.acquire(pair) 

317 for tile in tl.static_range(0, DPH, 64): 

318 k_cols = tile + offs_tile 

319 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

320 k1_l_ptr = kv_base + kv_offsets1[:, None] + k_cols[None, :] 

321 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :] 

322 k1_l_blk = tl.load( 

323 k1_l_ptr, 

324 mask=k1_l_msk, 

325 other=0.0, 

326 eviction_policy="evict_last", 

327 ) 

328 tl.store( 

329 tle.gpu.local_ptr(k1_l_slot.sK, (kv_tile_rows, k_cols_b)), 

330 k1_l_blk, 

331 mask=k1_l_msk, 

332 ) 

333 k1_l_writer.commit(pair) 

334 

335 valid_slot = valid_writer.acquire(pair) 

336 valid_row0 = tl.full([BK], 0, dtype=tl.int32) 

337 valid_row1 = tl.full([BK], 1, dtype=tl.int32) 

338 valid_ptr0 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row0, offs_t)) 

339 valid_ptr1 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row1, offs_t)) 

340 tl.store(valid_ptr0, valid0.to(tl.int8)) 

341 tl.store(valid_ptr1, valid1.to(tl.int8)) 

342 valid_writer.commit(pair) 

343 

344 @triton.jit 

345 def _tle_flashmla_prefill_consumer0( 

346 q_writer, 

347 q_reader, 

348 q_desc, 

349 tq_desc, 

350 k0_l_reader, 

351 k0_r_qk_reader, 

352 k1_l_remote_reader, 

353 valid_reader, 

354 sM_wg0_writer, 

355 sM_wg1_reader, 

356 sS0_writer, 

357 sS1_reader, 

358 sL_wg0_writer, 

359 sL_wg1_reader, 

360 output_desc, 

361 output_row, 

362 h_base, 

363 topk_len_ptr, 

364 attn_sink_base, 

365 log_scale: tl.constexpr, 

366 D: tl.constexpr, 

367 TD: tl.constexpr, 

368 OUT_DTYPE: tl.constexpr, 

369 HAVE_ATTN_SINK: tl.constexpr, 

370 TOPK: tl.constexpr, 

371 HAVE_TOPK_LENGTH: tl.constexpr, 

372 HAVE_TAIL: tl.constexpr, 

373 BK: tl.constexpr, 

374 BH: tl.constexpr, 

375 DPH: tl.constexpr, 

376 TDP: tl.constexpr, 

377 G: tl.constexpr, 

378 ): 

379 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK 

380 offs_h = tl.arange(0, BH) 

381 offs_dh = tl.arange(0, DPH) 

382 mask_h = h_base + offs_h < G 

383 mask_od_l = offs_dh < D 

384 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH)) 

385 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH)) 

386 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH)) 

387 

388 q_write_slot = q_writer.acquire(0) 

389 tle.gpu.copy(q_desc, q_write_slot.sQ_l, [BH, DPH], [output_row, 0]) 

390 tle.gpu.copy(q_desc, q_write_slot.sQ_r, [BH, DPH], [output_row, DPH]) 

391 if HAVE_TAIL: 

392 tle.gpu.copy(tq_desc, q_write_slot.sQ_tail, [BH, TDP], [output_row, D]) 

393 q_writer.commit(0) 

394 

395 q_slot = q_reader.wait(0).slot 

396 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l) 

397 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r) 

398 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32) 

399 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

400 acc_l = tl.zeros([BH, DPH], dtype=tl.float32) 

401 

402 NK = tl.cdiv(topk_len, BK) 

403 NPAIRS = tl.cdiv(NK, 2) 

404 

405 for pair in tl.range(NPAIRS): 

406 k0_l_wait = k0_l_reader.wait(pair) 

407 k0_l_slot = k0_l_wait.slot 

408 

409 q_l_blk = tl.load(q_l_smem_ptr) 

410 q_r_blk = tl.load(q_r_smem_ptr) 

411 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l))) 

412 

413 qk0 = tl.full([BH, BK], 0.0, dtype=tl.float32) 

414 qk0 = tl.dot(q_l_blk, tl.trans(k0_l_blk), qk0, out_dtype=tl.float32) 

415 

416 k0_r_wait = k0_r_qk_reader.wait(pair) 

417 k0_r_slot = k0_r_wait.slot 

418 k0_r_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK, (kv_rows, kv_cols_r))) 

419 qk0 = tl.dot(q_r_blk, tl.trans(k0_r_blk), qk0, out_dtype=tl.float32) 

420 if HAVE_TAIL: 

421 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail)) 

422 k0_t_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK_tail)) 

423 qk0 = tl.dot(q_tail_blk, tl.trans(k0_t_blk), qk0, out_dtype=tl.float32) 

424 

425 valid_wait = valid_reader.wait(pair) 

426 row0 = tl.full([BK], 0, dtype=tl.int32) 

427 valid0 = ( 

428 tl.load( 

429 tle.gpu.local_ptr( 

430 valid_wait.slot.is_kv_valid, (row0, tl.arange(0, BK)) 

431 ) 

432 ) 

433 != 0 

434 ) 

435 qk0 = tl.where(valid0[None, :], qk0, float("-inf")) 

436 valid_reader.release(pair) 

437 

438 local_max = tl.maximum(max_prev, tl.max(qk0, axis=1)) 

439 alpha = tl.math.exp2((max_prev - local_max) * log_scale) 

440 prob0 = tl.math.exp2(qk0 * log_scale - local_max[:, None] * log_scale) 

441 sum_exp = sum_exp * alpha + tl.sum(prob0, axis=1) 

442 acc_l = acc_l * alpha[:, None] 

443 prob0_b = prob0.to(OUT_DTYPE) 

444 

445 sM_wg0_slot = sM_wg0_writer.acquire(pair) 

446 tl.store(tle.gpu.local_ptr(sM_wg0_slot.sM), local_max) 

447 sM_wg0_writer.commit(pair) 

448 

449 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l))) 

450 acc_l = tl.dot(prob0_b, k0_l_blk, acc_l, out_dtype=tl.float32) 

451 k0_l_reader.release(pair) 

452 k0_r_qk_reader.release(pair) 

453 

454 sM_wg1_wait = sM_wg1_reader.wait(pair) 

455 max_next = tl.load(tle.gpu.local_ptr(sM_wg1_wait.slot.sM)) 

456 sM_wg1_reader.release(pair) 

457 

458 final_scale = tl.math.exp2((local_max - max_next) * log_scale) 

459 sum_exp = sum_exp * final_scale 

460 acc_l = acc_l * final_scale[:, None] 

461 

462 prob0_scaled = prob0 * final_scale[:, None] 

463 sS0_slot = sS0_writer.acquire(pair) 

464 tl.store(tle.gpu.local_ptr(sS0_slot.sS0), prob0_scaled.to(OUT_DTYPE)) 

465 sS0_writer.commit(pair) 

466 

467 sS1_wait = sS1_reader.wait(pair) 

468 prob1 = tl.load(tle.gpu.local_ptr(sS1_wait.slot.sS1)) 

469 k1_l_wait = k1_l_remote_reader.wait(pair) 

470 k1_l_blk = tl.load( 

471 tle.gpu.local_ptr(k1_l_wait.slot.sK, (kv_rows, kv_cols_l)) 

472 ) 

473 acc_l = tl.dot(prob1, k1_l_blk, acc_l, out_dtype=tl.float32) 

474 sS1_reader.release(pair) 

475 k1_l_remote_reader.release(pair) 

476 

477 max_prev = max_next 

478 

479 sL_wg0_slot = sL_wg0_writer.acquire(0) 

480 tl.store(tle.gpu.local_ptr(sL_wg0_slot.sL), sum_exp) 

481 sL_wg0_writer.commit(0) 

482 sL_wg1_wait = sL_wg1_reader.wait(1) 

483 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg1_wait.slot.sL)) 

484 total_sum = sum_exp + peer_sum 

485 sL_wg1_reader.release(1) 

486 

487 is_no_valid_tokens = total_sum == 0.0 

488 inv_total_sum = tl.fdiv(1.0, total_sum) 

489 out_l_vals = acc_l * inv_total_sum[:, None] 

490 if HAVE_ATTN_SINK: 

491 fin_log = ( 

492 max_prev * log_scale + tl.math.log2(total_sum) 

493 ) * 0.6931471805599453 

494 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0) 

495 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log)) 

496 out_l_vals = out_l_vals * sink_scale[:, None] 

497 out_l_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_l_vals) 

498 o_l_msk = mask_h[:, None] & mask_od_l[None, :] 

499 tl.store(q_l_smem_ptr, out_l_vals.to(OUT_DTYPE), o_l_msk) 

500 tle.gpu.copy(q_slot.sQ_l, output_desc, [BH, DPH], [output_row, 0]) 

501 

502 @triton.jit 

503 def _tle_flashmla_prefill_consumer1( 

504 q_reader, 

505 k1_r_reader, 

506 k1_l_qk_reader, 

507 k0_r_remote_reader, 

508 valid_reader, 

509 sM_wg1_writer, 

510 sM_wg0_reader, 

511 sS1_writer, 

512 sS0_reader, 

513 sL_wg1_writer, 

514 sL_wg0_reader, 

515 final_max_logits_smem, 

516 final_lse_smem, 

517 output_desc, 

518 output_row, 

519 max_logits_base, 

520 l_base, 

521 h_base, 

522 topk_len_ptr, 

523 attn_sink_base, 

524 log_scale: tl.constexpr, 

525 D: tl.constexpr, 

526 TD: tl.constexpr, 

527 OUT_DTYPE: tl.constexpr, 

528 HAVE_ATTN_SINK: tl.constexpr, 

529 TOPK: tl.constexpr, 

530 HAVE_TOPK_LENGTH: tl.constexpr, 

531 HAVE_TAIL: tl.constexpr, 

532 BK: tl.constexpr, 

533 BH: tl.constexpr, 

534 DPH: tl.constexpr, 

535 TDP: tl.constexpr, 

536 G: tl.constexpr, 

537 ): 

538 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK 

539 offs_h = tl.arange(0, BH) 

540 offs_dh = tl.arange(0, DPH) 

541 mask_h = h_base + offs_h < G 

542 mask_od_r = DPH + offs_dh < D 

543 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH)) 

544 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH)) 

545 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH)) 

546 q_slot = q_reader.wait(0).slot 

547 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l) 

548 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r) 

549 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32) 

550 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

551 acc_r = tl.zeros([BH, DPH], dtype=tl.float32) 

552 

553 NK = tl.cdiv(topk_len, BK) 

554 NPAIRS = tl.cdiv(NK, 2) 

555 for pair in tl.range(NPAIRS): 

556 k1_r_wait = k1_r_reader.wait(pair) 

557 k1_r_slot = k1_r_wait.slot 

558 

559 q_l_blk = tl.load(q_l_smem_ptr) 

560 q_r_blk = tl.load(q_r_smem_ptr) 

561 k1_r_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK, (kv_rows, kv_cols_r))) 

562 

563 qk1 = tl.full([BH, BK], 0.0, dtype=tl.float32) 

564 qk1 = tl.dot(q_r_blk, tl.trans(k1_r_blk), qk1, out_dtype=tl.float32) 

565 if HAVE_TAIL: 

566 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail)) 

567 k1_t_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK_tail)) 

568 qk1 = tl.dot(q_tail_blk, tl.trans(k1_t_blk), qk1, out_dtype=tl.float32) 

569 k1_l_wait = k1_l_qk_reader.wait(pair) 

570 k1_l_slot = k1_l_wait.slot 

571 k1_l_blk = tl.load(tle.gpu.local_ptr(k1_l_slot.sK, (kv_rows, kv_cols_l))) 

572 qk1 = tl.dot(q_l_blk, tl.trans(k1_l_blk), qk1, out_dtype=tl.float32) 

573 

574 valid_wait = valid_reader.wait(pair) 

575 row1 = tl.full([BK], 1, dtype=tl.int32) 

576 valid1 = ( 

577 tl.load( 

578 tle.gpu.local_ptr( 

579 valid_wait.slot.is_kv_valid, (row1, tl.arange(0, BK)) 

580 ) 

581 ) 

582 != 0 

583 ) 

584 qk1 = tl.where(valid1[None, :], qk1, float("-inf")) 

585 valid_reader.release(pair) 

586 

587 sM_wg0_wait = sM_wg0_reader.wait(pair) 

588 candidate0 = tl.load(tle.gpu.local_ptr(sM_wg0_wait.slot.sM)) 

589 sM_wg0_reader.release(pair) 

590 

591 candidate1 = tl.maximum(max_prev, tl.max(qk1, axis=1)) 

592 max_next = tl.maximum(candidate1, candidate0) 

593 sM_wg1_slot = sM_wg1_writer.acquire(pair) 

594 tl.store(tle.gpu.local_ptr(sM_wg1_slot.sM), max_next) 

595 sM_wg1_writer.commit(pair) 

596 

597 alpha = tl.math.exp2((max_prev - max_next) * log_scale) 

598 prob1 = tl.math.exp2(qk1 * log_scale - max_next[:, None] * log_scale) 

599 sum_exp = sum_exp * alpha + tl.sum(prob1, axis=1) 

600 acc_r = acc_r * alpha[:, None] 

601 prob1_b = prob1.to(OUT_DTYPE) 

602 

603 k1_l_qk_reader.release(pair) 

604 

605 acc_r = tl.dot(prob1_b, k1_r_blk, acc_r, out_dtype=tl.float32) 

606 

607 sS1_slot = sS1_writer.acquire(pair) 

608 tl.store(tle.gpu.local_ptr(sS1_slot.sS1), prob1_b) 

609 sS1_writer.commit(pair) 

610 

611 sS0_wait = sS0_reader.wait(pair) 

612 prob0 = tl.load(tle.gpu.local_ptr(sS0_wait.slot.sS0)) 

613 k0_r_wait = k0_r_remote_reader.wait(pair) 

614 k0_r_blk = tl.load( 

615 tle.gpu.local_ptr(k0_r_wait.slot.sK, (kv_rows, kv_cols_r)) 

616 ) 

617 acc_r = tl.dot(prob0, k0_r_blk, acc_r, out_dtype=tl.float32) 

618 k1_r_reader.release(pair) 

619 sS0_reader.release(pair) 

620 k0_r_remote_reader.release(pair) 

621 max_prev = max_next 

622 

623 sL_wg1_slot = sL_wg1_writer.acquire(1) 

624 tl.store(tle.gpu.local_ptr(sL_wg1_slot.sL), sum_exp) 

625 sL_wg1_writer.commit(1) 

626 sL_wg0_wait = sL_wg0_reader.wait(0) 

627 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg0_wait.slot.sL)) 

628 total_sum = sum_exp + peer_sum 

629 sL_wg0_reader.release(0) 

630 

631 is_no_valid_tokens = total_sum == 0.0 

632 inv_total_sum = tl.fdiv(1.0, total_sum) 

633 out_r_vals = acc_r * inv_total_sum[:, None] 

634 final_max_logits_log2 = max_prev * log_scale 

635 final_max_logits = final_max_logits_log2 * 0.6931471805599453 

636 fin_log = (final_max_logits_log2 + tl.math.log2(total_sum)) * 0.6931471805599453 

637 if HAVE_ATTN_SINK: 

638 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0) 

639 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log)) 

640 out_r_vals = out_r_vals * sink_scale[:, None] 

641 out_r_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_r_vals) 

642 o_r_msk = mask_h[:, None] & mask_od_r[None, :] 

643 tl.store(q_r_smem_ptr, out_r_vals.to(OUT_DTYPE), o_r_msk) 

644 tle.gpu.copy(q_slot.sQ_r, output_desc, [BH, DPH], [output_row, DPH]) 

645 

646 final_max_logits = tl.where(is_no_valid_tokens, float("-inf"), final_max_logits) 

647 fin_log = tl.where(is_no_valid_tokens, float("inf"), fin_log) 

648 tl.store(tle.gpu.local_ptr(final_max_logits_smem), final_max_logits, mask_h) 

649 tl.store(tle.gpu.local_ptr(final_lse_smem), fin_log, mask_h) 

650 final_max_logits = tl.load( 

651 tle.gpu.local_ptr(final_max_logits_smem), mask_h, other=float("-inf") 

652 ) 

653 fin_log = tl.load(tle.gpu.local_ptr(final_lse_smem), mask_h, other=float("inf")) 

654 tl.store(max_logits_base + offs_h, final_max_logits, mask_h) 

655 tl.store(l_base + offs_h, fin_log, mask_h) 

656 

657 @triton.jit 

658 def _tle_flashmla_prefill_fwd( 

659 q_desc, 

660 tq_desc, 

661 output_desc, 

662 kv, 

663 indices, 

664 attn_sink, 

665 topk_length, 

666 sm_scale: tl.constexpr, 

667 output, 

668 max_logits, 

669 lse, 

670 SQ, 

671 H: tl.constexpr, 

672 DQK: tl.constexpr, 

673 SKV, 

674 TOPK: tl.constexpr, 

675 HAVE_ATTN_SINK: tl.constexpr, 

676 HAVE_TOPK_LENGTH: tl.constexpr, 

677 D: tl.constexpr, 

678 TD: tl.constexpr, 

679 DP: tl.constexpr, 

680 TDP: tl.constexpr, 

681 G: tl.constexpr, 

682 VG: tl.constexpr, 

683 RH: tl.constexpr, 

684 HAVE_TAIL: tl.constexpr, 

685 BK: tl.constexpr, 

686 BH: tl.constexpr, 

687 PAIR_BLOCKS: tl.constexpr, 

688 ): 

689 DPH: tl.constexpr = DP // 2 

690 stride_kvg: tl.constexpr = TD + D 

691 stride_tg = TOPK 

692 stride_tm = VG * stride_tg 

693 stride_lm = H 

694 stride_mm = H 

695 

696 pid = tl.program_id(0) 

697 programs_per_q: tl.constexpr = VG * RH 

698 i_sq = pid // programs_per_q 

699 i_grh = pid % programs_per_q 

700 i_g = i_grh // RH 

701 i_rh = i_grh % RH 

702 h_base = i_rh * BH 

703 q_head_base = i_g * G + h_base 

704 i_sq64 = i_sq.to(tl.int64) 

705 i_g64 = i_g.to(tl.int64) 

706 q_head_base64 = q_head_base.to(tl.int64) 

707 kv_base = kv + i_g64 * stride_kvg 

708 tkv_base = kv_base + D 

709 t_base = indices + i_sq64 * stride_tm + i_g64 * stride_tg 

710 topk_len_ptr = topk_length + i_sq64 if HAVE_TOPK_LENGTH else indices 

711 attn_sink_base = attn_sink if HAVE_ATTN_SINK else max_logits 

712 max_logits_base = max_logits + i_sq64 * stride_mm + q_head_base64 

713 l_base = lse + i_sq64 * stride_lm + q_head_base64 

714 q_row = i_sq * H + q_head_base 

715 _ = output 

716 _ = SQ 

717 _ = DQK 

718 

719 sQ_l_smem = tle.gpu.alloc( 

720 [1, BH, DPH], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem 

721 ) 

722 sQ_r_smem = tle.gpu.alloc( 

723 [1, BH, DPH], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem 

724 ) 

725 if HAVE_TAIL: 

726 sQ_tail_smem = tle.gpu.alloc( 

727 [1, BH, TDP], 

728 dtype=kv.dtype.element_ty, 

729 layout=None, 

730 scope=tle.gpu.smem, 

731 ) 

732 q_pipe = tle.pipe( 

733 capacity=1, 

734 scope="cta", 

735 name="flashmla_sQ", 

736 readers=("wg0", "wg1"), 

737 one_shot=True, 

738 sQ_l=sQ_l_smem, 

739 sQ_r=sQ_r_smem, 

740 sQ_tail=sQ_tail_smem, 

741 ) 

742 else: 

743 q_pipe = tle.pipe( 

744 capacity=1, 

745 scope="cta", 

746 name="flashmla_sQ", 

747 readers=("wg0", "wg1"), 

748 one_shot=True, 

749 sQ_l=sQ_l_smem, 

750 sQ_r=sQ_r_smem, 

751 ) 

752 

753 sK0_smem = tle.gpu.alloc( 

754 [1, BK, DP], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem 

755 ) 

756 sK1_smem = tle.gpu.alloc( 

757 [1, BK, DP], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem 

758 ) 

759 if HAVE_TAIL: 

760 sK0_tail_smem = tle.gpu.alloc( 

761 [1, BK, TDP], 

762 dtype=kv.dtype.element_ty, 

763 layout=None, 

764 scope=tle.gpu.smem, 

765 ) 

766 sK1_tail_smem = tle.gpu.alloc( 

767 [1, BK, TDP], 

768 dtype=kv.dtype.element_ty, 

769 layout=None, 

770 scope=tle.gpu.smem, 

771 ) 

772 sS0_smem = sK0_tail_smem 

773 else: 

774 sS0_smem = tle.gpu.alloc( 

775 [1, BH, BK], 

776 dtype=kv.dtype.element_ty, 

777 layout=None, 

778 scope=tle.gpu.smem, 

779 ) 

780 is_kv_valid_smem = tle.gpu.alloc( 

781 [1, PAIR_BLOCKS, BK], 

782 dtype=tl.int8, 

783 layout=None, 

784 scope=tle.gpu.smem, 

785 nv_mma_shared_layout=False, 

786 ) 

787 k0_l_pipe = tle.pipe( 

788 capacity=1, scope="cta", name="flashmla_sK0_l", sK=sK0_smem 

789 ) 

790 if HAVE_TAIL: 

791 k0_r_pipe = tle.pipe( 

792 capacity=1, 

793 scope="cta", 

794 name="flashmla_sK0_r", 

795 readers=("qk", "remote"), 

796 sK=sK0_smem, 

797 sK_tail=sK0_tail_smem, 

798 ) 

799 else: 

800 k0_r_pipe = tle.pipe( 

801 capacity=1, 

802 scope="cta", 

803 name="flashmla_sK0_r", 

804 readers=("qk", "remote"), 

805 sK=sK0_smem, 

806 ) 

807 k1_l_pipe = tle.pipe( 

808 capacity=1, 

809 scope="cta", 

810 name="flashmla_sK1_l", 

811 readers=("qk", "remote"), 

812 sK=sK1_smem, 

813 ) 

814 if HAVE_TAIL: 

815 k1_r_pipe = tle.pipe( 

816 capacity=1, 

817 scope="cta", 

818 name="flashmla_sK1_r", 

819 sK=sK1_smem, 

820 sK_tail=sK1_tail_smem, 

821 ) 

822 else: 

823 k1_r_pipe = tle.pipe( 

824 capacity=1, 

825 scope="cta", 

826 name="flashmla_sK1_r", 

827 sK=sK1_smem, 

828 ) 

829 is_kv_valid_pipe = tle.pipe( 

830 capacity=1, 

831 scope="cta", 

832 name="flashmla_is_kv_valid_ready", 

833 readers=("wg0", "wg1"), 

834 is_kv_valid=is_kv_valid_smem, 

835 ) 

836 

837 sM_smem = tle.gpu.alloc( 

838 [1, BH], 

839 dtype=tl.float32, 

840 layout=None, 

841 scope=tle.gpu.smem, 

842 nv_mma_shared_layout=False, 

843 ) 

844 sS1_smem = tle.gpu.alloc( 

845 [1, BH, BK], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem 

846 ) 

847 sL_smem = tle.gpu.alloc( 

848 [2, BH], 

849 dtype=tl.float32, 

850 layout=None, 

851 scope=tle.gpu.smem, 

852 nv_mma_shared_layout=False, 

853 ) 

854 final_max_logits_smem = tle.gpu.alloc( 

855 [BH], 

856 dtype=tl.float32, 

857 layout=None, 

858 scope=tle.gpu.smem, 

859 nv_mma_shared_layout=False, 

860 ) 

861 final_lse_smem = tle.gpu.alloc( 

862 [BH], 

863 dtype=tl.float32, 

864 layout=None, 

865 scope=tle.gpu.smem, 

866 nv_mma_shared_layout=False, 

867 ) 

868 sM_wg0_pipe = tle.pipe( 

869 capacity=1, scope="cta", name="flashmla_wg0_bunch_0_ready", sM=sM_smem 

870 ) 

871 sM_wg1_pipe = tle.pipe( 

872 capacity=1, scope="cta", name="flashmla_wg1_bunch_0_ready", sM=sM_smem 

873 ) 

874 sS0_pipe = tle.pipe(capacity=1, scope="cta", name="flashmla_sS0", sS0=sS0_smem) 

875 sS1_pipe = tle.pipe(capacity=1, scope="cta", name="flashmla_sS1", sS1=sS1_smem) 

876 sL_wg0_pipe = tle.pipe( 

877 capacity=2, scope="cta", name="flashmla_sL_wg0", sL=sL_smem 

878 ) 

879 sL_wg1_pipe = tle.pipe( 

880 capacity=2, scope="cta", name="flashmla_sL_wg1", sL=sL_smem 

881 ) 

882 

883 log_scale: tl.constexpr = sm_scale * 1.4426950408889634 

884 

885 tle.gpu.warp_specialize( 

886 [ 

887 ( 

888 _tle_flashmla_prefill_consumer0, 

889 ( 

890 q_pipe.writer(), 

891 q_pipe.reader("wg0"), 

892 q_desc, 

893 tq_desc, 

894 k0_l_pipe.reader(), 

895 k0_r_pipe.reader("qk"), 

896 k1_l_pipe.reader("remote", fields=("sK",)), 

897 is_kv_valid_pipe.reader("wg0"), 

898 sM_wg0_pipe.writer(), 

899 sM_wg1_pipe.reader(), 

900 sS0_pipe.writer(), 

901 sS1_pipe.reader(), 

902 sL_wg0_pipe.writer(), 

903 sL_wg1_pipe.reader(), 

904 output_desc, 

905 q_row, 

906 h_base, 

907 topk_len_ptr, 

908 attn_sink_base, 

909 log_scale, 

910 D, 

911 TD, 

912 kv.dtype.element_ty, 

913 HAVE_ATTN_SINK, 

914 TOPK, 

915 HAVE_TOPK_LENGTH, 

916 HAVE_TAIL, 

917 BK, 

918 BH, 

919 DPH, 

920 TDP, 

921 G, 

922 ), 

923 ), 

924 ( 

925 _tle_flashmla_prefill_consumer1, 

926 ( 

927 q_pipe.reader("wg1"), 

928 k1_r_pipe.reader(), 

929 k1_l_pipe.reader("qk"), 

930 k0_r_pipe.reader("remote", fields=("sK",)), 

931 is_kv_valid_pipe.reader("wg1"), 

932 sM_wg1_pipe.writer(), 

933 sM_wg0_pipe.reader(), 

934 sS1_pipe.writer(), 

935 sS0_pipe.reader(), 

936 sL_wg1_pipe.writer(), 

937 sL_wg0_pipe.reader(), 

938 final_max_logits_smem, 

939 final_lse_smem, 

940 output_desc, 

941 q_row, 

942 max_logits_base, 

943 l_base, 

944 h_base, 

945 topk_len_ptr, 

946 attn_sink_base, 

947 log_scale, 

948 D, 

949 TD, 

950 kv.dtype.element_ty, 

951 HAVE_ATTN_SINK, 

952 TOPK, 

953 HAVE_TOPK_LENGTH, 

954 HAVE_TAIL, 

955 BK, 

956 BH, 

957 DPH, 

958 TDP, 

959 G, 

960 ), 

961 ), 

962 ( 

963 _tle_flashmla_prefill_producer, 

964 ( 

965 k0_l_pipe.writer(), 

966 k0_r_pipe.writer(), 

967 k1_l_pipe.writer(), 

968 k1_r_pipe.writer(), 

969 is_kv_valid_pipe.writer(), 

970 kv_base, 

971 tkv_base, 

972 t_base, 

973 topk_len_ptr, 

974 D, 

975 TD, 

976 DPH, 

977 TDP, 

978 VG, 

979 SKV, 

980 TOPK, 

981 HAVE_TOPK_LENGTH, 

982 HAVE_TAIL, 

983 BK, 

984 ), 

985 ), 

986 ], 

987 [4, 4], 

988 [216, 72], 

989 ) 

990 

991 

992def _flash_mla_sparse_tle_enabled() -> bool: 

993 value = os.environ.get("FLAGGEMS_FLASHMLA_SPARSE_TLE", "1").lower() 

994 return value not in {"0", "false", "off", "no"} 

995 

996 

997def _can_use_tle_flash_mla_sparse_fwd( 

998 q: torch.Tensor, 

999 kv: torch.Tensor, 

1000 indices: torch.Tensor, 

1001 d_v: int, 

1002 topk_length: Optional[torch.Tensor] = None, 

1003) -> bool: 

1004 if not (HAS_TLE_FLASHMLA_SPARSE and _flash_mla_sparse_tle_enabled()): 

1005 return False 

1006 if q.device.type != "cuda": 

1007 return False 

1008 SQ, HQ, DQK = q.shape 

1009 _ = SQ 

1010 HKV = kv.shape[1] 

1011 TOPK = indices.shape[-1] 

1012 return ( 

1013 d_v == 512 

1014 and HKV == 1 

1015 and DQK in (512, 576) 

1016 and HQ % TLE_FLASHMLA_PREFILL_BH == 0 

1017 and TOPK > 0 

1018 and TOPK % 128 == 0 

1019 ) 

1020 

1021 

1022def _set_triton_descriptor_allocator(device: torch.device) -> None: 

1023 def alloc_fn(size: int, align: int, stream): 

1024 _ = align 

1025 _ = stream 

1026 return torch.empty(size, dtype=torch.int8, device=device) 

1027 

1028 triton.set_allocator(alloc_fn) 

1029 

1030 

1031def flash_mla_sparse_fwd( 

1032 q: torch.Tensor, 

1033 kv: torch.Tensor, 

1034 indices: torch.Tensor, 

1035 sm_scale: float, 

1036 d_v: int = 512, 

1037 attn_sink: Optional[torch.Tensor] = None, 

1038 topk_length: Optional[torch.Tensor] = None, 

1039) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

1040 """ 

1041 Sparse attention prefill kernel 

1042 

1043 Args: 

1044 q: [s_q, h_q, d_qk], bfloat16 

1045 kv: [s_kv, h_kv, d_qk], bfloat16 

1046 indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv 

1047 sm_scale: float 

1048 d_v: The dimension of value vectors. Can only be 512 

1049 attn_sink: optional, [h_q], float32. 

1050 If attn_sink is provided, when computing output, output will be additionally multiplied by 

1051 exp(lse) / (exp(lse) + exp(attn_sink)). +-inf in attn_sink will be handled normally (i.e., -inf has no 

1052 effect, +inf will make corresponding output all zeros). 

1053 This argument has no effect on lse and max_logits. 

1054 topk_length: optional, [s_q], int32. 

1055 If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], 

1056 ignoring later k/v tokens (even if provided in indices). In extremely rare cases (topk_length provided, 

1057 there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token 

1058 containing NaN), operator output will contain NaN, so please avoid this situation. 

1059 

1060 Returns: 

1061 (output, max_logits, lse) 

1062 Please refer to tests/ref.py for the precise definitions of these parameters. 

1063 - output: [s_q, h_q, d_v], bfloat16 

1064 - max_logits: [s_q, h_q], float 

1065 - lse: [s_q, h_q], float, log-sum-exp of attention scores 

1066 """ 

1067 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() 

1068 assert ( 

1069 q.dtype == torch.bfloat16 

1070 and kv.dtype == torch.bfloat16 

1071 and indices.dtype == torch.int32 

1072 ) 

1073 SQ, HQ, DQK = q.shape 

1074 SKV, HKV, _ = kv.shape 

1075 

1076 assert d_v == 512, "Unsupported d_v" 

1077 DV = d_v 

1078 

1079 assert kv.shape[-1] == DQK 

1080 _, _, TOPK = indices.shape 

1081 assert indices.shape == (SQ, HKV, TOPK) 

1082 if attn_sink is not None: 

1083 assert attn_sink.is_contiguous() 

1084 assert attn_sink.dtype == torch.float32 

1085 assert attn_sink.shape == (HQ,), "attn_sink error shape" 

1086 if topk_length is not None: 

1087 assert topk_length.is_contiguous() 

1088 assert topk_length.dtype == torch.int32 

1089 assert topk_length.shape == (SQ,), "topk_length error shape" 

1090 

1091 # check from FlashMLA 

1092 assert HKV == 1, "h_kv is expected to be 1" 

1093 assert HQ == 64 or HQ == 128, "Unsupported h_q" 

1094 assert DQK == 576 or DQK == 512, "Unsupported d_qk" 

1095 

1096 _ = SKV 

1097 D = DV 

1098 TD = DQK - D 

1099 DP = triton.next_power_of_2(D) 

1100 HAVE_TAIL = TD > 0 

1101 TDP = triton.next_power_of_2(TD) if HAVE_TAIL else 1 

1102 G = HQ // HKV 

1103 BH = TLE_FLASHMLA_PREFILL_BH 

1104 RH = G // BH 

1105 BK = TLE_FLASHMLA_PREFILL_BK 

1106 output = torch.empty((SQ, HQ, DV), device=q.device, dtype=q.dtype) 

1107 max_logits = torch.empty((SQ, HQ), device=q.device, dtype=torch.float32) 

1108 lse = torch.empty((SQ, HQ), device=q.device, dtype=torch.float32) 

1109 

1110 def triton_grid(META): 

1111 return (triton.cdiv(HQ, META["BH"]) * SQ,) 

1112 

1113 if _can_use_tle_flash_mla_sparse_fwd(q, kv, indices, d_v, topk_length): 

1114 from triton.tools.tensor_descriptor import TensorDescriptor 

1115 

1116 _set_triton_descriptor_allocator(q.device) 

1117 q_desc = TensorDescriptor( 

1118 q, shape=[SQ * HQ, DQK], strides=[DQK, 1], block_shape=[BH, DP // 2] 

1119 ) 

1120 if HAVE_TAIL: 

1121 tq_desc = TensorDescriptor( 

1122 q, shape=[SQ * HQ, DQK], strides=[DQK, 1], block_shape=[BH, TDP] 

1123 ) 

1124 else: 

1125 tq_desc = q_desc 

1126 output_desc = TensorDescriptor( 

1127 output, shape=[SQ * HQ, D], strides=[D, 1], block_shape=[BH, DP // 2] 

1128 ) 

1129 _tle_flashmla_prefill_fwd[triton_grid]( 

1130 q_desc, 

1131 tq_desc, 

1132 output_desc, 

1133 kv, 

1134 indices, 

1135 attn_sink, 

1136 topk_length, 

1137 sm_scale, 

1138 output, 

1139 max_logits, 

1140 lse, 

1141 SQ, 

1142 HQ, 

1143 DQK, 

1144 SKV, 

1145 TOPK, 

1146 attn_sink is not None, 

1147 topk_length is not None, 

1148 D, 

1149 TD, 

1150 DP, 

1151 TDP, 

1152 G, 

1153 HKV, 

1154 RH, 

1155 HAVE_TAIL, 

1156 BK, 

1157 BH, 

1158 TLE_FLASHMLA_PREFILL_PAIR_BLOCKS, 

1159 num_warps=TLE_FLASHMLA_PREFILL_WORKER_NUM_WARPS, 

1160 num_stages=1, 

1161 ) 

1162 return output, max_logits, lse 

1163 

1164 triton_flash_mla_sparse_fwd[triton_grid]( 

1165 q, 

1166 kv, 

1167 indices, 

1168 attn_sink, 

1169 topk_length, 

1170 sm_scale, 

1171 output, 

1172 max_logits, 

1173 lse, 

1174 q.stride(1), 

1175 q.stride(0), 

1176 kv.stride(1), 

1177 kv.stride(0), 

1178 indices.stride(1), 

1179 indices.stride(0), 

1180 output.stride(1), 

1181 output.stride(0), 

1182 max_logits.stride(0), 

1183 lse.stride(0), 

1184 SQ, 

1185 HQ, 

1186 DQK, 

1187 SKV, 

1188 TOPK, 

1189 attn_sink is not None, 

1190 topk_length is not None, 

1191 ) 

1192 return output, max_logits, lse