Coverage for src/flag_gems/fused/flashmla_sparse.py: 11%
501 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import os
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils.triton_version_utils import has_triton_tle
10if has_triton_tle(3, 6, 0):
11 try:
12 import triton.experimental.tle.language as tle
14 HAS_TLE_FLASHMLA_SPARSE = True
15 except ImportError:
16 tle = None
17 HAS_TLE_FLASHMLA_SPARSE = False
18else:
19 tle = None
20 HAS_TLE_FLASHMLA_SPARSE = False
23TLE_FLASHMLA_PREFILL_BK = 64
24TLE_FLASHMLA_PREFILL_BH = 64
25TLE_FLASHMLA_PREFILL_PAIR_BLOCKS = 2
26TLE_FLASHMLA_PREFILL_WORKER_NUM_WARPS = 4
29@triton.autotune(
30 configs=[
31 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=2),
32 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=4),
33 ],
34 key=["SQ", "HQ", "DQK", "SKV", "TOPK", "HAVE_ATTN_SINK", "HAVE_TOPK_LENGTH"],
35)
36@triton.jit
37def triton_flash_mla_sparse_fwd(
38 q,
39 kv,
40 indices,
41 attn_sink,
42 topk_length,
43 sm_scale: tl.constexpr,
44 output,
45 max_logits,
46 lse,
47 stride_qh,
48 stride_qm,
49 stride_kvg,
50 stride_kvn,
51 stride_tg,
52 stride_tm,
53 stride_oh,
54 stride_om,
55 stride_mm,
56 stride_lm,
57 SQ, # s_q
58 HQ: tl.constexpr, # h_q=64 or 128
59 DQK: tl.constexpr, # d_qk=512 or 576
60 SKV, # s_kv
61 TOPK: tl.constexpr, # topk
62 HAVE_ATTN_SINK: tl.constexpr,
63 HAVE_TOPK_LENGTH: tl.constexpr,
64 BK: tl.constexpr,
65 BH: tl.constexpr,
66):
67 num_head_blocks: tl.constexpr = (HQ + BH - 1) // BH
68 pid = tl.program_id(0)
69 i_sq = pid // num_head_blocks
70 i_sq = i_sq.to(tl.int64) # prevent mul overflow
71 i_gbh = pid % num_head_blocks
72 gbh_base = i_gbh * BH
73 DP: tl.constexpr = 512
74 BDP: tl.constexpr = 256
76 q_base = q + i_sq * stride_qm + gbh_base * stride_qh
77 kv_base = kv
78 tkv_base = kv + DP
79 t_base = indices + i_sq * stride_tm
80 attn_sink_ptr = attn_sink + gbh_base if HAVE_ATTN_SINK else 0
81 topk_length_ptr = topk_length + i_sq if HAVE_TOPK_LENGTH else 0
82 o_base = output + i_sq * stride_om + gbh_base * stride_oh
83 max_log_base = max_logits + i_sq * stride_mm + gbh_base
84 l_base = lse + i_sq * stride_lm + gbh_base
86 offs_h = tl.arange(0, BH)
87 offs_d = tl.arange(0, BDP)
88 if DQK == 576:
89 offs_td = tl.arange(0, 64)
90 offs_t = tl.arange(0, BK)
92 # `[BH, 256] x 2` delivers better performance than `[BH, 512]` when BH=64
93 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :]
94 q_blk0 = tl.load(q_ptr, eviction_policy="evict_first")
95 q_blk1 = tl.load(q_ptr + BDP, eviction_policy="evict_first")
96 if DQK == 576:
97 tq_ptr = q_base + DP + offs_h[:, None] * stride_qh + offs_td[None, :]
98 tq_blk = tl.load(tq_ptr, eviction_policy="evict_first")
100 max_log = tl.full([BH], float("-inf"), dtype=tl.float32)
101 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
102 acc0 = tl.zeros([BH, BDP], dtype=tl.float32)
103 acc1 = tl.zeros([BH, BDP], dtype=tl.float32)
105 topk_len = tl.load(topk_length_ptr) if HAVE_TOPK_LENGTH else TOPK
106 NK = tl.cdiv(topk_len, BK)
107 for ck in range(NK):
108 # step1: load indices
109 t_ptr = BK * ck + offs_t # [BK]
110 t_msk = t_ptr < topk_len
111 t_ptr += t_base
112 kv_ids = tl.load(t_ptr, t_msk, other=-1)
113 mask_ids = (kv_ids < SKV) & (kv_ids >= 0)
114 # filter invalid index that may cause overflow in mul
115 kv_ids = tl.where(mask_ids, kv_ids, 0)
117 # step2: gather kv with indices
118 kv_ptr = kv_base + offs_d[:, None] + kv_ids[None, :] * stride_kvn
119 kv_blk0 = tl.load(kv_ptr, cache_modifier=".cg") # [BDP, BK]
120 kv_blk1 = tl.load(kv_ptr + BDP, cache_modifier=".cg") # [BDP, BK]
121 # step3: (q @ kv) * sm_scale
122 qk = tl.dot(
123 q_blk0, kv_blk0, out_dtype=tl.float32
124 ) # [BH, BDP]@[BDP, BK] -> [BH, BK]
125 qk = tl.dot(q_blk1, kv_blk1, qk, out_dtype=tl.float32)
126 if DQK == 576:
127 tkv_ptr = tkv_base + offs_td[:, None] + kv_ids[None, :] * stride_kvn
128 tkv_blk = tl.load(tkv_ptr, cache_modifier=".cg") # [TDP, BK]
129 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float32)
130 qk *= sm_scale
132 # step4: preprocess for logsumexp
133 qk = tl.where(mask_ids[None, :], qk, float("-inf")) # [BH, BK]
134 # step5: lse=logsumexp(qk), loop part
135 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) # [BH]
136 exp_qk = tl.math.exp(qk - new_max[:, None]) # [BH, BK]
137 sum_qk = tl.sum(exp_qk, axis=1) # [BH]
138 alpha = tl.math.exp(max_log - new_max) # [BH]
139 sum_exp = sum_exp * alpha + sum_qk # [BH]
140 # step6: exp(qk-lse) @ gathered_kv.trans(), loop part
141 acc0 = tl.dot(
142 exp_qk.to(tl.bfloat16),
143 kv_blk0.trans(),
144 acc0 * alpha[:, None],
145 out_dtype=tl.float32,
146 ) # [BH, BK]@[BK, BDP]->[BH, BDP]
147 acc1 = tl.dot(
148 exp_qk.to(tl.bfloat16),
149 kv_blk1.trans(),
150 acc1 * alpha[:, None],
151 out_dtype=tl.float32,
152 ) # [BH, BK]@[BK, BDP]->[BH, BDP]
153 max_log = new_max
155 # step7: store max_logits
156 valid_mask = max_log != float("-inf")
157 max_log = tl.where(valid_mask, max_log, float("-inf"))
158 tl.store(max_log_base + offs_h, max_log) # [BH], float32
160 # step8: lse=logsumexp(qk) final part, store lse
161 orig_lse = max_log + tl.math.log(sum_exp)
162 lse_out = tl.where(valid_mask, orig_lse, float("inf"))
163 tl.store(l_base + offs_h, lse_out) # [BH], float32
165 # step9: exp(qk-lse) @ gathered_kv.trans(), final part
166 if HAVE_ATTN_SINK:
167 # step10: attn_sink
168 sink = tl.load(attn_sink_ptr + offs_h) # [BH]
169 sum_exp_new_lse = tl.math.exp(orig_lse) + tl.math.exp(sink)
170 factor = tl.math.exp(max_log) / sum_exp_new_lse
171 else:
172 factor = 1.0 / sum_exp
174 out_vals0 = tl.where(valid_mask[:, None], acc0 * factor[:, None], 0.0)
175 out_vals1 = tl.where(valid_mask[:, None], acc1 * factor[:, None], 0.0)
176 # step11: store output
177 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] # [BH, BDP]
178 tl.store(o_ptr, out_vals0.to(tl.bfloat16))
179 tl.store(o_ptr + BDP, out_vals1.to(tl.bfloat16))
182if HAS_TLE_FLASHMLA_SPARSE:
184 @triton.jit
185 def _tle_flashmla_prefill_producer(
186 k0_l_writer,
187 k0_r_writer,
188 k1_l_writer,
189 k1_r_writer,
190 valid_writer,
191 kv_base,
192 tkv_base,
193 t_base,
194 topk_len_ptr,
195 D: tl.constexpr,
196 TD: tl.constexpr,
197 DPH: tl.constexpr,
198 TDP: tl.constexpr,
199 VG: tl.constexpr,
200 SKV,
201 TOPK: tl.constexpr,
202 HAVE_TOPK_LENGTH: tl.constexpr,
203 HAVE_TAIL: tl.constexpr,
204 BK: tl.constexpr,
205 ):
206 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK
207 max_col = SKV - 1
208 stride_kvn: tl.constexpr = VG * (TD + D)
209 NK = tl.cdiv(topk_len, BK)
210 NPAIRS = tl.cdiv(NK, 2)
211 offs_t = tl.arange(0, BK)
212 offs_tile = tl.arange(0, 64)
213 kv_tile_rows = tl.broadcast_to(offs_t[:, None], (BK, 64))
214 for pair in tl.range(NPAIRS):
215 ck0 = pair * 2
216 ck1 = ck0 + 1
217 t_offs0 = BK * ck0 + offs_t
218 t_msk0 = t_offs0 < topk_len
219 kv_ids0 = tl.load(t_base + t_offs0, t_msk0, other=-1)
220 valid0 = t_msk0 & (kv_ids0 <= max_col) & (kv_ids0 >= 0)
221 kv_offsets0 = tl.where(valid0, kv_ids0, 0).to(tl.int64) * stride_kvn
223 t_offs1 = BK * ck1 + offs_t
224 t_msk1 = t_offs1 < topk_len
225 kv_ids1 = tl.load(t_base + t_offs1, t_msk1, other=-1)
226 valid1 = t_msk1 & (kv_ids1 <= max_col) & (kv_ids1 >= 0)
227 kv_offsets1 = tl.where(valid1, kv_ids1, 0).to(tl.int64) * stride_kvn
229 k0_l_slot = k0_l_writer.acquire(pair)
230 for tile in tl.static_range(0, DPH, 64):
231 k_cols = tile + offs_tile
232 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
233 k0_l_ptr = kv_base + kv_offsets0[:, None] + k_cols[None, :]
234 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :]
235 k0_l_blk = tl.load(
236 k0_l_ptr,
237 mask=k0_l_msk,
238 other=0.0,
239 eviction_policy="evict_last",
240 )
241 tl.store(
242 tle.gpu.local_ptr(k0_l_slot.sK, (kv_tile_rows, k_cols_b)),
243 k0_l_blk,
244 mask=k0_l_msk,
245 )
246 k0_l_writer.commit(pair)
248 k1_r_slot = k1_r_writer.acquire(pair)
249 for tile in tl.static_range(0, DPH, 64):
250 k_cols = DPH + tile + offs_tile
251 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
252 k1_r_ptr = kv_base + kv_offsets1[:, None] + k_cols[None, :]
253 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :]
254 k1_r_blk = tl.load(
255 k1_r_ptr,
256 mask=k1_r_msk,
257 other=0.0,
258 eviction_policy="evict_last",
259 )
260 tl.store(
261 tle.gpu.local_ptr(k1_r_slot.sK, (kv_tile_rows, k_cols_b)),
262 k1_r_blk,
263 mask=k1_r_msk,
264 )
265 if HAVE_TAIL:
266 offs_td = tl.arange(0, TDP)
267 k1_r_tail_ptr = tkv_base + kv_offsets1[:, None] + offs_td[None, :]
268 k1_r_tail_msk = valid1[:, None] & (offs_td < TD)[None, :]
269 k1_r_tail_blk = tl.load(
270 k1_r_tail_ptr,
271 mask=k1_r_tail_msk,
272 other=0.0,
273 eviction_policy="evict_last",
274 )
275 tl.store(
276 tle.gpu.local_ptr(k1_r_slot.sK_tail),
277 k1_r_tail_blk,
278 mask=k1_r_tail_msk,
279 )
280 k1_r_writer.commit(pair)
282 k0_r_slot = k0_r_writer.acquire(pair)
283 for tile in tl.static_range(0, DPH, 64):
284 k_cols = DPH + tile + offs_tile
285 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
286 k0_r_ptr = kv_base + kv_offsets0[:, None] + k_cols[None, :]
287 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :]
288 k0_r_blk = tl.load(
289 k0_r_ptr,
290 mask=k0_r_msk,
291 other=0.0,
292 eviction_policy="evict_last",
293 )
294 tl.store(
295 tle.gpu.local_ptr(k0_r_slot.sK, (kv_tile_rows, k_cols_b)),
296 k0_r_blk,
297 mask=k0_r_msk,
298 )
299 if HAVE_TAIL:
300 offs_td = tl.arange(0, TDP)
301 k0_r_tail_ptr = tkv_base + kv_offsets0[:, None] + offs_td[None, :]
302 k0_r_tail_msk = valid0[:, None] & (offs_td < TD)[None, :]
303 k0_r_tail_blk = tl.load(
304 k0_r_tail_ptr,
305 mask=k0_r_tail_msk,
306 other=0.0,
307 eviction_policy="evict_last",
308 )
309 tl.store(
310 tle.gpu.local_ptr(k0_r_slot.sK_tail),
311 k0_r_tail_blk,
312 mask=k0_r_tail_msk,
313 )
314 k0_r_writer.commit(pair)
316 k1_l_slot = k1_l_writer.acquire(pair)
317 for tile in tl.static_range(0, DPH, 64):
318 k_cols = tile + offs_tile
319 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
320 k1_l_ptr = kv_base + kv_offsets1[:, None] + k_cols[None, :]
321 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :]
322 k1_l_blk = tl.load(
323 k1_l_ptr,
324 mask=k1_l_msk,
325 other=0.0,
326 eviction_policy="evict_last",
327 )
328 tl.store(
329 tle.gpu.local_ptr(k1_l_slot.sK, (kv_tile_rows, k_cols_b)),
330 k1_l_blk,
331 mask=k1_l_msk,
332 )
333 k1_l_writer.commit(pair)
335 valid_slot = valid_writer.acquire(pair)
336 valid_row0 = tl.full([BK], 0, dtype=tl.int32)
337 valid_row1 = tl.full([BK], 1, dtype=tl.int32)
338 valid_ptr0 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row0, offs_t))
339 valid_ptr1 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row1, offs_t))
340 tl.store(valid_ptr0, valid0.to(tl.int8))
341 tl.store(valid_ptr1, valid1.to(tl.int8))
342 valid_writer.commit(pair)
344 @triton.jit
345 def _tle_flashmla_prefill_consumer0(
346 q_writer,
347 q_reader,
348 q_desc,
349 tq_desc,
350 k0_l_reader,
351 k0_r_qk_reader,
352 k1_l_remote_reader,
353 valid_reader,
354 sM_wg0_writer,
355 sM_wg1_reader,
356 sS0_writer,
357 sS1_reader,
358 sL_wg0_writer,
359 sL_wg1_reader,
360 output_desc,
361 output_row,
362 h_base,
363 topk_len_ptr,
364 attn_sink_base,
365 log_scale: tl.constexpr,
366 D: tl.constexpr,
367 TD: tl.constexpr,
368 OUT_DTYPE: tl.constexpr,
369 HAVE_ATTN_SINK: tl.constexpr,
370 TOPK: tl.constexpr,
371 HAVE_TOPK_LENGTH: tl.constexpr,
372 HAVE_TAIL: tl.constexpr,
373 BK: tl.constexpr,
374 BH: tl.constexpr,
375 DPH: tl.constexpr,
376 TDP: tl.constexpr,
377 G: tl.constexpr,
378 ):
379 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK
380 offs_h = tl.arange(0, BH)
381 offs_dh = tl.arange(0, DPH)
382 mask_h = h_base + offs_h < G
383 mask_od_l = offs_dh < D
384 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH))
385 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH))
386 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH))
388 q_write_slot = q_writer.acquire(0)
389 tle.gpu.copy(q_desc, q_write_slot.sQ_l, [BH, DPH], [output_row, 0])
390 tle.gpu.copy(q_desc, q_write_slot.sQ_r, [BH, DPH], [output_row, DPH])
391 if HAVE_TAIL:
392 tle.gpu.copy(tq_desc, q_write_slot.sQ_tail, [BH, TDP], [output_row, D])
393 q_writer.commit(0)
395 q_slot = q_reader.wait(0).slot
396 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l)
397 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r)
398 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32)
399 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
400 acc_l = tl.zeros([BH, DPH], dtype=tl.float32)
402 NK = tl.cdiv(topk_len, BK)
403 NPAIRS = tl.cdiv(NK, 2)
405 for pair in tl.range(NPAIRS):
406 k0_l_wait = k0_l_reader.wait(pair)
407 k0_l_slot = k0_l_wait.slot
409 q_l_blk = tl.load(q_l_smem_ptr)
410 q_r_blk = tl.load(q_r_smem_ptr)
411 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l)))
413 qk0 = tl.full([BH, BK], 0.0, dtype=tl.float32)
414 qk0 = tl.dot(q_l_blk, tl.trans(k0_l_blk), qk0, out_dtype=tl.float32)
416 k0_r_wait = k0_r_qk_reader.wait(pair)
417 k0_r_slot = k0_r_wait.slot
418 k0_r_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK, (kv_rows, kv_cols_r)))
419 qk0 = tl.dot(q_r_blk, tl.trans(k0_r_blk), qk0, out_dtype=tl.float32)
420 if HAVE_TAIL:
421 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail))
422 k0_t_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK_tail))
423 qk0 = tl.dot(q_tail_blk, tl.trans(k0_t_blk), qk0, out_dtype=tl.float32)
425 valid_wait = valid_reader.wait(pair)
426 row0 = tl.full([BK], 0, dtype=tl.int32)
427 valid0 = (
428 tl.load(
429 tle.gpu.local_ptr(
430 valid_wait.slot.is_kv_valid, (row0, tl.arange(0, BK))
431 )
432 )
433 != 0
434 )
435 qk0 = tl.where(valid0[None, :], qk0, float("-inf"))
436 valid_reader.release(pair)
438 local_max = tl.maximum(max_prev, tl.max(qk0, axis=1))
439 alpha = tl.math.exp2((max_prev - local_max) * log_scale)
440 prob0 = tl.math.exp2(qk0 * log_scale - local_max[:, None] * log_scale)
441 sum_exp = sum_exp * alpha + tl.sum(prob0, axis=1)
442 acc_l = acc_l * alpha[:, None]
443 prob0_b = prob0.to(OUT_DTYPE)
445 sM_wg0_slot = sM_wg0_writer.acquire(pair)
446 tl.store(tle.gpu.local_ptr(sM_wg0_slot.sM), local_max)
447 sM_wg0_writer.commit(pair)
449 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l)))
450 acc_l = tl.dot(prob0_b, k0_l_blk, acc_l, out_dtype=tl.float32)
451 k0_l_reader.release(pair)
452 k0_r_qk_reader.release(pair)
454 sM_wg1_wait = sM_wg1_reader.wait(pair)
455 max_next = tl.load(tle.gpu.local_ptr(sM_wg1_wait.slot.sM))
456 sM_wg1_reader.release(pair)
458 final_scale = tl.math.exp2((local_max - max_next) * log_scale)
459 sum_exp = sum_exp * final_scale
460 acc_l = acc_l * final_scale[:, None]
462 prob0_scaled = prob0 * final_scale[:, None]
463 sS0_slot = sS0_writer.acquire(pair)
464 tl.store(tle.gpu.local_ptr(sS0_slot.sS0), prob0_scaled.to(OUT_DTYPE))
465 sS0_writer.commit(pair)
467 sS1_wait = sS1_reader.wait(pair)
468 prob1 = tl.load(tle.gpu.local_ptr(sS1_wait.slot.sS1))
469 k1_l_wait = k1_l_remote_reader.wait(pair)
470 k1_l_blk = tl.load(
471 tle.gpu.local_ptr(k1_l_wait.slot.sK, (kv_rows, kv_cols_l))
472 )
473 acc_l = tl.dot(prob1, k1_l_blk, acc_l, out_dtype=tl.float32)
474 sS1_reader.release(pair)
475 k1_l_remote_reader.release(pair)
477 max_prev = max_next
479 sL_wg0_slot = sL_wg0_writer.acquire(0)
480 tl.store(tle.gpu.local_ptr(sL_wg0_slot.sL), sum_exp)
481 sL_wg0_writer.commit(0)
482 sL_wg1_wait = sL_wg1_reader.wait(1)
483 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg1_wait.slot.sL))
484 total_sum = sum_exp + peer_sum
485 sL_wg1_reader.release(1)
487 is_no_valid_tokens = total_sum == 0.0
488 inv_total_sum = tl.fdiv(1.0, total_sum)
489 out_l_vals = acc_l * inv_total_sum[:, None]
490 if HAVE_ATTN_SINK:
491 fin_log = (
492 max_prev * log_scale + tl.math.log2(total_sum)
493 ) * 0.6931471805599453
494 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0)
495 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log))
496 out_l_vals = out_l_vals * sink_scale[:, None]
497 out_l_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_l_vals)
498 o_l_msk = mask_h[:, None] & mask_od_l[None, :]
499 tl.store(q_l_smem_ptr, out_l_vals.to(OUT_DTYPE), o_l_msk)
500 tle.gpu.copy(q_slot.sQ_l, output_desc, [BH, DPH], [output_row, 0])
502 @triton.jit
503 def _tle_flashmla_prefill_consumer1(
504 q_reader,
505 k1_r_reader,
506 k1_l_qk_reader,
507 k0_r_remote_reader,
508 valid_reader,
509 sM_wg1_writer,
510 sM_wg0_reader,
511 sS1_writer,
512 sS0_reader,
513 sL_wg1_writer,
514 sL_wg0_reader,
515 final_max_logits_smem,
516 final_lse_smem,
517 output_desc,
518 output_row,
519 max_logits_base,
520 l_base,
521 h_base,
522 topk_len_ptr,
523 attn_sink_base,
524 log_scale: tl.constexpr,
525 D: tl.constexpr,
526 TD: tl.constexpr,
527 OUT_DTYPE: tl.constexpr,
528 HAVE_ATTN_SINK: tl.constexpr,
529 TOPK: tl.constexpr,
530 HAVE_TOPK_LENGTH: tl.constexpr,
531 HAVE_TAIL: tl.constexpr,
532 BK: tl.constexpr,
533 BH: tl.constexpr,
534 DPH: tl.constexpr,
535 TDP: tl.constexpr,
536 G: tl.constexpr,
537 ):
538 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK
539 offs_h = tl.arange(0, BH)
540 offs_dh = tl.arange(0, DPH)
541 mask_h = h_base + offs_h < G
542 mask_od_r = DPH + offs_dh < D
543 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH))
544 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH))
545 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH))
546 q_slot = q_reader.wait(0).slot
547 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l)
548 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r)
549 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32)
550 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
551 acc_r = tl.zeros([BH, DPH], dtype=tl.float32)
553 NK = tl.cdiv(topk_len, BK)
554 NPAIRS = tl.cdiv(NK, 2)
555 for pair in tl.range(NPAIRS):
556 k1_r_wait = k1_r_reader.wait(pair)
557 k1_r_slot = k1_r_wait.slot
559 q_l_blk = tl.load(q_l_smem_ptr)
560 q_r_blk = tl.load(q_r_smem_ptr)
561 k1_r_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK, (kv_rows, kv_cols_r)))
563 qk1 = tl.full([BH, BK], 0.0, dtype=tl.float32)
564 qk1 = tl.dot(q_r_blk, tl.trans(k1_r_blk), qk1, out_dtype=tl.float32)
565 if HAVE_TAIL:
566 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail))
567 k1_t_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK_tail))
568 qk1 = tl.dot(q_tail_blk, tl.trans(k1_t_blk), qk1, out_dtype=tl.float32)
569 k1_l_wait = k1_l_qk_reader.wait(pair)
570 k1_l_slot = k1_l_wait.slot
571 k1_l_blk = tl.load(tle.gpu.local_ptr(k1_l_slot.sK, (kv_rows, kv_cols_l)))
572 qk1 = tl.dot(q_l_blk, tl.trans(k1_l_blk), qk1, out_dtype=tl.float32)
574 valid_wait = valid_reader.wait(pair)
575 row1 = tl.full([BK], 1, dtype=tl.int32)
576 valid1 = (
577 tl.load(
578 tle.gpu.local_ptr(
579 valid_wait.slot.is_kv_valid, (row1, tl.arange(0, BK))
580 )
581 )
582 != 0
583 )
584 qk1 = tl.where(valid1[None, :], qk1, float("-inf"))
585 valid_reader.release(pair)
587 sM_wg0_wait = sM_wg0_reader.wait(pair)
588 candidate0 = tl.load(tle.gpu.local_ptr(sM_wg0_wait.slot.sM))
589 sM_wg0_reader.release(pair)
591 candidate1 = tl.maximum(max_prev, tl.max(qk1, axis=1))
592 max_next = tl.maximum(candidate1, candidate0)
593 sM_wg1_slot = sM_wg1_writer.acquire(pair)
594 tl.store(tle.gpu.local_ptr(sM_wg1_slot.sM), max_next)
595 sM_wg1_writer.commit(pair)
597 alpha = tl.math.exp2((max_prev - max_next) * log_scale)
598 prob1 = tl.math.exp2(qk1 * log_scale - max_next[:, None] * log_scale)
599 sum_exp = sum_exp * alpha + tl.sum(prob1, axis=1)
600 acc_r = acc_r * alpha[:, None]
601 prob1_b = prob1.to(OUT_DTYPE)
603 k1_l_qk_reader.release(pair)
605 acc_r = tl.dot(prob1_b, k1_r_blk, acc_r, out_dtype=tl.float32)
607 sS1_slot = sS1_writer.acquire(pair)
608 tl.store(tle.gpu.local_ptr(sS1_slot.sS1), prob1_b)
609 sS1_writer.commit(pair)
611 sS0_wait = sS0_reader.wait(pair)
612 prob0 = tl.load(tle.gpu.local_ptr(sS0_wait.slot.sS0))
613 k0_r_wait = k0_r_remote_reader.wait(pair)
614 k0_r_blk = tl.load(
615 tle.gpu.local_ptr(k0_r_wait.slot.sK, (kv_rows, kv_cols_r))
616 )
617 acc_r = tl.dot(prob0, k0_r_blk, acc_r, out_dtype=tl.float32)
618 k1_r_reader.release(pair)
619 sS0_reader.release(pair)
620 k0_r_remote_reader.release(pair)
621 max_prev = max_next
623 sL_wg1_slot = sL_wg1_writer.acquire(1)
624 tl.store(tle.gpu.local_ptr(sL_wg1_slot.sL), sum_exp)
625 sL_wg1_writer.commit(1)
626 sL_wg0_wait = sL_wg0_reader.wait(0)
627 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg0_wait.slot.sL))
628 total_sum = sum_exp + peer_sum
629 sL_wg0_reader.release(0)
631 is_no_valid_tokens = total_sum == 0.0
632 inv_total_sum = tl.fdiv(1.0, total_sum)
633 out_r_vals = acc_r * inv_total_sum[:, None]
634 final_max_logits_log2 = max_prev * log_scale
635 final_max_logits = final_max_logits_log2 * 0.6931471805599453
636 fin_log = (final_max_logits_log2 + tl.math.log2(total_sum)) * 0.6931471805599453
637 if HAVE_ATTN_SINK:
638 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0)
639 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log))
640 out_r_vals = out_r_vals * sink_scale[:, None]
641 out_r_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_r_vals)
642 o_r_msk = mask_h[:, None] & mask_od_r[None, :]
643 tl.store(q_r_smem_ptr, out_r_vals.to(OUT_DTYPE), o_r_msk)
644 tle.gpu.copy(q_slot.sQ_r, output_desc, [BH, DPH], [output_row, DPH])
646 final_max_logits = tl.where(is_no_valid_tokens, float("-inf"), final_max_logits)
647 fin_log = tl.where(is_no_valid_tokens, float("inf"), fin_log)
648 tl.store(tle.gpu.local_ptr(final_max_logits_smem), final_max_logits, mask_h)
649 tl.store(tle.gpu.local_ptr(final_lse_smem), fin_log, mask_h)
650 final_max_logits = tl.load(
651 tle.gpu.local_ptr(final_max_logits_smem), mask_h, other=float("-inf")
652 )
653 fin_log = tl.load(tle.gpu.local_ptr(final_lse_smem), mask_h, other=float("inf"))
654 tl.store(max_logits_base + offs_h, final_max_logits, mask_h)
655 tl.store(l_base + offs_h, fin_log, mask_h)
657 @triton.jit
658 def _tle_flashmla_prefill_fwd(
659 q_desc,
660 tq_desc,
661 output_desc,
662 kv,
663 indices,
664 attn_sink,
665 topk_length,
666 sm_scale: tl.constexpr,
667 output,
668 max_logits,
669 lse,
670 SQ,
671 H: tl.constexpr,
672 DQK: tl.constexpr,
673 SKV,
674 TOPK: tl.constexpr,
675 HAVE_ATTN_SINK: tl.constexpr,
676 HAVE_TOPK_LENGTH: tl.constexpr,
677 D: tl.constexpr,
678 TD: tl.constexpr,
679 DP: tl.constexpr,
680 TDP: tl.constexpr,
681 G: tl.constexpr,
682 VG: tl.constexpr,
683 RH: tl.constexpr,
684 HAVE_TAIL: tl.constexpr,
685 BK: tl.constexpr,
686 BH: tl.constexpr,
687 PAIR_BLOCKS: tl.constexpr,
688 ):
689 DPH: tl.constexpr = DP // 2
690 stride_kvg: tl.constexpr = TD + D
691 stride_tg = TOPK
692 stride_tm = VG * stride_tg
693 stride_lm = H
694 stride_mm = H
696 pid = tl.program_id(0)
697 programs_per_q: tl.constexpr = VG * RH
698 i_sq = pid // programs_per_q
699 i_grh = pid % programs_per_q
700 i_g = i_grh // RH
701 i_rh = i_grh % RH
702 h_base = i_rh * BH
703 q_head_base = i_g * G + h_base
704 i_sq64 = i_sq.to(tl.int64)
705 i_g64 = i_g.to(tl.int64)
706 q_head_base64 = q_head_base.to(tl.int64)
707 kv_base = kv + i_g64 * stride_kvg
708 tkv_base = kv_base + D
709 t_base = indices + i_sq64 * stride_tm + i_g64 * stride_tg
710 topk_len_ptr = topk_length + i_sq64 if HAVE_TOPK_LENGTH else indices
711 attn_sink_base = attn_sink if HAVE_ATTN_SINK else max_logits
712 max_logits_base = max_logits + i_sq64 * stride_mm + q_head_base64
713 l_base = lse + i_sq64 * stride_lm + q_head_base64
714 q_row = i_sq * H + q_head_base
715 _ = output
716 _ = SQ
717 _ = DQK
719 sQ_l_smem = tle.gpu.alloc(
720 [1, BH, DPH], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem
721 )
722 sQ_r_smem = tle.gpu.alloc(
723 [1, BH, DPH], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem
724 )
725 if HAVE_TAIL:
726 sQ_tail_smem = tle.gpu.alloc(
727 [1, BH, TDP],
728 dtype=kv.dtype.element_ty,
729 layout=None,
730 scope=tle.gpu.smem,
731 )
732 q_pipe = tle.pipe(
733 capacity=1,
734 scope="cta",
735 name="flashmla_sQ",
736 readers=("wg0", "wg1"),
737 one_shot=True,
738 sQ_l=sQ_l_smem,
739 sQ_r=sQ_r_smem,
740 sQ_tail=sQ_tail_smem,
741 )
742 else:
743 q_pipe = tle.pipe(
744 capacity=1,
745 scope="cta",
746 name="flashmla_sQ",
747 readers=("wg0", "wg1"),
748 one_shot=True,
749 sQ_l=sQ_l_smem,
750 sQ_r=sQ_r_smem,
751 )
753 sK0_smem = tle.gpu.alloc(
754 [1, BK, DP], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem
755 )
756 sK1_smem = tle.gpu.alloc(
757 [1, BK, DP], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem
758 )
759 if HAVE_TAIL:
760 sK0_tail_smem = tle.gpu.alloc(
761 [1, BK, TDP],
762 dtype=kv.dtype.element_ty,
763 layout=None,
764 scope=tle.gpu.smem,
765 )
766 sK1_tail_smem = tle.gpu.alloc(
767 [1, BK, TDP],
768 dtype=kv.dtype.element_ty,
769 layout=None,
770 scope=tle.gpu.smem,
771 )
772 sS0_smem = sK0_tail_smem
773 else:
774 sS0_smem = tle.gpu.alloc(
775 [1, BH, BK],
776 dtype=kv.dtype.element_ty,
777 layout=None,
778 scope=tle.gpu.smem,
779 )
780 is_kv_valid_smem = tle.gpu.alloc(
781 [1, PAIR_BLOCKS, BK],
782 dtype=tl.int8,
783 layout=None,
784 scope=tle.gpu.smem,
785 nv_mma_shared_layout=False,
786 )
787 k0_l_pipe = tle.pipe(
788 capacity=1, scope="cta", name="flashmla_sK0_l", sK=sK0_smem
789 )
790 if HAVE_TAIL:
791 k0_r_pipe = tle.pipe(
792 capacity=1,
793 scope="cta",
794 name="flashmla_sK0_r",
795 readers=("qk", "remote"),
796 sK=sK0_smem,
797 sK_tail=sK0_tail_smem,
798 )
799 else:
800 k0_r_pipe = tle.pipe(
801 capacity=1,
802 scope="cta",
803 name="flashmla_sK0_r",
804 readers=("qk", "remote"),
805 sK=sK0_smem,
806 )
807 k1_l_pipe = tle.pipe(
808 capacity=1,
809 scope="cta",
810 name="flashmla_sK1_l",
811 readers=("qk", "remote"),
812 sK=sK1_smem,
813 )
814 if HAVE_TAIL:
815 k1_r_pipe = tle.pipe(
816 capacity=1,
817 scope="cta",
818 name="flashmla_sK1_r",
819 sK=sK1_smem,
820 sK_tail=sK1_tail_smem,
821 )
822 else:
823 k1_r_pipe = tle.pipe(
824 capacity=1,
825 scope="cta",
826 name="flashmla_sK1_r",
827 sK=sK1_smem,
828 )
829 is_kv_valid_pipe = tle.pipe(
830 capacity=1,
831 scope="cta",
832 name="flashmla_is_kv_valid_ready",
833 readers=("wg0", "wg1"),
834 is_kv_valid=is_kv_valid_smem,
835 )
837 sM_smem = tle.gpu.alloc(
838 [1, BH],
839 dtype=tl.float32,
840 layout=None,
841 scope=tle.gpu.smem,
842 nv_mma_shared_layout=False,
843 )
844 sS1_smem = tle.gpu.alloc(
845 [1, BH, BK], dtype=kv.dtype.element_ty, layout=None, scope=tle.gpu.smem
846 )
847 sL_smem = tle.gpu.alloc(
848 [2, BH],
849 dtype=tl.float32,
850 layout=None,
851 scope=tle.gpu.smem,
852 nv_mma_shared_layout=False,
853 )
854 final_max_logits_smem = tle.gpu.alloc(
855 [BH],
856 dtype=tl.float32,
857 layout=None,
858 scope=tle.gpu.smem,
859 nv_mma_shared_layout=False,
860 )
861 final_lse_smem = tle.gpu.alloc(
862 [BH],
863 dtype=tl.float32,
864 layout=None,
865 scope=tle.gpu.smem,
866 nv_mma_shared_layout=False,
867 )
868 sM_wg0_pipe = tle.pipe(
869 capacity=1, scope="cta", name="flashmla_wg0_bunch_0_ready", sM=sM_smem
870 )
871 sM_wg1_pipe = tle.pipe(
872 capacity=1, scope="cta", name="flashmla_wg1_bunch_0_ready", sM=sM_smem
873 )
874 sS0_pipe = tle.pipe(capacity=1, scope="cta", name="flashmla_sS0", sS0=sS0_smem)
875 sS1_pipe = tle.pipe(capacity=1, scope="cta", name="flashmla_sS1", sS1=sS1_smem)
876 sL_wg0_pipe = tle.pipe(
877 capacity=2, scope="cta", name="flashmla_sL_wg0", sL=sL_smem
878 )
879 sL_wg1_pipe = tle.pipe(
880 capacity=2, scope="cta", name="flashmla_sL_wg1", sL=sL_smem
881 )
883 log_scale: tl.constexpr = sm_scale * 1.4426950408889634
885 tle.gpu.warp_specialize(
886 [
887 (
888 _tle_flashmla_prefill_consumer0,
889 (
890 q_pipe.writer(),
891 q_pipe.reader("wg0"),
892 q_desc,
893 tq_desc,
894 k0_l_pipe.reader(),
895 k0_r_pipe.reader("qk"),
896 k1_l_pipe.reader("remote", fields=("sK",)),
897 is_kv_valid_pipe.reader("wg0"),
898 sM_wg0_pipe.writer(),
899 sM_wg1_pipe.reader(),
900 sS0_pipe.writer(),
901 sS1_pipe.reader(),
902 sL_wg0_pipe.writer(),
903 sL_wg1_pipe.reader(),
904 output_desc,
905 q_row,
906 h_base,
907 topk_len_ptr,
908 attn_sink_base,
909 log_scale,
910 D,
911 TD,
912 kv.dtype.element_ty,
913 HAVE_ATTN_SINK,
914 TOPK,
915 HAVE_TOPK_LENGTH,
916 HAVE_TAIL,
917 BK,
918 BH,
919 DPH,
920 TDP,
921 G,
922 ),
923 ),
924 (
925 _tle_flashmla_prefill_consumer1,
926 (
927 q_pipe.reader("wg1"),
928 k1_r_pipe.reader(),
929 k1_l_pipe.reader("qk"),
930 k0_r_pipe.reader("remote", fields=("sK",)),
931 is_kv_valid_pipe.reader("wg1"),
932 sM_wg1_pipe.writer(),
933 sM_wg0_pipe.reader(),
934 sS1_pipe.writer(),
935 sS0_pipe.reader(),
936 sL_wg1_pipe.writer(),
937 sL_wg0_pipe.reader(),
938 final_max_logits_smem,
939 final_lse_smem,
940 output_desc,
941 q_row,
942 max_logits_base,
943 l_base,
944 h_base,
945 topk_len_ptr,
946 attn_sink_base,
947 log_scale,
948 D,
949 TD,
950 kv.dtype.element_ty,
951 HAVE_ATTN_SINK,
952 TOPK,
953 HAVE_TOPK_LENGTH,
954 HAVE_TAIL,
955 BK,
956 BH,
957 DPH,
958 TDP,
959 G,
960 ),
961 ),
962 (
963 _tle_flashmla_prefill_producer,
964 (
965 k0_l_pipe.writer(),
966 k0_r_pipe.writer(),
967 k1_l_pipe.writer(),
968 k1_r_pipe.writer(),
969 is_kv_valid_pipe.writer(),
970 kv_base,
971 tkv_base,
972 t_base,
973 topk_len_ptr,
974 D,
975 TD,
976 DPH,
977 TDP,
978 VG,
979 SKV,
980 TOPK,
981 HAVE_TOPK_LENGTH,
982 HAVE_TAIL,
983 BK,
984 ),
985 ),
986 ],
987 [4, 4],
988 [216, 72],
989 )
992def _flash_mla_sparse_tle_enabled() -> bool:
993 value = os.environ.get("FLAGGEMS_FLASHMLA_SPARSE_TLE", "1").lower()
994 return value not in {"0", "false", "off", "no"}
997def _can_use_tle_flash_mla_sparse_fwd(
998 q: torch.Tensor,
999 kv: torch.Tensor,
1000 indices: torch.Tensor,
1001 d_v: int,
1002 topk_length: Optional[torch.Tensor] = None,
1003) -> bool:
1004 if not (HAS_TLE_FLASHMLA_SPARSE and _flash_mla_sparse_tle_enabled()):
1005 return False
1006 if q.device.type != "cuda":
1007 return False
1008 SQ, HQ, DQK = q.shape
1009 _ = SQ
1010 HKV = kv.shape[1]
1011 TOPK = indices.shape[-1]
1012 return (
1013 d_v == 512
1014 and HKV == 1
1015 and DQK in (512, 576)
1016 and HQ % TLE_FLASHMLA_PREFILL_BH == 0
1017 and TOPK > 0
1018 and TOPK % 128 == 0
1019 )
1022def _set_triton_descriptor_allocator(device: torch.device) -> None:
1023 def alloc_fn(size: int, align: int, stream):
1024 _ = align
1025 _ = stream
1026 return torch.empty(size, dtype=torch.int8, device=device)
1028 triton.set_allocator(alloc_fn)
1031def flash_mla_sparse_fwd(
1032 q: torch.Tensor,
1033 kv: torch.Tensor,
1034 indices: torch.Tensor,
1035 sm_scale: float,
1036 d_v: int = 512,
1037 attn_sink: Optional[torch.Tensor] = None,
1038 topk_length: Optional[torch.Tensor] = None,
1039) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1040 """
1041 Sparse attention prefill kernel
1043 Args:
1044 q: [s_q, h_q, d_qk], bfloat16
1045 kv: [s_kv, h_kv, d_qk], bfloat16
1046 indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
1047 sm_scale: float
1048 d_v: The dimension of value vectors. Can only be 512
1049 attn_sink: optional, [h_q], float32.
1050 If attn_sink is provided, when computing output, output will be additionally multiplied by
1051 exp(lse) / (exp(lse) + exp(attn_sink)). +-inf in attn_sink will be handled normally (i.e., -inf has no
1052 effect, +inf will make corresponding output all zeros).
1053 This argument has no effect on lse and max_logits.
1054 topk_length: optional, [s_q], int32.
1055 If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]],
1056 ignoring later k/v tokens (even if provided in indices). In extremely rare cases (topk_length provided,
1057 there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token
1058 containing NaN), operator output will contain NaN, so please avoid this situation.
1060 Returns:
1061 (output, max_logits, lse)
1062 Please refer to tests/ref.py for the precise definitions of these parameters.
1063 - output: [s_q, h_q, d_v], bfloat16
1064 - max_logits: [s_q, h_q], float
1065 - lse: [s_q, h_q], float, log-sum-exp of attention scores
1066 """
1067 assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
1068 assert (
1069 q.dtype == torch.bfloat16
1070 and kv.dtype == torch.bfloat16
1071 and indices.dtype == torch.int32
1072 )
1073 SQ, HQ, DQK = q.shape
1074 SKV, HKV, _ = kv.shape
1076 assert d_v == 512, "Unsupported d_v"
1077 DV = d_v
1079 assert kv.shape[-1] == DQK
1080 _, _, TOPK = indices.shape
1081 assert indices.shape == (SQ, HKV, TOPK)
1082 if attn_sink is not None:
1083 assert attn_sink.is_contiguous()
1084 assert attn_sink.dtype == torch.float32
1085 assert attn_sink.shape == (HQ,), "attn_sink error shape"
1086 if topk_length is not None:
1087 assert topk_length.is_contiguous()
1088 assert topk_length.dtype == torch.int32
1089 assert topk_length.shape == (SQ,), "topk_length error shape"
1091 # check from FlashMLA
1092 assert HKV == 1, "h_kv is expected to be 1"
1093 assert HQ == 64 or HQ == 128, "Unsupported h_q"
1094 assert DQK == 576 or DQK == 512, "Unsupported d_qk"
1096 _ = SKV
1097 D = DV
1098 TD = DQK - D
1099 DP = triton.next_power_of_2(D)
1100 HAVE_TAIL = TD > 0
1101 TDP = triton.next_power_of_2(TD) if HAVE_TAIL else 1
1102 G = HQ // HKV
1103 BH = TLE_FLASHMLA_PREFILL_BH
1104 RH = G // BH
1105 BK = TLE_FLASHMLA_PREFILL_BK
1106 output = torch.empty((SQ, HQ, DV), device=q.device, dtype=q.dtype)
1107 max_logits = torch.empty((SQ, HQ), device=q.device, dtype=torch.float32)
1108 lse = torch.empty((SQ, HQ), device=q.device, dtype=torch.float32)
1110 def triton_grid(META):
1111 return (triton.cdiv(HQ, META["BH"]) * SQ,)
1113 if _can_use_tle_flash_mla_sparse_fwd(q, kv, indices, d_v, topk_length):
1114 from triton.tools.tensor_descriptor import TensorDescriptor
1116 _set_triton_descriptor_allocator(q.device)
1117 q_desc = TensorDescriptor(
1118 q, shape=[SQ * HQ, DQK], strides=[DQK, 1], block_shape=[BH, DP // 2]
1119 )
1120 if HAVE_TAIL:
1121 tq_desc = TensorDescriptor(
1122 q, shape=[SQ * HQ, DQK], strides=[DQK, 1], block_shape=[BH, TDP]
1123 )
1124 else:
1125 tq_desc = q_desc
1126 output_desc = TensorDescriptor(
1127 output, shape=[SQ * HQ, D], strides=[D, 1], block_shape=[BH, DP // 2]
1128 )
1129 _tle_flashmla_prefill_fwd[triton_grid](
1130 q_desc,
1131 tq_desc,
1132 output_desc,
1133 kv,
1134 indices,
1135 attn_sink,
1136 topk_length,
1137 sm_scale,
1138 output,
1139 max_logits,
1140 lse,
1141 SQ,
1142 HQ,
1143 DQK,
1144 SKV,
1145 TOPK,
1146 attn_sink is not None,
1147 topk_length is not None,
1148 D,
1149 TD,
1150 DP,
1151 TDP,
1152 G,
1153 HKV,
1154 RH,
1155 HAVE_TAIL,
1156 BK,
1157 BH,
1158 TLE_FLASHMLA_PREFILL_PAIR_BLOCKS,
1159 num_warps=TLE_FLASHMLA_PREFILL_WORKER_NUM_WARPS,
1160 num_stages=1,
1161 )
1162 return output, max_logits, lse
1164 triton_flash_mla_sparse_fwd[triton_grid](
1165 q,
1166 kv,
1167 indices,
1168 attn_sink,
1169 topk_length,
1170 sm_scale,
1171 output,
1172 max_logits,
1173 lse,
1174 q.stride(1),
1175 q.stride(0),
1176 kv.stride(1),
1177 kv.stride(0),
1178 indices.stride(1),
1179 indices.stride(0),
1180 output.stride(1),
1181 output.stride(0),
1182 max_logits.stride(0),
1183 lse.stride(0),
1184 SQ,
1185 HQ,
1186 DQK,
1187 SKV,
1188 TOPK,
1189 attn_sink is not None,
1190 topk_length is not None,
1191 )
1192 return output, max_logits, lse