Coverage for src/flag_gems/fused/flash_mla_with_kvcache.py: 4%
1201 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""
2Triton implementation of flash_mla_with_kvcache for MLA attention.
3Supports both sparse (FP8 KV cache + topk indices) and dense (paged attention) modes.
4Only supports sm90 (Hopper) architecture.
5"""
7import dataclasses
8import os
9from typing import Optional, Tuple
11import torch
12import triton
13import triton.language as tl
15from flag_gems.utils.triton_version_utils import has_triton_tle
17if has_triton_tle(3, 6, 0):
18 try:
19 import triton.experimental.tle.language as tle
21 HAS_TLE = True
22 except ImportError:
23 tle = None
24 HAS_TLE = False
25else:
26 tle = None
27 HAS_TLE = False
30# TLE constants for decode
31TLE_DECODE_BK = 64
32TLE_DECODE_BH = 64
33TLE_DECODE_PAIR_BLOCKS = 2
34TLE_DECODE_WORKER_NUM_WARPS = 4
37# ============================================================================
38# Data structures (compatible with original CUDA interface)
39# ============================================================================
42@dataclasses.dataclass
43class FlashMLASchedMeta:
44 """Stores tile scheduler metadata for FlashMLA."""
46 @dataclasses.dataclass
47 class Config:
48 b: int
49 s_q: int
50 h_q: int
51 page_block_size: int
52 h_k: int
53 causal: bool
54 is_fp8_kvcache: bool
55 topk: Optional[int]
56 extra_page_block_size: Optional[int]
57 extra_topk: Optional[int]
59 have_initialized: bool = False
60 config: Optional[Config] = None
61 tile_scheduler_metadata: Optional[torch.Tensor] = None
62 num_splits: Optional[torch.Tensor] = None
65def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]:
66 """Returns an empty FlashMLASchedMeta instance."""
67 return FlashMLASchedMeta(), None
70# ============================================================================
71# Sparse decode kernel (FP8 KV cache + topk indices)
72#
73# KV cache layout per token (656 bytes total):
74# [0:512] - NoPE part: 512 float8_e4m3 values
75# [512:528] - Scale factors: 4 float32 values (each for 128 FP8 values)
76# [528:656] - RoPE part: 64 bfloat16 values
77#
78# The NoPE part (after dequantization) serves as BOTH K and V for MLA.
79# ============================================================================
82@triton.autotune(
83 configs=[
84 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=2),
85 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=4),
86 ],
87 key=["HQ", "DQK", "TOPK", "HAVE_ATTN_SINK", "HAVE_TOPK_LENGTH", "IS_FP8"],
88)
89@triton.jit
90def _sparse_decode_kernel(
91 q,
92 kv,
93 kv_scales,
94 kv_rope,
95 indices,
96 attn_sink,
97 topk_length,
98 sm_scale: tl.constexpr,
99 output,
100 lse,
101 stride_qb,
102 stride_qsq,
103 stride_qh,
104 stride_kvn,
105 stride_scales_n,
106 stride_rope_n,
107 stride_ib,
108 stride_isq,
109 stride_ob,
110 stride_osq,
111 stride_oh,
112 stride_lseb,
113 stride_lseh,
114 SQ,
115 HQ: tl.constexpr,
116 DQK: tl.constexpr,
117 SKV,
118 TOPK: tl.constexpr,
119 HAVE_ATTN_SINK: tl.constexpr,
120 HAVE_TOPK_LENGTH: tl.constexpr,
121 IS_FP8: tl.constexpr,
122 BK: tl.constexpr,
123 BH: tl.constexpr,
124):
125 """
126 Sparse decode kernel with online softmax.
127 Grid: (batch_size * seq_q * ceil(HQ / BH),)
128 Each program handles BH heads for one (batch, seq_q) position.
130 For FP8 mode:
131 - kv: [num_tokens, 512] float8_e4m3fn (NoPE part)
132 - kv_scales: [num_tokens, 4] float32 (per-128-element scales)
133 - kv_rope: [num_tokens, 64] bfloat16 (RoPE part)
134 For BF16 mode:
135 - kv: [num_tokens, DQK] bfloat16 (full KV)
136 - kv_scales, kv_rope: unused
137 """
138 num_head_blocks: tl.constexpr = (HQ + BH - 1) // BH
139 pid = tl.program_id(0)
140 i_b = pid // (SQ * num_head_blocks)
141 remainder = pid % (SQ * num_head_blocks)
142 i_sq = remainder // num_head_blocks
143 i_sq = i_sq.to(tl.int64)
144 i_gbh = remainder % num_head_blocks
145 gbh_base = i_gbh * BH
147 DP: tl.constexpr = 512
148 BDP: tl.constexpr = 256
150 # Base pointers
151 q_base = q + i_b * stride_qb + i_sq * stride_qsq + gbh_base * stride_qh
152 kv_base = kv
153 t_base = indices + i_b * stride_ib + i_sq * stride_isq
154 attn_sink_ptr = attn_sink + gbh_base if HAVE_ATTN_SINK else 0
155 topk_length_ptr = topk_length + i_b if HAVE_TOPK_LENGTH else 0
156 o_base = output + i_b * stride_ob + i_sq * stride_osq + gbh_base * stride_oh
157 l_base = lse + i_b * stride_lseb + gbh_base * stride_lseh + i_sq
159 offs_h = tl.arange(0, BH)
160 offs_d = tl.arange(0, BDP)
161 if DQK == 576:
162 offs_td = tl.arange(0, 64)
163 offs_t = tl.arange(0, BK)
165 # Load Q in two halves [BH, 256] x 2
166 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :]
167 q_blk0 = tl.load(q_ptr, eviction_policy="evict_first")
168 q_blk1 = tl.load(q_ptr + BDP, eviction_policy="evict_first")
169 if DQK == 576:
170 tq_ptr = q_base + DP + offs_h[:, None] * stride_qh + offs_td[None, :]
171 tq_blk = tl.load(tq_ptr, eviction_policy="evict_first")
173 # Online softmax accumulators
174 max_log = tl.full([BH], float("-inf"), dtype=tl.float32)
175 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
176 acc0 = tl.zeros([BH, BDP], dtype=tl.float32)
177 acc1 = tl.zeros([BH, BDP], dtype=tl.float32)
179 topk_len = tl.load(topk_length_ptr) if HAVE_TOPK_LENGTH else TOPK
180 NK = tl.cdiv(topk_len, BK)
181 for ck in range(NK):
182 # Load indices
183 t_ptr = BK * ck + offs_t
184 t_msk = t_ptr < topk_len
185 t_ptr += t_base
186 kv_ids = tl.load(t_ptr, t_msk, other=-1)
187 mask_ids = (kv_ids < SKV) & (kv_ids >= 0)
188 kv_ids = tl.where(mask_ids, kv_ids, 0)
190 if IS_FP8:
191 # FP8 mode: load FP8 values and dequantize with per-128-element scales
192 # Load NoPE FP8 data: [BDP, BK] for each half
193 kv_ptr = kv_base + offs_d[:, None] + kv_ids[None, :] * stride_kvn
194 kv_fp8_0 = tl.load(kv_ptr, cache_modifier=".cg") # [256, BK] float8
195 kv_fp8_1 = tl.load(kv_ptr + BDP, cache_modifier=".cg") # [256, BK] float8
197 # Load 4 scales per token separately
198 # Scale layout: [num_tokens, 4] float32
199 scale0 = tl.load(kv_scales + kv_ids * stride_scales_n + 0) # [BK]
200 scale1 = tl.load(kv_scales + kv_ids * stride_scales_n + 1) # [BK]
201 scale2 = tl.load(kv_scales + kv_ids * stride_scales_n + 2) # [BK]
202 scale3 = tl.load(kv_scales + kv_ids * stride_scales_n + 3) # [BK]
204 # Dequantize first half [256, BK]:
205 # elements [0:128] use scale0, elements [128:256] use scale1
206 mask_lo = offs_d[:, None] < 128
207 kv_blk0 = tl.where(
208 mask_lo,
209 kv_fp8_0.to(tl.float32) * scale0[None, :],
210 kv_fp8_0.to(tl.float32) * scale1[None, :],
211 ).to(tl.bfloat16)
213 # Dequantize second half [256, BK]:
214 # elements [0:128] use scale2, elements [128:256] use scale3
215 kv_blk1 = tl.where(
216 mask_lo,
217 kv_fp8_1.to(tl.float32) * scale2[None, :],
218 kv_fp8_1.to(tl.float32) * scale3[None, :],
219 ).to(tl.bfloat16)
220 else:
221 # BF16 mode: load directly
222 kv_ptr = kv_base + offs_d[:, None] + kv_ids[None, :] * stride_kvn
223 kv_blk0 = tl.load(kv_ptr, cache_modifier=".cg") # [BDP, BK]
224 kv_blk1 = tl.load(kv_ptr + BDP, cache_modifier=".cg") # [BDP, BK]
226 # Compute QK^T
227 qk = tl.dot(q_blk0, kv_blk0, out_dtype=tl.float32)
228 qk = tl.dot(q_blk1, kv_blk1, qk, out_dtype=tl.float32)
229 if DQK == 576:
230 if IS_FP8:
231 # RoPE part from separate tensor
232 rope_ptr = kv_rope + offs_td[:, None] + kv_ids[None, :] * stride_rope_n
233 tkv_blk = tl.load(rope_ptr, cache_modifier=".cg")
234 else:
235 tkv_ptr = kv_base + DP + offs_td[:, None] + kv_ids[None, :] * stride_kvn
236 tkv_blk = tl.load(tkv_ptr, cache_modifier=".cg")
237 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float32)
238 qk *= sm_scale
240 # Mask invalid tokens
241 qk = tl.where(mask_ids[None, :], qk, float("-inf"))
243 # Online softmax
244 new_max = tl.maximum(max_log, tl.max(qk, axis=1))
245 exp_qk = tl.math.exp(qk - new_max[:, None])
246 sum_qk = tl.sum(exp_qk, axis=1)
247 alpha = tl.math.exp(max_log - new_max)
248 sum_exp = sum_exp * alpha + sum_qk
250 # Accumulate P @ V (V = K NoPE for MLA)
251 acc0 = tl.dot(
252 exp_qk.to(tl.bfloat16),
253 kv_blk0.trans(),
254 acc0 * alpha[:, None],
255 out_dtype=tl.float32,
256 )
257 acc1 = tl.dot(
258 exp_qk.to(tl.bfloat16),
259 kv_blk1.trans(),
260 acc1 * alpha[:, None],
261 out_dtype=tl.float32,
262 )
263 max_log = new_max
265 # Finalize output
266 valid_mask = max_log != float("-inf")
267 max_log = tl.where(valid_mask, max_log, float("-inf"))
269 orig_lse = max_log + tl.math.log(sum_exp)
270 lse_out = tl.where(valid_mask, orig_lse, float("inf"))
271 tl.store(l_base + offs_h * stride_lseh, lse_out)
273 if HAVE_ATTN_SINK:
274 sink = tl.load(attn_sink_ptr + offs_h)
275 sum_exp_new_lse = tl.math.exp(orig_lse) + tl.math.exp(sink)
276 factor = tl.math.exp(max_log) / sum_exp_new_lse
277 else:
278 factor = 1.0 / sum_exp
280 out_vals0 = tl.where(valid_mask[:, None], acc0 * factor[:, None], 0.0)
281 out_vals1 = tl.where(valid_mask[:, None], acc1 * factor[:, None], 0.0)
283 # Store output
284 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_d[None, :]
285 tl.store(o_ptr, out_vals0.to(tl.bfloat16))
286 tl.store(o_ptr + BDP, out_vals1.to(tl.bfloat16))
289# ============================================================================
290# Sparse decode kernel for FlashMLA MODEL1 layout
291#
292# MODEL1 is FlashMLA's internal name for the d_qk=512 / 584-byte layout.
293# It is not a model name. Per page:
294# [0:page_block_size*576] - token data
295# per token: 448 FP8 NoPE + 64 BF16 RoPE
296# [page_block_size*576:...] - 8 uint8 E8M0 scales per token
297#
298# The 512-dim output uses both NoPE and RoPE values as V:
299# output[0:448] = weighted NoPE
300# output[448:512] = weighted RoPE
301# ============================================================================
304@triton.autotune(
305 configs=[
306 triton.Config({"BK": 32, "BH": 64}, num_warps=4, num_stages=1),
307 triton.Config({"BK": 32, "BH": 64}, num_warps=8, num_stages=1),
308 ],
309 key=[
310 "HQ",
311 "TOPK",
312 "EXTRA_TOPK",
313 "HAVE_ATTN_SINK",
314 "HAVE_TOPK_LENGTH",
315 "HAVE_EXTRA",
316 "HAVE_EXTRA_TOPK_LENGTH",
317 ],
318)
319@triton.jit
320def _sparse_decode_model1_kernel(
321 q,
322 kv,
323 indices,
324 extra_kv,
325 extra_indices,
326 attn_sink,
327 topk_length,
328 extra_topk_length,
329 sm_scale: tl.constexpr,
330 output,
331 lse,
332 stride_qb,
333 stride_qsq,
334 stride_qh,
335 stride_kv_block,
336 stride_ib,
337 stride_isq,
338 stride_extra_kv_block,
339 stride_eib,
340 stride_eisq,
341 stride_ob,
342 stride_osq,
343 stride_oh,
344 stride_lseb,
345 stride_lseh,
346 SQ,
347 HQ: tl.constexpr,
348 PAGE_SIZE: tl.constexpr,
349 EXTRA_PAGE_SIZE: tl.constexpr,
350 NUM_BLOCKS,
351 EXTRA_NUM_BLOCKS,
352 TOPK: tl.constexpr,
353 EXTRA_TOPK: tl.constexpr,
354 HAVE_ATTN_SINK: tl.constexpr,
355 HAVE_TOPK_LENGTH: tl.constexpr,
356 HAVE_EXTRA: tl.constexpr,
357 HAVE_EXTRA_TOPK_LENGTH: tl.constexpr,
358 BK: tl.constexpr,
359 BH: tl.constexpr,
360):
361 num_head_blocks: tl.constexpr = (HQ + BH - 1) // BH
362 pid = tl.program_id(0)
363 i_b = pid // (SQ * num_head_blocks)
364 remainder = pid % (SQ * num_head_blocks)
365 i_sq = remainder // num_head_blocks
366 i_sq = i_sq.to(tl.int64)
367 i_gbh = remainder % num_head_blocks
368 gbh_base = i_gbh * BH
370 NOPE: tl.constexpr = 448
371 ROPE: tl.constexpr = 64
372 # D: tl.constexpr = 512
373 BDP: tl.constexpr = 256
374 TOKEN_DATA_BYTES: tl.constexpr = 576
375 SCALE_BYTES: tl.constexpr = 8
377 q_base = q + i_b * stride_qb + i_sq * stride_qsq + gbh_base * stride_qh
378 t_base = indices + i_b * stride_ib + i_sq * stride_isq
379 et_base = extra_indices + i_b * stride_eib + i_sq * stride_eisq
380 attn_sink_ptr = attn_sink + gbh_base if HAVE_ATTN_SINK else 0
381 topk_length_ptr = topk_length + i_b if HAVE_TOPK_LENGTH else 0
382 extra_topk_length_ptr = extra_topk_length + i_b if HAVE_EXTRA_TOPK_LENGTH else 0
383 o_base = output + i_b * stride_ob + i_sq * stride_osq + gbh_base * stride_oh
384 l_base = lse + i_b * stride_lseb + gbh_base * stride_lseh + i_sq
386 offs_h = tl.arange(0, BH)
387 offs_d = tl.arange(0, BDP)
388 offs_t = tl.arange(0, BK)
389 offs_rope = tl.arange(0, ROPE)
391 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :]
392 q_blk0 = tl.load(q_ptr, eviction_policy="evict_first")
393 q_blk1_nope = tl.load(
394 q_ptr + BDP,
395 mask=offs_d[None, :] < (NOPE - BDP),
396 other=0.0,
397 eviction_policy="evict_first",
398 )
399 q_rope = tl.load(
400 q_base + offs_h[:, None] * stride_qh + (NOPE + offs_rope[None, :]),
401 eviction_policy="evict_first",
402 )
404 max_log = tl.full([BH], float("-inf"), dtype=tl.float32)
405 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
406 acc0 = tl.zeros([BH, BDP], dtype=tl.float32)
407 acc1 = tl.zeros([BH, BDP], dtype=tl.float32)
409 topk_len = tl.load(topk_length_ptr) if HAVE_TOPK_LENGTH else TOPK
410 NK = tl.cdiv(topk_len, BK)
411 for ck in range(NK):
412 t_offs = BK * ck + offs_t
413 t_msk = t_offs < topk_len
414 kv_ids = tl.load(t_base + t_offs, t_msk, other=-1)
415 block_ids = kv_ids // PAGE_SIZE
416 rel_ids = kv_ids - block_ids * PAGE_SIZE
417 valid_ids = t_msk & (kv_ids >= 0) & (block_ids < NUM_BLOCKS)
418 block_ids = tl.where(valid_ids, block_ids, 0)
419 rel_ids = tl.where(valid_ids, rel_ids, 0)
421 token_base = (
422 kv + block_ids.to(tl.int64) * stride_kv_block + rel_ids * TOKEN_DATA_BYTES
423 )
424 scale_base = (
425 kv
426 + block_ids.to(tl.int64) * stride_kv_block
427 + PAGE_SIZE * TOKEN_DATA_BYTES
428 + rel_ids * SCALE_BYTES
429 )
431 kv_fp8_0_u8 = tl.load(
432 token_base[None, :] + offs_d[:, None],
433 mask=valid_ids[None, :],
434 other=0,
435 cache_modifier=".cg",
436 )
437 kv_fp8_1_u8 = tl.load(
438 token_base[None, :] + (BDP + offs_d[:, None]),
439 mask=valid_ids[None, :] & (offs_d[:, None] < (NOPE - BDP)),
440 other=0,
441 cache_modifier=".cg",
442 )
444 scale0_u8 = tl.load(scale_base + 0, mask=valid_ids, other=127)
445 scale1_u8 = tl.load(scale_base + 1, mask=valid_ids, other=127)
446 scale2_u8 = tl.load(scale_base + 2, mask=valid_ids, other=127)
447 scale3_u8 = tl.load(scale_base + 3, mask=valid_ids, other=127)
448 scale4_u8 = tl.load(scale_base + 4, mask=valid_ids, other=127)
449 scale5_u8 = tl.load(scale_base + 5, mask=valid_ids, other=127)
450 scale6_u8 = tl.load(scale_base + 6, mask=valid_ids, other=127)
452 scale0 = (scale0_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
453 scale1 = (scale1_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
454 scale2 = (scale2_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
455 scale3 = (scale3_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
456 scale4 = (scale4_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
457 scale5 = (scale5_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
458 scale6 = (scale6_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
460 kv_fp8_0 = kv_fp8_0_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32)
461 scale_0 = tl.where(
462 offs_d[:, None] < 64,
463 scale0[None, :],
464 tl.where(
465 offs_d[:, None] < 128,
466 scale1[None, :],
467 tl.where(offs_d[:, None] < 192, scale2[None, :], scale3[None, :]),
468 ),
469 )
470 kv_blk0 = (kv_fp8_0 * scale_0).to(tl.bfloat16)
472 kv_fp8_1 = kv_fp8_1_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32)
473 scale_1 = tl.where(
474 offs_d[:, None] < 64,
475 scale4[None, :],
476 tl.where(offs_d[:, None] < 128, scale5[None, :], scale6[None, :]),
477 )
478 nope_tail = (kv_fp8_1 * scale_1).to(tl.bfloat16)
480 rope_ptr = (token_base + NOPE).to(tl.pointer_type(tl.bfloat16))
481 rope_blk = tl.load(
482 rope_ptr[None, :] + offs_rope[:, None],
483 mask=valid_ids[None, :],
484 other=0.0,
485 cache_modifier=".cg",
486 )
488 kv_blk1 = tl.where(
489 offs_d[:, None] < (NOPE - BDP),
490 nope_tail,
491 tl.load(
492 rope_ptr[None, :] + (offs_d[:, None] - (NOPE - BDP)),
493 mask=valid_ids[None, :] & (offs_d[:, None] >= (NOPE - BDP)),
494 other=0.0,
495 cache_modifier=".cg",
496 ),
497 )
499 qk = tl.dot(q_blk0, kv_blk0, out_dtype=tl.float32)
500 qk = tl.dot(q_blk1_nope, nope_tail, qk, out_dtype=tl.float32)
501 qk = tl.dot(q_rope, rope_blk, qk, out_dtype=tl.float32)
502 qk *= sm_scale
503 qk = tl.where(valid_ids[None, :], qk, float("-inf"))
505 new_max = tl.maximum(max_log, tl.max(qk, axis=1))
506 exp_qk = tl.math.exp(qk - new_max[:, None])
507 sum_qk = tl.sum(exp_qk, axis=1)
508 alpha = tl.math.exp(max_log - new_max)
509 sum_exp = sum_exp * alpha + sum_qk
510 acc0 = tl.dot(
511 exp_qk.to(tl.bfloat16),
512 kv_blk0.trans(),
513 acc0 * alpha[:, None],
514 out_dtype=tl.float32,
515 )
516 acc1 = tl.dot(
517 exp_qk.to(tl.bfloat16),
518 kv_blk1.trans(),
519 acc1 * alpha[:, None],
520 out_dtype=tl.float32,
521 )
522 max_log = new_max
524 if HAVE_EXTRA:
525 extra_topk_len = (
526 tl.load(extra_topk_length_ptr) if HAVE_EXTRA_TOPK_LENGTH else EXTRA_TOPK
527 )
528 ENK = tl.cdiv(extra_topk_len, BK)
529 for ck in range(ENK):
530 t_offs = BK * ck + offs_t
531 t_msk = t_offs < extra_topk_len
532 kv_ids = tl.load(et_base + t_offs, t_msk, other=-1)
533 block_ids = kv_ids // EXTRA_PAGE_SIZE
534 rel_ids = kv_ids - block_ids * EXTRA_PAGE_SIZE
535 valid_ids = t_msk & (kv_ids >= 0) & (block_ids < EXTRA_NUM_BLOCKS)
536 block_ids = tl.where(valid_ids, block_ids, 0)
537 rel_ids = tl.where(valid_ids, rel_ids, 0)
539 token_base = (
540 extra_kv
541 + block_ids.to(tl.int64) * stride_extra_kv_block
542 + rel_ids * TOKEN_DATA_BYTES
543 )
544 scale_base = (
545 extra_kv
546 + block_ids.to(tl.int64) * stride_extra_kv_block
547 + EXTRA_PAGE_SIZE * TOKEN_DATA_BYTES
548 + rel_ids * SCALE_BYTES
549 )
551 kv_fp8_0_u8 = tl.load(
552 token_base[None, :] + offs_d[:, None],
553 mask=valid_ids[None, :],
554 other=0,
555 cache_modifier=".cg",
556 )
557 kv_fp8_1_u8 = tl.load(
558 token_base[None, :] + (BDP + offs_d[:, None]),
559 mask=valid_ids[None, :] & (offs_d[:, None] < (NOPE - BDP)),
560 other=0,
561 cache_modifier=".cg",
562 )
564 scale0_u8 = tl.load(scale_base + 0, mask=valid_ids, other=127)
565 scale1_u8 = tl.load(scale_base + 1, mask=valid_ids, other=127)
566 scale2_u8 = tl.load(scale_base + 2, mask=valid_ids, other=127)
567 scale3_u8 = tl.load(scale_base + 3, mask=valid_ids, other=127)
568 scale4_u8 = tl.load(scale_base + 4, mask=valid_ids, other=127)
569 scale5_u8 = tl.load(scale_base + 5, mask=valid_ids, other=127)
570 scale6_u8 = tl.load(scale_base + 6, mask=valid_ids, other=127)
572 scale0 = (scale0_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
573 scale1 = (scale1_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
574 scale2 = (scale2_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
575 scale3 = (scale3_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
576 scale4 = (scale4_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
577 scale5 = (scale5_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
578 scale6 = (scale6_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True)
580 kv_fp8_0 = kv_fp8_0_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32)
581 scale_0 = tl.where(
582 offs_d[:, None] < 64,
583 scale0[None, :],
584 tl.where(
585 offs_d[:, None] < 128,
586 scale1[None, :],
587 tl.where(offs_d[:, None] < 192, scale2[None, :], scale3[None, :]),
588 ),
589 )
590 kv_blk0 = (kv_fp8_0 * scale_0).to(tl.bfloat16)
592 kv_fp8_1 = kv_fp8_1_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32)
593 scale_1 = tl.where(
594 offs_d[:, None] < 64,
595 scale4[None, :],
596 tl.where(offs_d[:, None] < 128, scale5[None, :], scale6[None, :]),
597 )
598 nope_tail = (kv_fp8_1 * scale_1).to(tl.bfloat16)
600 rope_ptr = (token_base + NOPE).to(tl.pointer_type(tl.bfloat16))
601 rope_blk = tl.load(
602 rope_ptr[None, :] + offs_rope[:, None],
603 mask=valid_ids[None, :],
604 other=0.0,
605 cache_modifier=".cg",
606 )
607 kv_blk1 = tl.where(
608 offs_d[:, None] < (NOPE - BDP),
609 nope_tail,
610 tl.load(
611 rope_ptr[None, :] + (offs_d[:, None] - (NOPE - BDP)),
612 mask=valid_ids[None, :] & (offs_d[:, None] >= (NOPE - BDP)),
613 other=0.0,
614 cache_modifier=".cg",
615 ),
616 )
618 qk = tl.dot(q_blk0, kv_blk0, out_dtype=tl.float32)
619 qk = tl.dot(q_blk1_nope, nope_tail, qk, out_dtype=tl.float32)
620 qk = tl.dot(q_rope, rope_blk, qk, out_dtype=tl.float32)
621 qk *= sm_scale
622 qk = tl.where(valid_ids[None, :], qk, float("-inf"))
624 new_max = tl.maximum(max_log, tl.max(qk, axis=1))
625 exp_qk = tl.math.exp(qk - new_max[:, None])
626 sum_qk = tl.sum(exp_qk, axis=1)
627 alpha = tl.math.exp(max_log - new_max)
628 sum_exp = sum_exp * alpha + sum_qk
629 acc0 = tl.dot(
630 exp_qk.to(tl.bfloat16),
631 kv_blk0.trans(),
632 acc0 * alpha[:, None],
633 out_dtype=tl.float32,
634 )
635 acc1 = tl.dot(
636 exp_qk.to(tl.bfloat16),
637 kv_blk1.trans(),
638 acc1 * alpha[:, None],
639 out_dtype=tl.float32,
640 )
641 max_log = new_max
643 valid_mask = max_log != float("-inf")
644 orig_lse = max_log + tl.math.log(sum_exp)
645 lse_out = tl.where(valid_mask, orig_lse, float("inf"))
646 tl.store(l_base + offs_h * stride_lseh, lse_out)
648 if HAVE_ATTN_SINK:
649 sink = tl.load(attn_sink_ptr + offs_h)
650 sum_exp_new_lse = tl.math.exp(orig_lse) + tl.math.exp(sink)
651 factor = tl.math.exp(max_log) / sum_exp_new_lse
652 else:
653 factor = 1.0 / sum_exp
655 out_vals0 = tl.where(valid_mask[:, None], acc0 * factor[:, None], 0.0)
656 out_vals1 = tl.where(valid_mask[:, None], acc1 * factor[:, None], 0.0)
657 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_d[None, :]
658 tl.store(o_ptr, out_vals0.to(tl.bfloat16))
659 tl.store(o_ptr + BDP, out_vals1.to(tl.bfloat16))
662# ============================================================================
663# Dense decode kernel (paged attention with block_table)
664# ============================================================================
667@triton.autotune(
668 configs=[
669 triton.Config({"BLOCK_H": 64, "BLOCK_N": 64}, num_warps=8, num_stages=2),
670 triton.Config({"BLOCK_H": 64, "BLOCK_N": 64}, num_warps=8, num_stages=3),
671 ],
672 key=["HQ", "DQK", "HAVE_CAUSAL"],
673)
674@triton.jit
675def _dense_decode_kernel(
676 Q_ptr,
677 stride_q_b,
678 stride_q_sq,
679 stride_q_h,
680 KV_cache,
681 stride_kv_bs,
682 Block_table,
683 stride_bt_b,
684 Seq_lens,
685 Out,
686 stride_o_b,
687 stride_o_sq,
688 stride_o_h,
689 LSE,
690 stride_lse_b,
691 stride_lse_h,
692 sm_scale,
693 SQ,
694 HQ: tl.constexpr,
695 DQK: tl.constexpr,
696 HEAD_DIM_V: tl.constexpr,
697 PAGE_SIZE: tl.constexpr,
698 HAVE_CAUSAL: tl.constexpr,
699 BLOCK_H: tl.constexpr,
700 BLOCK_N: tl.constexpr,
701):
702 """
703 Dense decode kernel with paged attention and online softmax.
704 Grid: (ceil(HQ / BLOCK_H), batch_size * seq_q)
705 """
706 pid_h_block = tl.program_id(0)
707 pid_b_sq = tl.program_id(1)
708 i_b = pid_b_sq // SQ
709 i_sq = pid_b_sq % SQ
711 cur_head = pid_h_block * BLOCK_H + tl.arange(0, BLOCK_H)
712 mask_head = cur_head < HQ
714 # Load Q: NoPE part [BLOCK_H, HEAD_DIM_V] and RoPE part [BLOCK_H, DQK-HEAD_DIM_V]
715 offs_d_nope = tl.arange(0, HEAD_DIM_V)
716 offs_q_nope = (
717 i_b * stride_q_b
718 + i_sq * stride_q_sq
719 + cur_head[:, None] * stride_q_h
720 + offs_d_nope[None, :]
721 )
722 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None], other=0.0)
724 offs_d_pe = tl.arange(HEAD_DIM_V, DQK)
725 offs_q_pe = (
726 i_b * stride_q_b
727 + i_sq * stride_q_sq
728 + cur_head[:, None] * stride_q_h
729 + offs_d_pe[None, :]
730 )
731 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None], other=0.0)
733 # Online softmax accumulators
734 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32)
735 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
736 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32)
738 cur_batch_seq_len = tl.load(Seq_lens + i_b)
739 Block_table += i_b * stride_bt_b
741 offs_n = tl.arange(0, BLOCK_N)
742 loop_time = cur_batch_seq_len // BLOCK_N
743 remainder = cur_batch_seq_len % BLOCK_N
745 for i in range(0, loop_time):
746 kv_page_number = tl.load(Block_table + offs_n // PAGE_SIZE)
747 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
749 # Load V (NoPE part)
750 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_nope[None, :]
751 v_c = tl.load(KV_cache + offs_v_c)
752 k_c = tl.trans(v_c)
754 # QK = q_nope @ k_nope^T
755 qk = tl.dot(q_nope, k_c)
757 # Add RoPE contribution
758 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_pe[:, None]
759 k_pe = tl.load(KV_cache + offs_k_pe)
760 qk = tl.dot(q_pe, k_pe, acc=qk)
761 qk *= sm_scale
763 # Online softmax update
764 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
765 re_scale = tl.exp(e_max - n_e_max)
766 p = tl.exp(qk - n_e_max[:, None])
767 acc *= re_scale[:, None]
768 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
769 e_sum = e_sum * re_scale + tl.sum(p, 1)
770 e_max = n_e_max
771 offs_n += BLOCK_N
773 if remainder:
774 mask_kvsplit = offs_n < cur_batch_seq_len
775 kv_page_number = tl.load(
776 Block_table + offs_n // PAGE_SIZE, mask=mask_kvsplit, other=0
777 )
778 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
780 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_nope[None, :]
781 v_c = tl.load(KV_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0)
782 k_c = tl.trans(v_c)
784 qk = tl.dot(q_nope, k_c)
786 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_pe[:, None]
787 k_pe = tl.load(KV_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0)
788 qk = tl.dot(q_pe, k_pe, acc=qk)
789 qk *= sm_scale
791 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf"))
793 n_e_max = tl.maximum(tl.max(qk, 1), e_max)
794 re_scale = tl.exp(e_max - n_e_max)
795 p = tl.exp(qk - n_e_max[:, None])
796 acc *= re_scale[:, None]
797 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc)
798 e_sum = e_sum * re_scale + tl.sum(p, 1)
799 e_max = n_e_max
801 # Store output
802 offs_o = (
803 i_b * stride_o_b
804 + i_sq * stride_o_sq
805 + cur_head[:, None] * stride_o_h
806 + offs_d_nope[None, :]
807 )
808 tl.store(
809 Out + offs_o,
810 (acc / e_sum[:, None]).to(Out.dtype.element_ty),
811 mask=mask_head[:, None],
812 )
814 # Store LSE
815 lse_val = e_max + tl.math.log(e_sum)
816 lse_offset = i_b * stride_lse_b + cur_head * stride_lse_h + i_sq
817 tl.store(LSE + lse_offset, lse_val, mask=mask_head)
820# ============================================================================
821# Main dispatch function
822# ============================================================================
825def flash_mla_with_kvcache(
826 q: torch.Tensor,
827 k_cache: torch.Tensor,
828 block_table: Optional[torch.Tensor],
829 cache_seqlens: Optional[torch.Tensor],
830 head_dim_v: int,
831 tile_scheduler_metadata: FlashMLASchedMeta,
832 num_splits: None = None,
833 softmax_scale: Optional[float] = None,
834 causal: bool = False,
835 is_fp8_kvcache: bool = False,
836 indices: Optional[torch.Tensor] = None,
837 attn_sink: Optional[torch.Tensor] = None,
838 extra_k_cache: Optional[torch.Tensor] = None,
839 extra_indices_in_kvcache: Optional[torch.Tensor] = None,
840 topk_length: Optional[torch.Tensor] = None,
841 extra_topk_length: Optional[torch.Tensor] = None,
842 out: Optional[torch.Tensor] = None,
843) -> Tuple[torch.Tensor, torch.Tensor]:
844 """
845 Triton implementation of flash_mla_with_kvcache.
846 Functionally equivalent to the CUDA implementation.
848 Returns:
849 out: (batch_size, seq_len_q, num_heads_q, head_dim_v)
850 softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32
851 """
852 sched_meta = tile_scheduler_metadata
853 assert isinstance(sched_meta, FlashMLASchedMeta)
854 assert num_splits is None
855 assert q.ndim == 4
856 assert k_cache.ndim == 4
858 topk = indices.shape[-1] if indices is not None else None
859 extra_k_page_block_size = (
860 extra_k_cache.shape[1] if extra_k_cache is not None else None
861 )
862 extra_topk_val = (
863 extra_indices_in_kvcache.shape[-1]
864 if extra_indices_in_kvcache is not None
865 else None
866 )
868 if softmax_scale is None:
869 softmax_scale = q.shape[-1] ** (-0.5)
871 if not sched_meta.have_initialized:
872 if indices is not None:
873 assert not causal, "causal must be False when sparse attention is enabled"
874 sched_meta.have_initialized = True
875 sched_meta.config = FlashMLASchedMeta.Config(
876 q.shape[0],
877 q.shape[1],
878 q.shape[2],
879 k_cache.shape[1],
880 k_cache.shape[2],
881 causal,
882 is_fp8_kvcache,
883 topk,
884 extra_k_page_block_size,
885 extra_topk_val,
886 )
887 else:
888 helper_msg = (
889 " Your input arguments are inconsistent with sched_meta. Please make "
890 "sure the input arguments are consistent across different invocations "
891 "of flash_mla_with_kvcache on the same sched_meta."
892 )
893 assert sched_meta.config is not None
894 assert sched_meta.config.b == q.shape[0], (
895 "sched_meta.config.b must be equal to batch_size." + helper_msg
896 )
897 assert sched_meta.config.s_q == q.shape[1], (
898 "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
899 )
900 assert sched_meta.config.h_q == q.shape[2], (
901 "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
902 )
903 assert sched_meta.config.page_block_size == k_cache.shape[1], (
904 "sched_meta.config.page_block_size must be equal to page_block_size."
905 + helper_msg
906 )
907 assert sched_meta.config.h_k == k_cache.shape[2], (
908 "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
909 )
910 assert sched_meta.config.causal == causal, (
911 "sched_meta.config.causal must be equal to causal." + helper_msg
912 )
913 assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, (
914 "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache."
915 + helper_msg
916 )
917 assert sched_meta.config.topk == topk, (
918 "sched_meta.config.topk must be equal to the last dim of indices."
919 + helper_msg
920 )
921 assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, (
922 "sched_meta.config.extra_page_block_size must be equal to the "
923 "page_block_size of extra_k_cache." + helper_msg
924 )
925 assert sched_meta.config.extra_topk == extra_topk_val, (
926 "sched_meta.config.extra_topk must be equal to the last dim of "
927 "extra_indices_in_kvcache." + helper_msg
928 )
930 batch_size, seq_q, num_heads_q, head_dim_k = q.shape
931 num_heads_k = k_cache.shape[2]
933 if out is None:
934 out = torch.empty(
935 (batch_size, seq_q, num_heads_q, head_dim_v),
936 dtype=q.dtype,
937 device=q.device,
938 )
939 else:
940 assert out.shape == (batch_size, seq_q, num_heads_q, head_dim_v)
941 assert out.dtype == q.dtype
942 assert out.device == q.device
943 assert out.stride(-1) == 1
944 lse = torch.empty(
945 (batch_size, num_heads_q, seq_q),
946 dtype=torch.float32,
947 device=q.device,
948 )
950 if indices is not None:
951 assert not causal, "causal must be False when sparse attention is enabled"
952 assert is_fp8_kvcache, "is_fp8_kvcache must be True for sparse attention"
953 assert (
954 num_heads_k == 1
955 ), "Currently only MQA (h_kv == 1) is supported for sparse decoding"
956 assert head_dim_v == 512, "Only head_size_v == 512 is supported"
957 assert num_heads_q in (64, 128), "Only h_q == 64 or 128 is supported"
958 assert head_dim_k in (
959 512,
960 576,
961 ), "Only head_size_k == 512 or 576 is supported for sparse decoding"
962 assert q.dtype == torch.bfloat16
963 assert k_cache.dtype in (torch.float8_e4m3fn, torch.int8, torch.uint8)
964 assert topk is not None and topk > 0
965 assert topk % 64 == 0, "topk must be divisible by 64"
966 assert indices.ndim == 3 and indices.shape[:2] == (batch_size, seq_q)
967 assert indices.dtype == torch.int32
968 assert indices.stride(-1) == 1
969 if topk_length is not None:
970 assert topk_length.shape == (batch_size,)
971 assert topk_length.dtype == torch.int32
972 assert topk_length.is_contiguous()
973 if attn_sink is not None:
974 assert attn_sink.shape == (num_heads_q,)
975 assert attn_sink.dtype == torch.float32
976 if extra_k_cache is not None:
977 assert extra_indices_in_kvcache is not None, (
978 "extra_indices_in_kvcache must be provided when extra_k_cache "
979 "is provided"
980 )
981 assert extra_k_cache.dtype in (
982 torch.float8_e4m3fn,
983 torch.int8,
984 torch.uint8,
985 )
986 else:
987 assert extra_indices_in_kvcache is None, (
988 "extra_indices_in_kvcache must not be provided when extra_k_cache "
989 "is not provided"
990 )
991 assert extra_topk_length is None, (
992 "extra_topk_length must not be provided when extra_k_cache is "
993 "not provided"
994 )
995 if extra_indices_in_kvcache is not None:
996 assert extra_indices_in_kvcache.ndim == 3
997 assert extra_indices_in_kvcache.shape[:2] == (batch_size, seq_q)
998 assert extra_indices_in_kvcache.dtype == torch.int32
999 assert extra_indices_in_kvcache.stride(-1) == 1
1000 assert extra_indices_in_kvcache.shape[-1] % 64 == 0
1001 if extra_topk_length is not None:
1002 assert extra_topk_length.shape == (batch_size,)
1003 assert extra_topk_length.dtype == torch.int32
1004 assert extra_topk_length.is_contiguous()
1005 if head_dim_k == 576:
1006 assert (
1007 k_cache.shape[-1] == 656
1008 ), "V32 sparse FP8 cache must use 656 bytes per token"
1009 assert (
1010 k_cache.stride(1) == 656
1011 ), "The whole block must be contiguous for V32 KV cache"
1012 assert topk_length is None, "V3.2/V32 does not support dynamic topk length"
1013 assert extra_k_cache is None, "V3.2/V32 does not support extra KV cache"
1014 assert (
1015 extra_indices_in_kvcache is None
1016 ), "V3.2/V32 does not support extra indices"
1017 assert (
1018 extra_topk_length is None
1019 ), "V3.2/V32 does not support extra topk length"
1020 else:
1021 assert (
1022 k_cache.shape[-1] == 584
1023 ), "MODEL1 sparse FP8 cache must use 584 bytes per token"
1024 assert (
1025 k_cache.stride(1) == 584
1026 ), "The whole block must be contiguous for MODEL1 KV cache"
1027 if extra_k_cache is not None:
1028 assert extra_k_cache.ndim == 4
1029 assert extra_k_cache.shape[2] == 1
1030 assert extra_k_cache.shape[-1] == 584
1031 assert extra_k_cache.stride(1) == 584
1032 _sparse_decode_dispatch(
1033 q,
1034 k_cache,
1035 indices,
1036 out,
1037 lse,
1038 attn_sink,
1039 topk_length,
1040 extra_k_cache,
1041 extra_indices_in_kvcache,
1042 extra_topk_length,
1043 batch_size,
1044 seq_q,
1045 num_heads_q,
1046 head_dim_k,
1047 head_dim_v,
1048 topk,
1049 k_cache.shape[1],
1050 softmax_scale,
1051 is_fp8_kvcache,
1052 )
1053 else:
1054 assert (
1055 attn_sink is None
1056 and extra_k_cache is None
1057 and extra_indices_in_kvcache is None
1058 and topk_length is None
1059 and extra_topk_length is None
1060 ), (
1061 "indices, attn_sink, extra_k_cache, extra_indices_in_kvcache, "
1062 "topk_length and extra_topk_length must be None when dense "
1063 "attention is used."
1064 )
1065 assert block_table is not None and cache_seqlens is not None, (
1066 "block_table and cache_seqlens must be provided when dense attention "
1067 "is used."
1068 )
1069 assert num_heads_k == 1, "Only num_heads_k == 1 is supported for dense MLA"
1070 if seq_q > 1 and causal:
1071 raise NotImplementedError(
1072 "causal dense attention with seq_q > 1 is not implemented"
1073 )
1074 _dense_decode_dispatch(
1075 q,
1076 k_cache,
1077 block_table,
1078 cache_seqlens,
1079 out,
1080 lse,
1081 batch_size,
1082 seq_q,
1083 num_heads_q,
1084 head_dim_k,
1085 head_dim_v,
1086 k_cache.shape[1],
1087 softmax_scale,
1088 causal,
1089 )
1091 return out, lse
1094# ============================================================================
1095# Kernel launch helpers
1096# ============================================================================
1099def _sparse_decode_dispatch(
1100 q,
1101 kv,
1102 indices,
1103 out,
1104 lse,
1105 attn_sink,
1106 topk_length,
1107 extra_kv,
1108 extra_indices,
1109 extra_topk_length,
1110 batch_size,
1111 seq_q,
1112 num_heads_q,
1113 head_dim_k,
1114 head_dim_v,
1115 topk,
1116 page_block_size,
1117 softmax_scale,
1118 is_fp8_kvcache,
1119):
1120 """Launch sparse decode kernel."""
1121 BH = 64
1122 num_head_blocks = (num_heads_q + BH - 1) // BH
1123 grid = (batch_size * seq_q * num_head_blocks,)
1125 skv = kv.shape[0] * page_block_size
1127 if head_dim_k == 512:
1128 _sparse_decode_model1_kernel[grid](
1129 q,
1130 kv,
1131 indices,
1132 extra_kv if extra_kv is not None else kv,
1133 extra_indices if extra_indices is not None else indices,
1134 attn_sink if attn_sink is not None else None,
1135 topk_length if topk_length is not None else None,
1136 extra_topk_length if extra_topk_length is not None else None,
1137 softmax_scale,
1138 out,
1139 lse,
1140 # Q strides
1141 q.stride(0),
1142 q.stride(1),
1143 q.stride(2),
1144 # KV and indices strides
1145 kv.stride(0),
1146 indices.stride(0),
1147 indices.stride(1),
1148 extra_kv.stride(0) if extra_kv is not None else kv.stride(0),
1149 extra_indices.stride(0) if extra_indices is not None else indices.stride(0),
1150 extra_indices.stride(1) if extra_indices is not None else indices.stride(1),
1151 # Output strides
1152 out.stride(0),
1153 out.stride(1),
1154 out.stride(2),
1155 # LSE strides
1156 lse.stride(0),
1157 lse.stride(1),
1158 # Scalar args
1159 seq_q,
1160 num_heads_q,
1161 page_block_size,
1162 extra_kv.shape[1] if extra_kv is not None else 1,
1163 kv.shape[0],
1164 extra_kv.shape[0] if extra_kv is not None else 0,
1165 topk,
1166 extra_indices.shape[-1] if extra_indices is not None else 0,
1167 attn_sink is not None,
1168 topk_length is not None,
1169 extra_kv is not None,
1170 extra_topk_length is not None,
1171 )
1172 return
1174 if is_fp8_kvcache:
1175 # FP8 mode: kv has shape [num_blocks, page_block_size, 1, 656]
1176 # Layout per token (656 bytes):
1177 # [0:512] - 512 float8_e4m3fn values (NoPE)
1178 # [512:528] - 4 float32 scales (16 bytes)
1179 # [528:656] - 64 bfloat16 values (RoPE, 128 bytes)
1180 kv_bytes = kv.reshape(-1, 656).contiguous() # [num_tokens, 656] uint8
1182 # NoPE FP8 part: first 512 bytes as float8_e4m3fn
1183 kv_nope = (
1184 kv_bytes[:, :512].contiguous().view(torch.float8_e4m3fn)
1185 ) # [num_tokens, 512]
1186 stride_kvn = kv_nope.stride(0)
1188 # Scales: bytes [512:528] as 4 float32 values
1189 kv_scales = (
1190 kv_bytes[:, 512:528].contiguous().view(torch.float32)
1191 ) # [num_tokens, 4]
1192 stride_scales_n = kv_scales.stride(0)
1194 # RoPE BF16 part: bytes [528:656] as 64 bfloat16 values
1195 kv_rope = (
1196 kv_bytes[:, 528:656].contiguous().view(torch.bfloat16)
1197 ) # [num_tokens, 64]
1198 stride_rope_n = kv_rope.stride(0)
1199 else:
1200 # BF16 mode: kv has shape [num_blocks, page_block_size, 1, head_dim_k]
1201 kv_nope = kv.reshape(-1, kv.shape[-1]).contiguous()
1202 stride_kvn = kv_nope.stride(0)
1203 kv_scales = kv_nope # unused, pass same tensor
1204 stride_scales_n = 0
1205 kv_rope = kv_nope # unused, pass same tensor
1206 stride_rope_n = 0
1208 # # TLE warp specialization path TODO
1209 # if _can_use_tle_sparse_decode(q, indices, head_dim_v, head_dim_k, is_fp8_kvcache):
1210 # _tle_sparse_decode_launch(
1211 # q, kv_nope, kv_scales, kv_rope, indices, out, lse,
1212 # attn_sink, topk_length,
1213 # batch_size, seq_q, num_heads_q,
1214 # head_dim_k, head_dim_v, topk, skv,
1215 # softmax_scale, is_fp8_kvcache,
1216 # stride_kvn, stride_scales_n, stride_rope_n,
1217 # )
1218 # return
1220 _sparse_decode_kernel[grid](
1221 q,
1222 kv_nope,
1223 kv_scales,
1224 kv_rope,
1225 indices,
1226 attn_sink if attn_sink is not None else None,
1227 topk_length if topk_length is not None else None,
1228 softmax_scale,
1229 out,
1230 lse,
1231 # Q strides
1232 q.stride(0),
1233 q.stride(1),
1234 q.stride(2),
1235 # KV strides
1236 stride_kvn,
1237 stride_scales_n,
1238 stride_rope_n,
1239 # Indices strides
1240 indices.stride(0),
1241 indices.stride(1),
1242 # Output strides
1243 out.stride(0),
1244 out.stride(1),
1245 out.stride(2),
1246 # LSE strides
1247 lse.stride(0),
1248 lse.stride(1),
1249 # Scalar args
1250 seq_q,
1251 num_heads_q,
1252 head_dim_k,
1253 skv,
1254 topk,
1255 attn_sink is not None,
1256 topk_length is not None,
1257 is_fp8_kvcache,
1258 )
1261def _dense_decode_dispatch(
1262 q,
1263 kv_cache,
1264 block_table,
1265 cache_seqlens,
1266 out,
1267 lse,
1268 batch_size,
1269 seq_q,
1270 num_heads_q,
1271 head_dim_k,
1272 head_dim_v,
1273 page_block_size,
1274 softmax_scale,
1275 causal,
1276):
1277 """Launch dense decode kernel."""
1278 BLOCK_H = 64
1279 num_head_blocks = (num_heads_q + BLOCK_H - 1) // BLOCK_H
1281 # KV cache: [num_blocks, page_block_size, num_heads_k, head_dim_k]
1282 # Flatten to [num_tokens_total, head_dim_k] for paged access
1283 kv_flat = kv_cache.view(-1, head_dim_k).contiguous()
1284 block_table = block_table.contiguous()
1286 # TLE warp specialization path
1287 if _can_use_tle_dense_decode(q, kv_cache, block_table, head_dim_v, page_block_size):
1288 _tle_dense_decode_launch(
1289 q,
1290 kv_flat,
1291 block_table,
1292 cache_seqlens,
1293 out,
1294 lse,
1295 batch_size,
1296 seq_q,
1297 num_heads_q,
1298 head_dim_k,
1299 head_dim_v,
1300 page_block_size,
1301 softmax_scale,
1302 causal,
1303 )
1304 return
1306 grid = (num_head_blocks, batch_size * seq_q)
1308 _dense_decode_kernel[grid](
1309 q,
1310 q.stride(0),
1311 q.stride(1),
1312 q.stride(2),
1313 kv_flat,
1314 kv_flat.stride(0),
1315 block_table,
1316 block_table.stride(0),
1317 cache_seqlens,
1318 out,
1319 out.stride(0),
1320 out.stride(1),
1321 out.stride(2),
1322 lse,
1323 lse.stride(0),
1324 lse.stride(1),
1325 softmax_scale,
1326 seq_q,
1327 num_heads_q,
1328 head_dim_k,
1329 head_dim_v,
1330 page_block_size,
1331 causal,
1332 )
1335# ============================================================================
1336# TLE Warp Specialization path for sparse decode
1337# ============================================================================
1340def _tle_decode_enabled() -> bool:
1341 value = os.environ.get("FLAGGEMS_FLASHMLA_DECODE_TLE", "1").lower()
1342 return value not in {"0", "false", "off", "no"}
1345def _can_use_tle_sparse_decode(
1346 q: torch.Tensor,
1347 indices: torch.Tensor,
1348 head_dim_v: int,
1349 head_dim_k: int,
1350 is_fp8: bool,
1351) -> bool:
1352 if not (HAS_TLE and _tle_decode_enabled()):
1353 return False
1354 if q.device.type != "cuda":
1355 return False
1356 batch_size, seq_q, num_heads_q, d_qk = q.shape
1357 TOPK = indices.shape[-1]
1358 return (
1359 head_dim_v == 512
1360 and d_qk in (512, 576)
1361 and num_heads_q % TLE_DECODE_BH == 0
1362 and TOPK > 0
1363 and TOPK % (TLE_DECODE_BK * TLE_DECODE_PAIR_BLOCKS) == 0
1364 )
1367def _can_use_tle_dense_decode(
1368 q: torch.Tensor,
1369 kv_cache: torch.Tensor,
1370 block_table: torch.Tensor,
1371 head_dim_v: int,
1372 page_block_size: int,
1373) -> bool:
1374 if not (HAS_TLE and _tle_decode_enabled()):
1375 return False
1376 if q.device.type != "cuda":
1377 return False
1378 batch_size, seq_q, num_heads_q, d_qk = q.shape
1379 return (
1380 head_dim_v == 512
1381 and d_qk in (512, 576)
1382 and num_heads_q % TLE_DECODE_BH == 0
1383 and page_block_size == TLE_DECODE_BK
1384 )
1387def _set_triton_descriptor_allocator(device: torch.device) -> None:
1388 def alloc_fn(size: int, align: int, stream):
1389 _ = align
1390 _ = stream
1391 return torch.empty(size, dtype=torch.int8, device=device)
1393 triton.set_allocator(alloc_fn)
1396def _tle_sparse_decode_launch(
1397 q,
1398 kv_nope,
1399 kv_scales,
1400 kv_rope,
1401 indices,
1402 out,
1403 lse,
1404 attn_sink,
1405 topk_length,
1406 batch_size,
1407 seq_q,
1408 num_heads_q,
1409 head_dim_k,
1410 head_dim_v,
1411 topk,
1412 skv,
1413 softmax_scale,
1414 is_fp8_kvcache,
1415 stride_kvn,
1416 stride_scales_n,
1417 stride_rope_n,
1418):
1419 """Launch TLE warp-specialized sparse decode kernel."""
1420 from triton.tools.tensor_descriptor import TensorDescriptor
1422 _set_triton_descriptor_allocator(q.device)
1424 BH = TLE_DECODE_BH
1425 BK = TLE_DECODE_BK
1426 D = head_dim_v # 512
1427 TD = head_dim_k - D # 64 for DQK=576, 0 for DQK=512
1428 DP = triton.next_power_of_2(D)
1429 DPH = DP // 2
1430 HAVE_TAIL = TD > 0
1431 TDP = triton.next_power_of_2(TD) if HAVE_TAIL else 1
1432 G = num_heads_q
1433 RH = G // BH
1435 # Reshape q for TensorDescriptor: [batch*seq_q*HQ, DQK]
1436 q_flat = q.reshape(batch_size * seq_q * num_heads_q, head_dim_k).contiguous()
1437 out_flat = out.reshape(batch_size * seq_q * num_heads_q, head_dim_v)
1439 q_desc = TensorDescriptor(
1440 q_flat,
1441 shape=[batch_size * seq_q * num_heads_q, head_dim_k],
1442 strides=[head_dim_k, 1],
1443 block_shape=[BH, DPH],
1444 )
1445 if HAVE_TAIL:
1446 tq_desc = TensorDescriptor(
1447 q_flat,
1448 shape=[batch_size * seq_q * num_heads_q, head_dim_k],
1449 strides=[head_dim_k, 1],
1450 block_shape=[BH, TDP],
1451 )
1452 else:
1453 tq_desc = q_desc
1454 output_desc = TensorDescriptor(
1455 out_flat,
1456 shape=[batch_size * seq_q * num_heads_q, D],
1457 strides=[D, 1],
1458 block_shape=[BH, DPH],
1459 )
1461 # Grid: one program per (batch*seq_q, head_block)
1462 grid = (batch_size * seq_q * RH,)
1464 # Indices stride: [batch, seq_q, topk] -> stride for batch*seq_q dim
1465 stride_isq = (
1466 indices.stride(0) * indices.stride(1) // indices.stride(1)
1467 if seq_q == 1
1468 else indices.stride(1)
1469 )
1470 # For shape [batch, seq_q, topk]: stride_isq = topk (contiguous)
1471 stride_isq = topk
1473 _tle_sparse_decode_fwd[grid](
1474 q_desc,
1475 tq_desc,
1476 output_desc,
1477 kv_nope,
1478 kv_scales,
1479 kv_rope,
1480 indices.reshape(batch_size * seq_q, topk).contiguous(),
1481 attn_sink,
1482 topk_length,
1483 softmax_scale,
1484 out_flat,
1485 lse.reshape(batch_size * seq_q, num_heads_q).contiguous(),
1486 batch_size * seq_q,
1487 num_heads_q,
1488 head_dim_k,
1489 skv,
1490 topk,
1491 attn_sink is not None,
1492 topk_length is not None,
1493 is_fp8_kvcache,
1494 D,
1495 TD,
1496 DP,
1497 TDP,
1498 G,
1499 RH,
1500 HAVE_TAIL,
1501 BK,
1502 BH,
1503 TLE_DECODE_PAIR_BLOCKS,
1504 stride_kvn,
1505 stride_scales_n,
1506 stride_rope_n,
1507 indices.stride(0),
1508 stride_isq,
1509 num_warps=TLE_DECODE_WORKER_NUM_WARPS,
1510 num_stages=1,
1511 )
1514if HAS_TLE:
1516 @triton.jit
1517 def _tle_sparse_decode_producer(
1518 k0_l_writer,
1519 k0_r_writer,
1520 k1_l_writer,
1521 k1_r_writer,
1522 valid_writer,
1523 kv_nope_base,
1524 kv_scales_base,
1525 kv_rope_base,
1526 t_base,
1527 topk_len_ptr,
1528 D: tl.constexpr,
1529 TD: tl.constexpr,
1530 DPH: tl.constexpr,
1531 TDP: tl.constexpr,
1532 SKV,
1533 TOPK: tl.constexpr,
1534 HAVE_TOPK_LENGTH: tl.constexpr,
1535 HAVE_TAIL: tl.constexpr,
1536 IS_FP8: tl.constexpr,
1537 BK: tl.constexpr,
1538 stride_kvn,
1539 stride_scales_n,
1540 stride_rope_n,
1541 ):
1542 """
1543 Producer warpgroup: loads KV data from global memory to shared memory.
1544 For FP8 mode: loads FP8 NoPE + scales + RoPE, dequantizes FP8 to BF16.
1545 For BF16 mode: loads BF16 KV directly.
1546 """
1547 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK
1548 max_col = SKV - 1
1549 NK = tl.cdiv(topk_len, BK)
1550 NPAIRS = tl.cdiv(NK, 2)
1551 offs_t = tl.arange(0, BK)
1552 offs_tile = tl.arange(0, 64)
1553 kv_tile_rows = tl.broadcast_to(offs_t[:, None], (BK, 64))
1555 for pair in tl.range(NPAIRS):
1556 ck0 = pair * 2
1557 ck1 = ck0 + 1
1559 # Load indices for both blocks
1560 t_offs0 = BK * ck0 + offs_t
1561 t_msk0 = t_offs0 < topk_len
1562 kv_ids0 = tl.load(t_base + t_offs0, t_msk0, other=-1)
1563 valid0 = t_msk0 & (kv_ids0 <= max_col) & (kv_ids0 >= 0)
1565 t_offs1 = BK * ck1 + offs_t
1566 t_msk1 = t_offs1 < topk_len
1567 kv_ids1 = tl.load(t_base + t_offs1, t_msk1, other=-1)
1568 valid1 = t_msk1 & (kv_ids1 <= max_col) & (kv_ids1 >= 0)
1570 # Process k0_l (left half of block 0)
1571 k0_l_slot = k0_l_writer.acquire(pair)
1572 for tile in tl.static_range(0, DPH, 64):
1573 k_cols = tile + offs_tile
1574 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
1576 if IS_FP8:
1577 # Load FP8 data
1578 kv_ptr = (
1579 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn
1580 )
1581 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :]
1582 k0_l_fp8 = tl.load(
1583 kv_ptr, mask=k0_l_msk, other=0.0, eviction_policy="evict_last"
1584 )
1586 # Load scales for dequantization
1587 # Each 128 elements share one scale
1588 scale_idx = k_cols // 128 # 0 or 1 for left half
1589 scale0 = tl.load(
1590 kv_scales_base + kv_ids0 * stride_scales_n + scale_idx,
1591 mask=valid0,
1592 other=1.0,
1593 )
1595 # Dequantize: FP8 * scale -> BF16
1596 k0_l_blk = (k0_l_fp8.to(tl.float32) * scale0[:, None]).to(
1597 tl.bfloat16
1598 )
1599 else:
1600 # BF16 mode: load directly
1601 kv_ptr = (
1602 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn
1603 )
1604 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :]
1605 k0_l_blk = tl.load(
1606 kv_ptr, mask=k0_l_msk, other=0.0, eviction_policy="evict_last"
1607 )
1609 tl.store(
1610 tle.gpu.local_ptr(k0_l_slot.sK, (kv_tile_rows, k_cols_b)),
1611 k0_l_blk,
1612 mask=valid0[:, None] & (k_cols < D)[None, :],
1613 )
1614 k0_l_writer.commit(pair)
1616 # Process k1_r (right half of block 1)
1617 k1_r_slot = k1_r_writer.acquire(pair)
1618 for tile in tl.static_range(0, DPH, 64):
1619 k_cols = DPH + tile + offs_tile
1620 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
1622 if IS_FP8:
1623 kv_ptr = (
1624 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn
1625 )
1626 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :]
1627 k1_r_fp8 = tl.load(
1628 kv_ptr, mask=k1_r_msk, other=0.0, eviction_policy="evict_last"
1629 )
1631 # Scale index: 2 or 3 for right half
1632 scale_idx = 2 + (k_cols - DPH) // 128
1633 scale1 = tl.load(
1634 kv_scales_base + kv_ids1 * stride_scales_n + scale_idx,
1635 mask=valid1,
1636 other=1.0,
1637 )
1639 k1_r_blk = (k1_r_fp8.to(tl.float32) * scale1[:, None]).to(
1640 tl.bfloat16
1641 )
1642 else:
1643 kv_ptr = (
1644 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn
1645 )
1646 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :]
1647 k1_r_blk = tl.load(
1648 kv_ptr, mask=k1_r_msk, other=0.0, eviction_policy="evict_last"
1649 )
1651 tl.store(
1652 tle.gpu.local_ptr(k1_r_slot.sK, (kv_tile_rows, k_cols_b)),
1653 k1_r_blk,
1654 mask=valid1[:, None] & (k_cols < D)[None, :],
1655 )
1657 # Load RoPE tail if needed
1658 if HAVE_TAIL:
1659 offs_td = tl.arange(0, TDP)
1660 if IS_FP8:
1661 k1_r_tail_ptr = (
1662 kv_rope_base
1663 + offs_td[None, :]
1664 + kv_ids1[:, None] * stride_rope_n
1665 )
1666 else:
1667 k1_r_tail_ptr = (
1668 kv_nope_base
1669 + D
1670 + offs_td[None, :]
1671 + kv_ids1[:, None] * stride_kvn
1672 )
1673 k1_r_tail_msk = valid1[:, None] & (offs_td < TD)[None, :]
1674 k1_r_tail_blk = tl.load(
1675 k1_r_tail_ptr,
1676 mask=k1_r_tail_msk,
1677 other=0.0,
1678 eviction_policy="evict_last",
1679 )
1680 tl.store(
1681 tle.gpu.local_ptr(k1_r_slot.sK_tail),
1682 k1_r_tail_blk,
1683 mask=k1_r_tail_msk,
1684 )
1685 k1_r_writer.commit(pair)
1687 # Process k0_r (right half of block 0)
1688 k0_r_slot = k0_r_writer.acquire(pair)
1689 for tile in tl.static_range(0, DPH, 64):
1690 k_cols = DPH + tile + offs_tile
1691 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
1693 if IS_FP8:
1694 kv_ptr = (
1695 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn
1696 )
1697 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :]
1698 k0_r_fp8 = tl.load(
1699 kv_ptr, mask=k0_r_msk, other=0.0, eviction_policy="evict_last"
1700 )
1702 scale_idx = 2 + (k_cols - DPH) // 128
1703 scale0 = tl.load(
1704 kv_scales_base + kv_ids0 * stride_scales_n + scale_idx,
1705 mask=valid0,
1706 other=1.0,
1707 )
1709 k0_r_blk = (k0_r_fp8.to(tl.float32) * scale0[:, None]).to(
1710 tl.bfloat16
1711 )
1712 else:
1713 kv_ptr = (
1714 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn
1715 )
1716 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :]
1717 k0_r_blk = tl.load(
1718 kv_ptr, mask=k0_r_msk, other=0.0, eviction_policy="evict_last"
1719 )
1721 tl.store(
1722 tle.gpu.local_ptr(k0_r_slot.sK, (kv_tile_rows, k_cols_b)),
1723 k0_r_blk,
1724 mask=valid0[:, None] & (k_cols < D)[None, :],
1725 )
1727 if HAVE_TAIL:
1728 offs_td = tl.arange(0, TDP)
1729 if IS_FP8:
1730 k0_r_tail_ptr = (
1731 kv_rope_base
1732 + offs_td[None, :]
1733 + kv_ids0[:, None] * stride_rope_n
1734 )
1735 else:
1736 k0_r_tail_ptr = (
1737 kv_nope_base
1738 + D
1739 + offs_td[None, :]
1740 + kv_ids0[:, None] * stride_kvn
1741 )
1742 k0_r_tail_msk = valid0[:, None] & (offs_td < TD)[None, :]
1743 k0_r_tail_blk = tl.load(
1744 k0_r_tail_ptr,
1745 mask=k0_r_tail_msk,
1746 other=0.0,
1747 eviction_policy="evict_last",
1748 )
1749 tl.store(
1750 tle.gpu.local_ptr(k0_r_slot.sK_tail),
1751 k0_r_tail_blk,
1752 mask=k0_r_tail_msk,
1753 )
1754 k0_r_writer.commit(pair)
1756 # Process k1_l (left half of block 1)
1757 k1_l_slot = k1_l_writer.acquire(pair)
1758 for tile in tl.static_range(0, DPH, 64):
1759 k_cols = tile + offs_tile
1760 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
1762 if IS_FP8:
1763 kv_ptr = (
1764 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn
1765 )
1766 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :]
1767 k1_l_fp8 = tl.load(
1768 kv_ptr, mask=k1_l_msk, other=0.0, eviction_policy="evict_last"
1769 )
1771 scale_idx = k_cols // 128
1772 scale1 = tl.load(
1773 kv_scales_base + kv_ids1 * stride_scales_n + scale_idx,
1774 mask=valid1,
1775 other=1.0,
1776 )
1778 k1_l_blk = (k1_l_fp8.to(tl.float32) * scale1[:, None]).to(
1779 tl.bfloat16
1780 )
1781 else:
1782 kv_ptr = (
1783 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn
1784 )
1785 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :]
1786 k1_l_blk = tl.load(
1787 kv_ptr, mask=k1_l_msk, other=0.0, eviction_policy="evict_last"
1788 )
1790 tl.store(
1791 tle.gpu.local_ptr(k1_l_slot.sK, (kv_tile_rows, k_cols_b)),
1792 k1_l_blk,
1793 mask=valid1[:, None] & (k_cols < D)[None, :],
1794 )
1795 k1_l_writer.commit(pair)
1797 # Store validity masks
1798 valid_slot = valid_writer.acquire(pair)
1799 valid_row0 = tl.full([BK], 0, dtype=tl.int32)
1800 valid_row1 = tl.full([BK], 1, dtype=tl.int32)
1801 valid_ptr0 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row0, offs_t))
1802 valid_ptr1 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row1, offs_t))
1803 tl.store(valid_ptr0, valid0.to(tl.int8))
1804 tl.store(valid_ptr1, valid1.to(tl.int8))
1805 valid_writer.commit(pair)
1807 @triton.jit
1808 def _tle_sparse_decode_consumer0(
1809 q_writer,
1810 q_reader,
1811 q_desc,
1812 tq_desc,
1813 k0_l_reader,
1814 k0_r_qk_reader,
1815 k1_l_remote_reader,
1816 valid_reader,
1817 sM_wg0_writer,
1818 sM_wg1_reader,
1819 sS0_writer,
1820 sS1_reader,
1821 sL_wg0_writer,
1822 sL_wg1_reader,
1823 output_desc,
1824 output_row,
1825 h_base,
1826 topk_len_ptr,
1827 attn_sink_base,
1828 log_scale: tl.constexpr,
1829 D: tl.constexpr,
1830 TD: tl.constexpr,
1831 OUT_DTYPE: tl.constexpr,
1832 HAVE_ATTN_SINK: tl.constexpr,
1833 TOPK: tl.constexpr,
1834 HAVE_TOPK_LENGTH: tl.constexpr,
1835 HAVE_TAIL: tl.constexpr,
1836 BK: tl.constexpr,
1837 BH: tl.constexpr,
1838 DPH: tl.constexpr,
1839 TDP: tl.constexpr,
1840 G: tl.constexpr,
1841 ):
1842 """Consumer 0: computes QK^T + online softmax + P@V_left."""
1843 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK
1844 offs_h = tl.arange(0, BH)
1845 offs_dh = tl.arange(0, DPH)
1846 mask_h = h_base + offs_h < G
1847 mask_od_l = offs_dh < D
1848 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH))
1849 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH))
1850 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH))
1852 # Load Q into shared memory (one-shot)
1853 q_write_slot = q_writer.acquire(0)
1854 tle.gpu.copy(q_desc, q_write_slot.sQ_l, [BH, DPH], [output_row, 0])
1855 tle.gpu.copy(q_desc, q_write_slot.sQ_r, [BH, DPH], [output_row, DPH])
1856 if HAVE_TAIL:
1857 tle.gpu.copy(tq_desc, q_write_slot.sQ_tail, [BH, TDP], [output_row, D])
1858 q_writer.commit(0)
1860 q_slot = q_reader.wait(0).slot
1861 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l)
1862 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r)
1864 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32)
1865 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
1866 acc_l = tl.zeros([BH, DPH], dtype=tl.float32)
1868 NK = tl.cdiv(topk_len, BK)
1869 NPAIRS = tl.cdiv(NK, 2)
1870 for pair in tl.range(NPAIRS):
1871 # Wait for k0_l data
1872 k0_l_wait = k0_l_reader.wait(pair)
1873 k0_l_slot = k0_l_wait.slot
1875 q_l_blk = tl.load(q_l_smem_ptr)
1876 q_r_blk = tl.load(q_r_smem_ptr)
1877 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l)))
1879 # QK for block 0: q_l @ k0_l^T + q_r @ k0_r^T + q_tail @ k0_tail^T
1880 qk0 = tl.full([BH, BK], 0.0, dtype=tl.float32)
1881 qk0 = tl.dot(q_l_blk, tl.trans(k0_l_blk), qk0, out_dtype=tl.float32)
1883 # Wait for k0_r
1884 k0_r_wait = k0_r_qk_reader.wait(pair)
1885 k0_r_slot = k0_r_wait.slot
1886 k0_r_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK, (kv_rows, kv_cols_r)))
1887 qk0 = tl.dot(q_r_blk, tl.trans(k0_r_blk), qk0, out_dtype=tl.float32)
1889 if HAVE_TAIL:
1890 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail))
1891 k0_t_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK_tail))
1892 qk0 = tl.dot(q_tail_blk, tl.trans(k0_t_blk), qk0, out_dtype=tl.float32)
1894 # Get validity mask for block 0
1895 valid_wait = valid_reader.wait(pair)
1896 row0 = tl.full([BK], 0, dtype=tl.int32)
1897 valid0 = (
1898 tl.load(
1899 tle.gpu.local_ptr(
1900 valid_wait.slot.is_kv_valid, (row0, tl.arange(0, BK))
1901 )
1902 ).to(tl.int32)
1903 == 1
1904 )
1906 qk0 = tl.where(valid0[None, :], qk0, float("-inf"))
1908 # Compute local softmax for block 0 only
1909 local_max = tl.maximum(max_prev, tl.max(qk0, axis=1))
1910 alpha = tl.math.exp2((max_prev - local_max) * log_scale)
1911 prob0 = tl.math.exp2(qk0 * log_scale - local_max[:, None] * log_scale)
1912 sum_exp = sum_exp * alpha + tl.sum(prob0, axis=1)
1913 acc_l = acc_l * alpha[:, None]
1914 prob0_b = prob0.to(OUT_DTYPE)
1916 # Send local_max to consumer1
1917 sM_wg0_slot = sM_wg0_writer.acquire(pair)
1918 tl.store(tle.gpu.local_ptr(sM_wg0_slot.sM), local_max)
1919 sM_wg0_writer.commit(pair)
1921 # Accumulate P@V_left with prob0
1922 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l)))
1923 acc_l = tl.dot(prob0_b, k0_l_blk, acc_l, out_dtype=tl.float32)
1924 k0_l_reader.release(pair)
1925 k0_r_qk_reader.release(pair)
1927 # Wait for max_next from consumer1 (merged max of block0 and block1)
1928 sM_wg1_wait = sM_wg1_reader.wait(pair)
1929 max_next = tl.load(tle.gpu.local_ptr(sM_wg1_wait.slot.sM))
1930 sM_wg1_reader.release(pair)
1932 # Rescale prob0 and acc_l using the global max
1933 final_scale = tl.math.exp2((local_max - max_next) * log_scale)
1934 sum_exp = sum_exp * final_scale
1935 acc_l = acc_l * final_scale[:, None]
1937 # Send rescaled prob0 to consumer1
1938 prob0_scaled = prob0 * final_scale[:, None]
1939 sS0_slot = sS0_writer.acquire(pair)
1940 tl.store(tle.gpu.local_ptr(sS0_slot.sS0), prob0_scaled.to(OUT_DTYPE))
1941 sS0_writer.commit(pair)
1943 # Receive prob1 from consumer1 and accumulate k1_l
1944 sS1_wait = sS1_reader.wait(pair)
1945 prob1 = tl.load(tle.gpu.local_ptr(sS1_wait.slot.sS1))
1946 k1_l_wait = k1_l_remote_reader.wait(pair)
1947 k1_l_blk = tl.load(
1948 tle.gpu.local_ptr(k1_l_wait.slot.sK, (kv_rows, kv_cols_l))
1949 )
1950 acc_l = tl.dot(prob1, k1_l_blk, acc_l, out_dtype=tl.float32)
1951 sS1_reader.release(pair)
1952 k1_l_remote_reader.release(pair)
1954 valid_reader.release(pair)
1956 max_prev = max_next
1958 # Exchange final sum_exp with consumer1
1959 sL_wg0_slot = sL_wg0_writer.acquire(0)
1960 tl.store(tle.gpu.local_ptr(sL_wg0_slot.sL), sum_exp)
1961 sL_wg0_writer.commit(0)
1962 sL_wg1_wait = sL_wg1_reader.wait(1)
1963 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg1_wait.slot.sL))
1964 total_sum = sum_exp + peer_sum
1965 sL_wg1_reader.release(1)
1967 is_no_valid_tokens = total_sum == 0.0
1968 inv_total_sum = tl.fdiv(1.0, total_sum)
1969 out_l_vals = acc_l * inv_total_sum[:, None]
1970 if HAVE_ATTN_SINK:
1971 fin_log = (
1972 max_prev * log_scale + tl.math.log2(total_sum)
1973 ) * 0.6931471805599453
1974 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0)
1975 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log))
1976 out_l_vals = out_l_vals * sink_scale[:, None]
1977 out_l_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_l_vals)
1978 o_l_msk = mask_h[:, None] & mask_od_l[None, :]
1979 tl.store(q_l_smem_ptr, out_l_vals.to(OUT_DTYPE), o_l_msk)
1980 tle.gpu.copy(q_slot.sQ_l, output_desc, [BH, DPH], [output_row, 0])
1982 @triton.jit
1983 def _tle_sparse_decode_consumer1(
1984 q_reader,
1985 k1_r_reader,
1986 k1_l_qk_reader,
1987 k0_r_remote_reader,
1988 valid_reader,
1989 sM_wg1_writer,
1990 sM_wg0_reader,
1991 sS1_writer,
1992 sS0_reader,
1993 sL_wg1_writer,
1994 sL_wg0_reader,
1995 output_desc,
1996 output_row,
1997 lse_base,
1998 h_base,
1999 topk_len_ptr,
2000 attn_sink_base,
2001 log_scale: tl.constexpr,
2002 D: tl.constexpr,
2003 TD: tl.constexpr,
2004 OUT_DTYPE: tl.constexpr,
2005 HAVE_ATTN_SINK: tl.constexpr,
2006 TOPK: tl.constexpr,
2007 HAVE_TOPK_LENGTH: tl.constexpr,
2008 HAVE_TAIL: tl.constexpr,
2009 BK: tl.constexpr,
2010 BH: tl.constexpr,
2011 DPH: tl.constexpr,
2012 TDP: tl.constexpr,
2013 G: tl.constexpr,
2014 ):
2015 """Consumer 1: computes P@V_right, exchanges softmax state with consumer0."""
2016 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK
2017 offs_h = tl.arange(0, BH)
2018 offs_dh = tl.arange(0, DPH)
2019 mask_h = h_base + offs_h < G
2020 mask_od_r = DPH + offs_dh < D
2021 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH))
2022 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH))
2023 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH))
2025 q_slot = q_reader.wait(0).slot
2026 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l)
2027 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r)
2029 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32)
2030 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
2031 acc_r = tl.zeros([BH, DPH], dtype=tl.float32)
2033 NK = tl.cdiv(topk_len, BK)
2034 NPAIRS = tl.cdiv(NK, 2)
2035 for pair in tl.range(NPAIRS):
2036 # Wait for k1_r data
2037 k1_r_wait = k1_r_reader.wait(pair)
2038 k1_r_slot = k1_r_wait.slot
2040 q_l_blk = tl.load(q_l_smem_ptr)
2041 q_r_blk = tl.load(q_r_smem_ptr)
2042 k1_r_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK, (kv_rows, kv_cols_r)))
2044 # QK for block 1
2045 qk1 = tl.full([BH, BK], 0.0, dtype=tl.float32)
2046 qk1 = tl.dot(q_r_blk, tl.trans(k1_r_blk), qk1, out_dtype=tl.float32)
2047 if HAVE_TAIL:
2048 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail))
2049 k1_t_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK_tail))
2050 qk1 = tl.dot(q_tail_blk, tl.trans(k1_t_blk), qk1, out_dtype=tl.float32)
2052 k1_l_wait = k1_l_qk_reader.wait(pair)
2053 k1_l_slot = k1_l_wait.slot
2054 k1_l_blk = tl.load(tle.gpu.local_ptr(k1_l_slot.sK, (kv_rows, kv_cols_l)))
2055 qk1 = tl.dot(q_l_blk, tl.trans(k1_l_blk), qk1, out_dtype=tl.float32)
2057 # Get validity mask for block 1
2058 valid_wait = valid_reader.wait(pair)
2059 row1 = tl.full([BK], 1, dtype=tl.int32)
2060 valid1 = (
2061 tl.load(
2062 tle.gpu.local_ptr(
2063 valid_wait.slot.is_kv_valid, (row1, tl.arange(0, BK))
2064 )
2065 ).to(tl.int32)
2066 == 1
2067 )
2069 qk1 = tl.where(valid1[None, :], qk1, float("-inf"))
2070 valid_reader.release(pair)
2072 # Receive candidate0 (local_max) from consumer0
2073 sM_wg0_wait = sM_wg0_reader.wait(pair)
2074 candidate0 = tl.load(tle.gpu.local_ptr(sM_wg0_wait.slot.sM))
2075 sM_wg0_reader.release(pair)
2077 # Compute candidate1 and merge to get global max_next
2078 candidate1 = tl.maximum(max_prev, tl.max(qk1, axis=1))
2079 max_next = tl.maximum(candidate1, candidate0)
2081 # Send max_next back to consumer0
2082 sM_wg1_slot = sM_wg1_writer.acquire(pair)
2083 tl.store(tle.gpu.local_ptr(sM_wg1_slot.sM), max_next)
2084 sM_wg1_writer.commit(pair)
2086 # Compute prob1 using global max_next
2087 alpha = tl.math.exp2((max_prev - max_next) * log_scale)
2088 prob1 = tl.math.exp2(qk1 * log_scale - max_next[:, None] * log_scale)
2089 sum_exp = sum_exp * alpha + tl.sum(prob1, axis=1)
2090 acc_r = acc_r * alpha[:, None]
2091 prob1_b = prob1.to(OUT_DTYPE)
2093 k1_l_qk_reader.release(pair)
2095 # Accumulate P@V_right with prob1
2096 acc_r = tl.dot(prob1_b, k1_r_blk, acc_r, out_dtype=tl.float32)
2098 # Send prob1 to consumer0
2099 sS1_slot = sS1_writer.acquire(pair)
2100 tl.store(tle.gpu.local_ptr(sS1_slot.sS1), prob1_b)
2101 sS1_writer.commit(pair)
2103 # Receive rescaled prob0 from consumer0 and accumulate k0_r
2104 sS0_wait = sS0_reader.wait(pair)
2105 prob0 = tl.load(tle.gpu.local_ptr(sS0_wait.slot.sS0))
2106 k0_r_wait = k0_r_remote_reader.wait(pair)
2107 k0_r_blk = tl.load(
2108 tle.gpu.local_ptr(k0_r_wait.slot.sK, (kv_rows, kv_cols_r))
2109 )
2110 acc_r = tl.dot(prob0, k0_r_blk, acc_r, out_dtype=tl.float32)
2111 k1_r_reader.release(pair)
2112 sS0_reader.release(pair)
2113 k0_r_remote_reader.release(pair)
2115 max_prev = max_next
2117 # Exchange final sum_exp with consumer0
2118 sL_wg1_slot = sL_wg1_writer.acquire(1)
2119 tl.store(tle.gpu.local_ptr(sL_wg1_slot.sL), sum_exp)
2120 sL_wg1_writer.commit(1)
2121 sL_wg0_wait = sL_wg0_reader.wait(0)
2122 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg0_wait.slot.sL))
2123 total_sum = sum_exp + peer_sum
2124 sL_wg0_reader.release(0)
2126 is_no_valid_tokens = total_sum == 0.0
2127 inv_total_sum = tl.fdiv(1.0, total_sum)
2128 out_r_vals = acc_r * inv_total_sum[:, None]
2129 if HAVE_ATTN_SINK:
2130 fin_log = (
2131 max_prev * log_scale + tl.math.log2(total_sum)
2132 ) * 0.6931471805599453
2133 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0)
2134 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log))
2135 out_r_vals = out_r_vals * sink_scale[:, None]
2136 out_r_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_r_vals)
2137 o_r_msk = mask_h[:, None] & mask_od_r[None, :]
2138 tl.store(q_r_smem_ptr, out_r_vals.to(OUT_DTYPE), o_r_msk)
2139 tle.gpu.copy(q_slot.sQ_r, output_desc, [BH, DPH], [output_row, DPH])
2141 # Store LSE
2142 lse_val = (max_prev * log_scale + tl.math.log2(total_sum)) * 0.6931471805599453
2143 lse_val = tl.where(is_no_valid_tokens, float("inf"), lse_val)
2144 tl.store(lse_base + offs_h, lse_val, mask=mask_h)
2146 @triton.jit
2147 def _tle_sparse_decode_fwd(
2148 q_desc,
2149 tq_desc,
2150 output_desc,
2151 kv_nope,
2152 kv_scales,
2153 kv_rope,
2154 indices,
2155 attn_sink,
2156 topk_length,
2157 sm_scale: tl.constexpr,
2158 output,
2159 lse,
2160 BATCH_SQ,
2161 HQ: tl.constexpr,
2162 DQK: tl.constexpr,
2163 SKV,
2164 TOPK: tl.constexpr,
2165 HAVE_ATTN_SINK: tl.constexpr,
2166 HAVE_TOPK_LENGTH: tl.constexpr,
2167 IS_FP8: tl.constexpr,
2168 D: tl.constexpr,
2169 TD: tl.constexpr,
2170 DP: tl.constexpr,
2171 TDP: tl.constexpr,
2172 G: tl.constexpr,
2173 RH: tl.constexpr,
2174 HAVE_TAIL: tl.constexpr,
2175 BK: tl.constexpr,
2176 BH: tl.constexpr,
2177 PAIR_BLOCKS: tl.constexpr,
2178 stride_kvn,
2179 stride_scales_n,
2180 stride_rope_n,
2181 stride_ib,
2182 stride_isq,
2183 ):
2184 DPH: tl.constexpr = DP // 2
2185 stride_lm = HQ
2187 pid = tl.program_id(0)
2188 programs_per_bsq: tl.constexpr = RH
2189 i_bsq = pid // programs_per_bsq
2190 i_rh = pid % programs_per_bsq
2191 h_base = i_rh * BH
2192 i_bsq64 = i_bsq.to(tl.int64)
2194 kv_nope_base = kv_nope
2195 kv_scales_base = kv_scales
2196 kv_rope_base = kv_rope
2197 t_base = indices + i_bsq64 * stride_isq
2198 topk_len_ptr = topk_length + i_bsq64 if HAVE_TOPK_LENGTH else indices
2199 attn_sink_base = attn_sink if HAVE_ATTN_SINK else lse
2200 l_base = lse + i_bsq64 * stride_lm + h_base
2201 q_row = i_bsq * HQ + h_base
2202 _ = output
2203 _ = BATCH_SQ
2204 _ = DQK
2206 sQ_l_smem = tle.gpu.alloc(
2207 [1, BH, DPH], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2208 )
2209 sQ_r_smem = tle.gpu.alloc(
2210 [1, BH, DPH], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2211 )
2212 if HAVE_TAIL:
2213 sQ_tail_smem = tle.gpu.alloc(
2214 [1, BH, TDP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2215 )
2216 q_pipe = tle.pipe(
2217 capacity=1,
2218 scope="cta",
2219 name="decode_sQ",
2220 readers=("wg0", "wg1"),
2221 one_shot=True,
2222 sQ_l=sQ_l_smem,
2223 sQ_r=sQ_r_smem,
2224 sQ_tail=sQ_tail_smem,
2225 )
2226 else:
2227 q_pipe = tle.pipe(
2228 capacity=1,
2229 scope="cta",
2230 name="decode_sQ",
2231 readers=("wg0", "wg1"),
2232 one_shot=True,
2233 sQ_l=sQ_l_smem,
2234 sQ_r=sQ_r_smem,
2235 )
2237 sK0_smem = tle.gpu.alloc(
2238 [1, BK, DP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2239 )
2240 sK1_smem = tle.gpu.alloc(
2241 [1, BK, DP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2242 )
2243 if HAVE_TAIL:
2244 sK0_tail_smem = tle.gpu.alloc(
2245 [1, BK, TDP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2246 )
2247 sK1_tail_smem = tle.gpu.alloc(
2248 [1, BK, TDP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2249 )
2250 sS0_smem = sK0_tail_smem
2251 sS1_smem = sK1_tail_smem
2252 else:
2253 sS0_smem = tle.gpu.alloc(
2254 [1, BH, BK], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2255 )
2256 sS1_smem = tle.gpu.alloc(
2257 [1, BH, BK], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem
2258 )
2259 is_kv_valid_smem = tle.gpu.alloc(
2260 [1, 2, BK], dtype=tl.int8, layout=None, scope=tle.gpu.smem
2261 )
2262 sM_smem = tle.gpu.alloc(
2263 [1, BH], dtype=tl.float32, layout=None, scope=tle.gpu.smem
2264 )
2265 sL_smem = tle.gpu.alloc(
2266 [2, BH], dtype=tl.float32, layout=None, scope=tle.gpu.smem
2267 )
2269 # Pipe definitions
2270 if HAVE_TAIL:
2271 k0_l_pipe = tle.pipe(
2272 capacity=1, scope="cta", name="decode_k0_l", sK=sK0_smem
2273 )
2274 k0_r_pipe = tle.pipe(
2275 capacity=1,
2276 scope="cta",
2277 name="decode_k0_r",
2278 readers=("qk", "remote"),
2279 sK=sK0_smem,
2280 sK_tail=sK0_tail_smem,
2281 )
2282 k1_l_pipe = tle.pipe(
2283 capacity=1,
2284 scope="cta",
2285 name="decode_k1_l",
2286 readers=("qk", "remote"),
2287 sK=sK1_smem,
2288 )
2289 k1_r_pipe = tle.pipe(
2290 capacity=1,
2291 scope="cta",
2292 name="decode_k1_r",
2293 sK=sK1_smem,
2294 sK_tail=sK1_tail_smem,
2295 )
2296 else:
2297 k0_l_pipe = tle.pipe(
2298 capacity=1, scope="cta", name="decode_k0_l", sK=sK0_smem
2299 )
2300 k0_r_pipe = tle.pipe(
2301 capacity=1,
2302 scope="cta",
2303 name="decode_k0_r",
2304 readers=("qk", "remote"),
2305 sK=sK0_smem,
2306 )
2307 k1_l_pipe = tle.pipe(
2308 capacity=1,
2309 scope="cta",
2310 name="decode_k1_l",
2311 readers=("qk", "remote"),
2312 sK=sK1_smem,
2313 )
2314 k1_r_pipe = tle.pipe(
2315 capacity=1, scope="cta", name="decode_k1_r", sK=sK1_smem
2316 )
2318 is_kv_valid_pipe = tle.pipe(
2319 capacity=1,
2320 scope="cta",
2321 name="decode_valid",
2322 readers=("wg0", "wg1"),
2323 is_kv_valid=is_kv_valid_smem,
2324 )
2325 sM_wg0_pipe = tle.pipe(
2326 capacity=1, scope="cta", name="decode_wg0_max", sM=sM_smem
2327 )
2328 sM_wg1_pipe = tle.pipe(
2329 capacity=1, scope="cta", name="decode_wg1_max", sM=sM_smem
2330 )
2331 sS0_pipe = tle.pipe(capacity=1, scope="cta", name="decode_sS0", sS0=sS0_smem)
2332 sS1_pipe = tle.pipe(capacity=1, scope="cta", name="decode_sS1", sS1=sS1_smem)
2333 sL_wg0_pipe = tle.pipe(
2334 capacity=2, scope="cta", name="decode_sL_wg0", sL=sL_smem
2335 )
2336 sL_wg1_pipe = tle.pipe(
2337 capacity=2, scope="cta", name="decode_sL_wg1", sL=sL_smem
2338 )
2340 log_scale: tl.constexpr = sm_scale * 1.4426950408889634
2342 tle.gpu.warp_specialize(
2343 [
2344 (
2345 _tle_sparse_decode_consumer0,
2346 (
2347 q_pipe.writer(),
2348 q_pipe.reader("wg0"),
2349 q_desc,
2350 tq_desc,
2351 k0_l_pipe.reader(),
2352 k0_r_pipe.reader("qk"),
2353 k1_l_pipe.reader("remote", fields=("sK",)),
2354 is_kv_valid_pipe.reader("wg0"),
2355 sM_wg0_pipe.writer(),
2356 sM_wg1_pipe.reader(),
2357 sS0_pipe.writer(),
2358 sS1_pipe.reader(),
2359 sL_wg0_pipe.writer(),
2360 sL_wg1_pipe.reader(),
2361 output_desc,
2362 q_row,
2363 h_base,
2364 topk_len_ptr,
2365 attn_sink_base,
2366 log_scale,
2367 D,
2368 TD,
2369 tl.bfloat16,
2370 HAVE_ATTN_SINK,
2371 TOPK,
2372 HAVE_TOPK_LENGTH,
2373 HAVE_TAIL,
2374 BK,
2375 BH,
2376 DPH,
2377 TDP,
2378 G,
2379 ),
2380 ),
2381 (
2382 _tle_sparse_decode_consumer1,
2383 (
2384 q_pipe.reader("wg1"),
2385 k1_r_pipe.reader(),
2386 k1_l_pipe.reader("qk"),
2387 k0_r_pipe.reader("remote", fields=("sK",)),
2388 is_kv_valid_pipe.reader("wg1"),
2389 sM_wg1_pipe.writer(),
2390 sM_wg0_pipe.reader(),
2391 sS1_pipe.writer(),
2392 sS0_pipe.reader(),
2393 sL_wg1_pipe.writer(),
2394 sL_wg0_pipe.reader(),
2395 output_desc,
2396 q_row,
2397 l_base,
2398 h_base,
2399 topk_len_ptr,
2400 attn_sink_base,
2401 log_scale,
2402 D,
2403 TD,
2404 tl.bfloat16,
2405 HAVE_ATTN_SINK,
2406 TOPK,
2407 HAVE_TOPK_LENGTH,
2408 HAVE_TAIL,
2409 BK,
2410 BH,
2411 DPH,
2412 TDP,
2413 G,
2414 ),
2415 ),
2416 (
2417 _tle_sparse_decode_producer,
2418 (
2419 k0_l_pipe.writer(),
2420 k0_r_pipe.writer(),
2421 k1_l_pipe.writer(),
2422 k1_r_pipe.writer(),
2423 is_kv_valid_pipe.writer(),
2424 kv_nope_base,
2425 kv_scales_base,
2426 kv_rope_base,
2427 t_base,
2428 topk_len_ptr,
2429 D,
2430 TD,
2431 DPH,
2432 TDP,
2433 SKV,
2434 TOPK,
2435 HAVE_TOPK_LENGTH,
2436 HAVE_TAIL,
2437 IS_FP8,
2438 BK,
2439 stride_kvn,
2440 stride_scales_n,
2441 stride_rope_n,
2442 ),
2443 ),
2444 ],
2445 [4, 4],
2446 [216, 72],
2447 )
2450# ============================================================================
2451# TLE Warp Specialization path for dense decode
2452# ============================================================================
2455def _tle_dense_decode_launch(
2456 q,
2457 kv_flat,
2458 block_table,
2459 cache_seqlens,
2460 out,
2461 lse,
2462 batch_size,
2463 seq_q,
2464 num_heads_q,
2465 head_dim_k,
2466 head_dim_v,
2467 page_block_size,
2468 softmax_scale,
2469 causal,
2470):
2471 """Launch TLE warp-specialized dense decode kernel."""
2472 from triton.tools.tensor_descriptor import TensorDescriptor
2474 _set_triton_descriptor_allocator(q.device)
2476 BH = TLE_DECODE_BH
2477 BK = TLE_DECODE_BK
2478 D = head_dim_v # 512
2479 TD = head_dim_k - D # 64 for DQK=576, 0 for DQK=512
2480 DP = triton.next_power_of_2(D)
2481 DPH = DP // 2
2482 HAVE_TAIL = TD > 0
2483 TDP = triton.next_power_of_2(TD) if HAVE_TAIL else 1
2484 G = num_heads_q
2485 RH = G // BH
2487 # Reshape q for TensorDescriptor: [batch*seq_q*HQ, DQK]
2488 q_flat = q.reshape(batch_size * seq_q * num_heads_q, head_dim_k).contiguous()
2489 out_flat = out.reshape(batch_size * seq_q * num_heads_q, head_dim_v)
2491 q_desc = TensorDescriptor(
2492 q_flat,
2493 shape=[batch_size * seq_q * num_heads_q, head_dim_k],
2494 strides=[head_dim_k, 1],
2495 block_shape=[BH, DPH],
2496 )
2497 if HAVE_TAIL:
2498 tq_desc = TensorDescriptor(
2499 q_flat,
2500 shape=[batch_size * seq_q * num_heads_q, head_dim_k],
2501 strides=[head_dim_k, 1],
2502 block_shape=[BH, TDP],
2503 )
2504 else:
2505 tq_desc = q_desc
2506 output_desc = TensorDescriptor(
2507 out_flat,
2508 shape=[batch_size * seq_q * num_heads_q, D],
2509 strides=[D, 1],
2510 block_shape=[BH, DPH],
2511 )
2513 # Grid: one program per (batch*seq_q, head_block)
2514 grid = (batch_size * seq_q * RH,)
2516 # Reshape block_table and cache_seqlens for kernel
2517 block_table_flat = block_table.reshape(batch_size * seq_q, -1).contiguous()
2518 cache_seqlens_flat = cache_seqlens.reshape(batch_size * seq_q).contiguous()
2520 _tle_dense_decode_fwd[grid](
2521 q_desc,
2522 tq_desc,
2523 output_desc,
2524 kv_flat,
2525 block_table_flat,
2526 cache_seqlens_flat,
2527 softmax_scale,
2528 out_flat,
2529 lse.reshape(batch_size * seq_q, num_heads_q).contiguous(),
2530 batch_size * seq_q,
2531 num_heads_q,
2532 head_dim_k,
2533 page_block_size,
2534 causal,
2535 D,
2536 TD,
2537 DP,
2538 TDP,
2539 G,
2540 RH,
2541 HAVE_TAIL,
2542 BK,
2543 BH,
2544 TLE_DECODE_PAIR_BLOCKS,
2545 kv_flat.stride(0),
2546 block_table_flat.stride(0),
2547 num_warps=TLE_DECODE_WORKER_NUM_WARPS,
2548 num_stages=1,
2549 )
2552if HAS_TLE:
2554 @triton.jit
2555 def _tle_dense_decode_producer(
2556 k0_l_writer,
2557 k0_r_writer,
2558 k1_l_writer,
2559 k1_r_writer,
2560 is_kv_valid_writer,
2561 kv_base,
2562 block_table_ptr,
2563 cache_seqlen,
2564 D: tl.constexpr,
2565 TD: tl.constexpr,
2566 DPH: tl.constexpr,
2567 TDP: tl.constexpr,
2568 PAGE_SIZE: tl.constexpr,
2569 HAVE_TAIL: tl.constexpr,
2570 BK: tl.constexpr,
2571 stride_kvn: tl.constexpr,
2572 stride_bt: tl.constexpr,
2573 ):
2574 """
2575 Producer: Load KV pages from paged cache to shared memory.
2576 Key difference from sparse: pages are contiguous, enabling efficient loads.
2577 """
2578 num_pages = tl.cdiv(cache_seqlen, PAGE_SIZE)
2579 NPAIRS = tl.cdiv(num_pages, 2)
2581 offs_t = tl.arange(0, BK)
2582 offs_tile = tl.arange(0, 64)
2583 kv_tile_rows = tl.broadcast_to(offs_t[:, None], (BK, 64))
2585 for pair in tl.range(NPAIRS):
2586 page_idx0 = pair * 2
2587 page_idx1 = page_idx0 + 1
2589 # Load physical page numbers from block_table (stride within row is 1)
2590 phys_page0 = tl.load(
2591 block_table_ptr + page_idx0,
2592 mask=page_idx0 < tl.cdiv(cache_seqlen, PAGE_SIZE),
2593 other=0,
2594 )
2595 phys_page1 = tl.load(
2596 block_table_ptr + page_idx1,
2597 mask=page_idx1 < tl.cdiv(cache_seqlen, PAGE_SIZE),
2598 other=0,
2599 )
2601 # Compute base addresses for contiguous page data
2602 base0 = phys_page0.to(tl.int64) * PAGE_SIZE * stride_kvn
2603 base1 = phys_page1.to(tl.int64) * PAGE_SIZE * stride_kvn
2605 # Validity masks for partial last page
2606 t_offs0 = page_idx0 * PAGE_SIZE + offs_t
2607 t_offs1 = page_idx1 * PAGE_SIZE + offs_t
2608 valid0 = t_offs0 < cache_seqlen
2609 valid1 = t_offs1 < cache_seqlen
2611 # Store validity masks
2612 valid_slot = is_kv_valid_writer.acquire(pair)
2613 valid_row0 = tl.full([BK], 0, dtype=tl.int32)
2614 valid_row1 = tl.full([BK], 1, dtype=tl.int32)
2615 tl.store(
2616 tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row0, offs_t)),
2617 valid0.to(tl.int8),
2618 )
2619 tl.store(
2620 tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row1, offs_t)),
2621 valid1.to(tl.int8),
2622 )
2623 is_kv_valid_writer.commit(pair)
2625 # Load page0 left half (NoPE [:256])
2626 k0_l_slot = k0_l_writer.acquire(pair)
2627 for tile in tl.static_range(0, DPH, 64):
2628 k_cols = tile + offs_tile
2629 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
2630 k0_l_ptr = (
2631 kv_base + base0 + offs_t[:, None] * stride_kvn + k_cols[None, :]
2632 )
2633 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :]
2634 k0_l_blk = tl.load(
2635 k0_l_ptr, mask=k0_l_msk, other=0.0, eviction_policy="evict_last"
2636 )
2637 tl.store(
2638 tle.gpu.local_ptr(k0_l_slot.sK, (kv_tile_rows, k_cols_b)),
2639 k0_l_blk,
2640 mask=k0_l_msk,
2641 )
2642 k0_l_writer.commit(pair)
2644 # Load page1 right half (NoPE [256:512])
2645 k1_r_slot = k1_r_writer.acquire(pair)
2646 for tile in tl.static_range(0, DPH, 64):
2647 k_cols = DPH + tile + offs_tile
2648 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
2649 k1_r_ptr = (
2650 kv_base + base1 + offs_t[:, None] * stride_kvn + k_cols[None, :]
2651 )
2652 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :]
2653 k1_r_blk = tl.load(
2654 k1_r_ptr, mask=k1_r_msk, other=0.0, eviction_policy="evict_last"
2655 )
2656 tl.store(
2657 tle.gpu.local_ptr(k1_r_slot.sK, (kv_tile_rows, k_cols_b)),
2658 k1_r_blk,
2659 mask=k1_r_msk,
2660 )
2661 if HAVE_TAIL:
2662 offs_td = tl.arange(0, TDP)
2663 k1_r_tail_ptr = (
2664 kv_base
2665 + base1
2666 + offs_t[:, None] * stride_kvn
2667 + (D + offs_td)[None, :]
2668 )
2669 k1_r_tail_msk = valid1[:, None] & (offs_td < TD)[None, :]
2670 k1_r_tail_blk = tl.load(
2671 k1_r_tail_ptr,
2672 mask=k1_r_tail_msk,
2673 other=0.0,
2674 eviction_policy="evict_last",
2675 )
2676 tl.store(
2677 tle.gpu.local_ptr(k1_r_slot.sK_tail),
2678 k1_r_tail_blk,
2679 mask=k1_r_tail_msk,
2680 )
2681 k1_r_writer.commit(pair)
2683 # Load page0 right half (NoPE [256:512])
2684 k0_r_slot = k0_r_writer.acquire(pair)
2685 for tile in tl.static_range(0, DPH, 64):
2686 k_cols = DPH + tile + offs_tile
2687 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
2688 k0_r_ptr = (
2689 kv_base + base0 + offs_t[:, None] * stride_kvn + k_cols[None, :]
2690 )
2691 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :]
2692 k0_r_blk = tl.load(
2693 k0_r_ptr, mask=k0_r_msk, other=0.0, eviction_policy="evict_last"
2694 )
2695 tl.store(
2696 tle.gpu.local_ptr(k0_r_slot.sK, (kv_tile_rows, k_cols_b)),
2697 k0_r_blk,
2698 mask=k0_r_msk,
2699 )
2700 if HAVE_TAIL:
2701 offs_td = tl.arange(0, TDP)
2702 k0_r_tail_ptr = (
2703 kv_base
2704 + base0
2705 + offs_t[:, None] * stride_kvn
2706 + (D + offs_td)[None, :]
2707 )
2708 k0_r_tail_msk = valid0[:, None] & (offs_td < TD)[None, :]
2709 k0_r_tail_blk = tl.load(
2710 k0_r_tail_ptr,
2711 mask=k0_r_tail_msk,
2712 other=0.0,
2713 eviction_policy="evict_last",
2714 )
2715 tl.store(
2716 tle.gpu.local_ptr(k0_r_slot.sK_tail),
2717 k0_r_tail_blk,
2718 mask=k0_r_tail_msk,
2719 )
2720 k0_r_writer.commit(pair)
2722 # Load page1 left half (NoPE [:256])
2723 k1_l_slot = k1_l_writer.acquire(pair)
2724 for tile in tl.static_range(0, DPH, 64):
2725 k_cols = tile + offs_tile
2726 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64))
2727 k1_l_ptr = (
2728 kv_base + base1 + offs_t[:, None] * stride_kvn + k_cols[None, :]
2729 )
2730 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :]
2731 k1_l_blk = tl.load(
2732 k1_l_ptr, mask=k1_l_msk, other=0.0, eviction_policy="evict_last"
2733 )
2734 tl.store(
2735 tle.gpu.local_ptr(k1_l_slot.sK, (kv_tile_rows, k_cols_b)),
2736 k1_l_blk,
2737 mask=k1_l_msk,
2738 )
2739 k1_l_writer.commit(pair)
2741 @triton.jit
2742 def _tle_dense_decode_consumer0(
2743 q_writer,
2744 q_reader,
2745 q_desc,
2746 tq_desc,
2747 k0_l_reader,
2748 k0_r_qk_reader,
2749 k1_l_remote_reader,
2750 is_kv_valid_reader,
2751 sM_wg0_writer,
2752 sM_wg1_reader,
2753 sS0_writer,
2754 sS1_reader,
2755 sL_wg0_writer,
2756 sL_wg1_reader,
2757 output_desc,
2758 output_row,
2759 h_base,
2760 cache_seqlen,
2761 log_scale: tl.constexpr,
2762 D: tl.constexpr,
2763 TD: tl.constexpr,
2764 OUT_DTYPE: tl.constexpr,
2765 HAVE_TAIL: tl.constexpr,
2766 BK: tl.constexpr,
2767 BH: tl.constexpr,
2768 DPH: tl.constexpr,
2769 TDP: tl.constexpr,
2770 G: tl.constexpr,
2771 PAGE_SIZE: tl.constexpr,
2772 ):
2773 """Consumer 0: QK^T left half + softmax + P@V_left."""
2774 offs_h = tl.arange(0, BH)
2775 offs_dh = tl.arange(0, DPH)
2776 mask_h = h_base + offs_h < G
2777 mask_od_l = offs_dh < D
2778 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH))
2779 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH))
2780 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH))
2782 # Load Q once
2783 q_write_slot = q_writer.acquire(0)
2784 tle.gpu.copy(q_desc, q_write_slot.sQ_l, [BH, DPH], [output_row, 0])
2785 tle.gpu.copy(q_desc, q_write_slot.sQ_r, [BH, DPH], [output_row, DPH])
2786 if HAVE_TAIL:
2787 tle.gpu.copy(tq_desc, q_write_slot.sQ_tail, [BH, TDP], [output_row, D])
2788 q_writer.commit(0)
2790 q_slot = q_reader.wait(0).slot
2791 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l)
2792 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r)
2793 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32)
2794 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
2795 acc_l = tl.zeros([BH, DPH], dtype=tl.float32)
2797 num_pages = tl.cdiv(cache_seqlen, PAGE_SIZE)
2798 NPAIRS = tl.cdiv(num_pages, 2)
2800 for pair in tl.range(NPAIRS):
2801 # Compute QK^T for page0
2802 k0_l_wait = k0_l_reader.wait(pair)
2803 k0_l_slot = k0_l_wait.slot
2805 q_l_blk = tl.load(q_l_smem_ptr)
2806 q_r_blk = tl.load(q_r_smem_ptr)
2807 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l)))
2809 qk0 = tl.full([BH, BK], 0.0, dtype=tl.float32)
2810 qk0 = tl.dot(q_l_blk, tl.trans(k0_l_blk), qk0, out_dtype=tl.float32)
2812 k0_r_wait = k0_r_qk_reader.wait(pair)
2813 k0_r_slot = k0_r_wait.slot
2814 k0_r_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK, (kv_rows, kv_cols_r)))
2815 qk0 = tl.dot(q_r_blk, tl.trans(k0_r_blk), qk0, out_dtype=tl.float32)
2816 if HAVE_TAIL:
2817 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail))
2818 k0_t_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK_tail))
2819 qk0 = tl.dot(q_tail_blk, tl.trans(k0_t_blk), qk0, out_dtype=tl.float32)
2821 # Apply validity mask
2822 valid_wait = is_kv_valid_reader.wait(pair)
2823 row0 = tl.full([BK], 0, dtype=tl.int32)
2824 valid0 = (
2825 tl.load(
2826 tle.gpu.local_ptr(
2827 valid_wait.slot.is_kv_valid, (row0, tl.arange(0, BK))
2828 )
2829 )
2830 != 0
2831 )
2832 qk0 = tl.where(valid0[None, :], qk0, float("-inf"))
2833 is_kv_valid_reader.release(pair)
2835 # Online softmax
2836 local_max = tl.maximum(max_prev, tl.max(qk0, axis=1))
2837 alpha = tl.math.exp2((max_prev - local_max) * log_scale)
2838 prob0 = tl.math.exp2(qk0 * log_scale - local_max[:, None] * log_scale)
2839 sum_exp = sum_exp * alpha + tl.sum(prob0, axis=1)
2840 acc_l = acc_l * alpha[:, None]
2841 prob0_b = prob0.to(OUT_DTYPE)
2843 # Send local max to consumer1
2844 sM_wg0_slot = sM_wg0_writer.acquire(pair)
2845 tl.store(tle.gpu.local_ptr(sM_wg0_slot.sM), local_max)
2846 sM_wg0_writer.commit(pair)
2848 # Accumulate P@V_left
2849 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l)))
2850 acc_l = tl.dot(prob0_b, k0_l_blk, acc_l, out_dtype=tl.float32)
2851 k0_l_reader.release(pair)
2852 k0_r_qk_reader.release(pair)
2854 # Receive final max from consumer1
2855 sM_wg1_wait = sM_wg1_reader.wait(pair)
2856 max_next = tl.load(tle.gpu.local_ptr(sM_wg1_wait.slot.sM))
2857 sM_wg1_reader.release(pair)
2859 # Rescale with final max
2860 final_scale = tl.math.exp2((local_max - max_next) * log_scale)
2861 sum_exp = sum_exp * final_scale
2862 acc_l = acc_l * final_scale[:, None]
2864 # Send rescaled prob0 to consumer1
2865 prob0_scaled = prob0 * final_scale[:, None]
2866 sS0_slot = sS0_writer.acquire(pair)
2867 tl.store(tle.gpu.local_ptr(sS0_slot.sS0), prob0_scaled.to(OUT_DTYPE))
2868 sS0_writer.commit(pair)
2870 # Receive prob1 and accumulate P@V_left from page1
2871 sS1_wait = sS1_reader.wait(pair)
2872 prob1 = tl.load(tle.gpu.local_ptr(sS1_wait.slot.sS1))
2873 k1_l_wait = k1_l_remote_reader.wait(pair)
2874 k1_l_blk = tl.load(
2875 tle.gpu.local_ptr(k1_l_wait.slot.sK, (kv_rows, kv_cols_l))
2876 )
2877 acc_l = tl.dot(prob1, k1_l_blk, acc_l, out_dtype=tl.float32)
2878 sS1_reader.release(pair)
2879 k1_l_remote_reader.release(pair)
2881 max_prev = max_next
2883 # Exchange sum_exp with consumer1
2884 sL_wg0_slot = sL_wg0_writer.acquire(0)
2885 tl.store(tle.gpu.local_ptr(sL_wg0_slot.sL), sum_exp)
2886 sL_wg0_writer.commit(0)
2887 sL_wg1_wait = sL_wg1_reader.wait(1)
2888 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg1_wait.slot.sL))
2889 total_sum = sum_exp + peer_sum
2890 sL_wg1_reader.release(1)
2892 # Normalize and write output left half
2893 is_no_valid_tokens = total_sum == 0.0
2894 inv_total_sum = tl.fdiv(1.0, total_sum)
2895 out_l_vals = acc_l * inv_total_sum[:, None]
2896 out_l_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_l_vals)
2897 o_l_msk = mask_h[:, None] & mask_od_l[None, :]
2898 tl.store(q_l_smem_ptr, out_l_vals.to(OUT_DTYPE), o_l_msk)
2899 tle.gpu.copy(q_slot.sQ_l, output_desc, [BH, DPH], [output_row, 0])
2901 @triton.jit
2902 def _tle_dense_decode_consumer1(
2903 q_reader,
2904 k1_r_reader,
2905 k1_l_qk_reader,
2906 k0_r_remote_reader,
2907 is_kv_valid_reader,
2908 sM_wg1_writer,
2909 sM_wg0_reader,
2910 sS1_writer,
2911 sS0_reader,
2912 sL_wg1_writer,
2913 sL_wg0_reader,
2914 final_lse_smem,
2915 output_desc,
2916 output_row,
2917 l_base,
2918 h_base,
2919 cache_seqlen,
2920 log_scale: tl.constexpr,
2921 D: tl.constexpr,
2922 TD: tl.constexpr,
2923 OUT_DTYPE: tl.constexpr,
2924 HAVE_TAIL: tl.constexpr,
2925 BK: tl.constexpr,
2926 BH: tl.constexpr,
2927 DPH: tl.constexpr,
2928 TDP: tl.constexpr,
2929 G: tl.constexpr,
2930 PAGE_SIZE: tl.constexpr,
2931 ):
2932 """Consumer 1: QK^T right half + P@V_right."""
2933 offs_h = tl.arange(0, BH)
2934 offs_dh = tl.arange(0, DPH)
2935 mask_h = h_base + offs_h < G
2936 mask_od_r = DPH + offs_dh < D
2937 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH))
2938 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH))
2939 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH))
2941 q_slot = q_reader.wait(0).slot
2942 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l)
2943 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r)
2944 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32)
2945 sum_exp = tl.full([BH], 0.0, dtype=tl.float32)
2946 acc_r = tl.zeros([BH, DPH], dtype=tl.float32)
2948 num_pages = tl.cdiv(cache_seqlen, PAGE_SIZE)
2949 NPAIRS = tl.cdiv(num_pages, 2)
2951 for pair in tl.range(NPAIRS):
2952 # Compute QK^T for page1
2953 k1_r_wait = k1_r_reader.wait(pair)
2954 k1_r_slot = k1_r_wait.slot
2956 q_l_blk = tl.load(q_l_smem_ptr)
2957 q_r_blk = tl.load(q_r_smem_ptr)
2958 k1_r_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK, (kv_rows, kv_cols_r)))
2960 qk1 = tl.full([BH, BK], 0.0, dtype=tl.float32)
2961 qk1 = tl.dot(q_r_blk, tl.trans(k1_r_blk), qk1, out_dtype=tl.float32)
2962 if HAVE_TAIL:
2963 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail))
2964 k1_t_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK_tail))
2965 qk1 = tl.dot(q_tail_blk, tl.trans(k1_t_blk), qk1, out_dtype=tl.float32)
2967 k1_l_wait = k1_l_qk_reader.wait(pair)
2968 k1_l_slot = k1_l_wait.slot
2969 k1_l_blk = tl.load(tle.gpu.local_ptr(k1_l_slot.sK, (kv_rows, kv_cols_l)))
2970 qk1 = tl.dot(q_l_blk, tl.trans(k1_l_blk), qk1, out_dtype=tl.float32)
2972 # Apply validity mask
2973 valid_wait = is_kv_valid_reader.wait(pair)
2974 row1 = tl.full([BK], 1, dtype=tl.int32)
2975 valid1 = (
2976 tl.load(
2977 tle.gpu.local_ptr(
2978 valid_wait.slot.is_kv_valid, (row1, tl.arange(0, BK))
2979 )
2980 )
2981 != 0
2982 )
2983 qk1 = tl.where(valid1[None, :], qk1, float("-inf"))
2984 is_kv_valid_reader.release(pair)
2986 # Receive candidate0 from consumer0
2987 sM_wg0_wait = sM_wg0_reader.wait(pair)
2988 candidate0 = tl.load(tle.gpu.local_ptr(sM_wg0_wait.slot.sM))
2989 sM_wg0_reader.release(pair)
2991 # Compute final max
2992 candidate1 = tl.maximum(max_prev, tl.max(qk1, axis=1))
2993 max_next = tl.maximum(candidate1, candidate0)
2994 sM_wg1_slot = sM_wg1_writer.acquire(pair)
2995 tl.store(tle.gpu.local_ptr(sM_wg1_slot.sM), max_next)
2996 sM_wg1_writer.commit(pair)
2998 # Online softmax
2999 alpha = tl.math.exp2((max_prev - max_next) * log_scale)
3000 prob1 = tl.math.exp2(qk1 * log_scale - max_next[:, None] * log_scale)
3001 sum_exp = sum_exp * alpha + tl.sum(prob1, axis=1)
3002 acc_r = acc_r * alpha[:, None]
3003 prob1_b = prob1.to(OUT_DTYPE)
3005 k1_l_qk_reader.release(pair)
3007 # Accumulate P@V_right from page1
3008 acc_r = tl.dot(prob1_b, k1_r_blk, acc_r, out_dtype=tl.float32)
3010 # Send prob1 to consumer0
3011 sS1_slot = sS1_writer.acquire(pair)
3012 tl.store(tle.gpu.local_ptr(sS1_slot.sS1), prob1_b)
3013 sS1_writer.commit(pair)
3015 # Receive prob0 and accumulate P@V_right from page0
3016 sS0_wait = sS0_reader.wait(pair)
3017 prob0 = tl.load(tle.gpu.local_ptr(sS0_wait.slot.sS0))
3018 k0_r_wait = k0_r_remote_reader.wait(pair)
3019 k0_r_blk = tl.load(
3020 tle.gpu.local_ptr(k0_r_wait.slot.sK, (kv_rows, kv_cols_r))
3021 )
3022 acc_r = tl.dot(prob0, k0_r_blk, acc_r, out_dtype=tl.float32)
3023 k1_r_reader.release(pair)
3024 sS0_reader.release(pair)
3025 k0_r_remote_reader.release(pair)
3026 max_prev = max_next
3028 # Exchange sum_exp with consumer0
3029 sL_wg1_slot = sL_wg1_writer.acquire(1)
3030 tl.store(tle.gpu.local_ptr(sL_wg1_slot.sL), sum_exp)
3031 sL_wg1_writer.commit(1)
3032 sL_wg0_wait = sL_wg0_reader.wait(0)
3033 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg0_wait.slot.sL))
3034 total_sum = sum_exp + peer_sum
3035 sL_wg0_reader.release(0)
3037 # Normalize and write output right half
3038 is_no_valid_tokens = total_sum == 0.0
3039 inv_total_sum = tl.fdiv(1.0, total_sum)
3040 out_r_vals = acc_r * inv_total_sum[:, None]
3041 final_max_logits_log2 = max_prev * log_scale
3042 fin_log = (final_max_logits_log2 + tl.math.log2(total_sum)) * 0.6931471805599453
3043 out_r_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_r_vals)
3044 o_r_msk = mask_h[:, None] & mask_od_r[None, :]
3045 tl.store(q_r_smem_ptr, out_r_vals.to(OUT_DTYPE), o_r_msk)
3046 tle.gpu.copy(q_slot.sQ_r, output_desc, [BH, DPH], [output_row, DPH])
3048 # Write LSE
3049 fin_log = tl.where(is_no_valid_tokens, float("inf"), fin_log)
3050 tl.store(tle.gpu.local_ptr(final_lse_smem), fin_log, mask_h)
3051 fin_log = tl.load(tle.gpu.local_ptr(final_lse_smem), mask_h, other=float("inf"))
3052 tl.store(l_base + offs_h, fin_log, mask_h)
3054 @triton.jit
3055 def _tle_dense_decode_fwd(
3056 q_desc,
3057 tq_desc,
3058 output_desc,
3059 kv_cache,
3060 block_table,
3061 cache_seqlens,
3062 sm_scale: tl.constexpr,
3063 output,
3064 lse,
3065 BS,
3066 G: tl.constexpr,
3067 DQK: tl.constexpr,
3068 PAGE_SIZE: tl.constexpr,
3069 CAUSAL: tl.constexpr,
3070 D: tl.constexpr,
3071 TD: tl.constexpr,
3072 DP: tl.constexpr,
3073 TDP: tl.constexpr,
3074 H: tl.constexpr,
3075 RH: tl.constexpr,
3076 HAVE_TAIL: tl.constexpr,
3077 BK: tl.constexpr,
3078 BH: tl.constexpr,
3079 PAIR_BLOCKS: tl.constexpr,
3080 stride_kvn: tl.constexpr,
3081 stride_bt: tl.constexpr,
3082 ):
3083 DPH: tl.constexpr = DP // 2
3084 stride_lm = G
3086 pid = tl.program_id(0)
3087 i_sq = pid // RH
3088 i_rh = pid % RH
3089 h_base = i_rh * BH
3090 output_row = i_sq * G + h_base
3091 i_sq64 = i_sq.to(tl.int64)
3093 cache_seqlen = tl.load(cache_seqlens + i_sq64)
3094 block_table_ptr = block_table + i_sq64 * stride_bt
3095 kv_base = kv_cache
3096 l_base = lse + i_sq64 * stride_lm + h_base
3097 _ = output
3098 _ = BS
3099 _ = DQK
3100 _ = CAUSAL
3102 sQ_l_smem = tle.gpu.alloc(
3103 [1, BH, DPH],
3104 dtype=kv_cache.dtype.element_ty,
3105 layout=None,
3106 scope=tle.gpu.smem,
3107 )
3108 sQ_r_smem = tle.gpu.alloc(
3109 [1, BH, DPH],
3110 dtype=kv_cache.dtype.element_ty,
3111 layout=None,
3112 scope=tle.gpu.smem,
3113 )
3114 if HAVE_TAIL:
3115 sQ_tail_smem = tle.gpu.alloc(
3116 [1, BH, TDP],
3117 dtype=kv_cache.dtype.element_ty,
3118 layout=None,
3119 scope=tle.gpu.smem,
3120 )
3121 q_pipe = tle.pipe(
3122 capacity=1,
3123 scope="cta",
3124 name="dense_sQ",
3125 readers=("wg0", "wg1"),
3126 one_shot=True,
3127 sQ_l=sQ_l_smem,
3128 sQ_r=sQ_r_smem,
3129 sQ_tail=sQ_tail_smem,
3130 )
3131 else:
3132 q_pipe = tle.pipe(
3133 capacity=1,
3134 scope="cta",
3135 name="dense_sQ",
3136 readers=("wg0", "wg1"),
3137 one_shot=True,
3138 sQ_l=sQ_l_smem,
3139 sQ_r=sQ_r_smem,
3140 )
3142 sK0_smem = tle.gpu.alloc(
3143 [1, BK, DP],
3144 dtype=kv_cache.dtype.element_ty,
3145 layout=None,
3146 scope=tle.gpu.smem,
3147 )
3148 sK1_smem = tle.gpu.alloc(
3149 [1, BK, DP],
3150 dtype=kv_cache.dtype.element_ty,
3151 layout=None,
3152 scope=tle.gpu.smem,
3153 )
3154 if HAVE_TAIL:
3155 sK0_tail_smem = tle.gpu.alloc(
3156 [1, BK, TDP],
3157 dtype=kv_cache.dtype.element_ty,
3158 layout=None,
3159 scope=tle.gpu.smem,
3160 )
3161 sK1_tail_smem = tle.gpu.alloc(
3162 [1, BK, TDP],
3163 dtype=kv_cache.dtype.element_ty,
3164 layout=None,
3165 scope=tle.gpu.smem,
3166 )
3167 sS0_smem = sK0_tail_smem
3168 else:
3169 sS0_smem = tle.gpu.alloc(
3170 [1, BH, BK],
3171 dtype=kv_cache.dtype.element_ty,
3172 layout=None,
3173 scope=tle.gpu.smem,
3174 )
3176 is_kv_valid_smem = tle.gpu.alloc(
3177 [1, PAIR_BLOCKS, BK],
3178 dtype=tl.int8,
3179 layout=None,
3180 scope=tle.gpu.smem,
3181 nv_mma_shared_layout=False,
3182 )
3184 k0_l_pipe = tle.pipe(capacity=1, scope="cta", name="dense_sK0_l", sK=sK0_smem)
3185 if HAVE_TAIL:
3186 k0_r_pipe = tle.pipe(
3187 capacity=1,
3188 scope="cta",
3189 name="dense_sK0_r",
3190 readers=("qk", "remote"),
3191 sK=sK0_smem,
3192 sK_tail=sK0_tail_smem,
3193 )
3194 else:
3195 k0_r_pipe = tle.pipe(
3196 capacity=1,
3197 scope="cta",
3198 name="dense_sK0_r",
3199 readers=("qk", "remote"),
3200 sK=sK0_smem,
3201 )
3202 k1_l_pipe = tle.pipe(
3203 capacity=1,
3204 scope="cta",
3205 name="dense_sK1_l",
3206 readers=("qk", "remote"),
3207 sK=sK1_smem,
3208 )
3209 if HAVE_TAIL:
3210 k1_r_pipe = tle.pipe(
3211 capacity=1,
3212 scope="cta",
3213 name="dense_sK1_r",
3214 sK=sK1_smem,
3215 sK_tail=sK1_tail_smem,
3216 )
3217 else:
3218 k1_r_pipe = tle.pipe(
3219 capacity=1, scope="cta", name="dense_sK1_r", sK=sK1_smem
3220 )
3222 is_kv_valid_pipe = tle.pipe(
3223 capacity=1,
3224 scope="cta",
3225 name="dense_is_kv_valid",
3226 readers=("wg0", "wg1"),
3227 is_kv_valid=is_kv_valid_smem,
3228 )
3230 sM_smem = tle.gpu.alloc(
3231 [1, BH],
3232 dtype=tl.float32,
3233 layout=None,
3234 scope=tle.gpu.smem,
3235 nv_mma_shared_layout=False,
3236 )
3237 sS1_smem = tle.gpu.alloc(
3238 [1, BH, BK],
3239 dtype=kv_cache.dtype.element_ty,
3240 layout=None,
3241 scope=tle.gpu.smem,
3242 )
3243 sL_smem = tle.gpu.alloc(
3244 [2, BH],
3245 dtype=tl.float32,
3246 layout=None,
3247 scope=tle.gpu.smem,
3248 nv_mma_shared_layout=False,
3249 )
3250 final_lse_smem = tle.gpu.alloc(
3251 [BH],
3252 dtype=tl.float32,
3253 layout=None,
3254 scope=tle.gpu.smem,
3255 nv_mma_shared_layout=False,
3256 )
3258 sM_wg0_pipe = tle.pipe(
3259 capacity=1, scope="cta", name="dense_wg0_max", sM=sM_smem
3260 )
3261 sM_wg1_pipe = tle.pipe(
3262 capacity=1, scope="cta", name="dense_wg1_max", sM=sM_smem
3263 )
3264 sS0_pipe = tle.pipe(capacity=1, scope="cta", name="dense_sS0", sS0=sS0_smem)
3265 sS1_pipe = tle.pipe(capacity=1, scope="cta", name="dense_sS1", sS1=sS1_smem)
3266 sL_wg0_pipe = tle.pipe(capacity=2, scope="cta", name="dense_sL_wg0", sL=sL_smem)
3267 sL_wg1_pipe = tle.pipe(capacity=2, scope="cta", name="dense_sL_wg1", sL=sL_smem)
3269 log_scale: tl.constexpr = sm_scale * 1.4426950408889634
3271 tle.gpu.warp_specialize(
3272 [
3273 (
3274 _tle_dense_decode_consumer0,
3275 (
3276 q_pipe.writer(),
3277 q_pipe.reader("wg0"),
3278 q_desc,
3279 tq_desc,
3280 k0_l_pipe.reader(),
3281 k0_r_pipe.reader("qk"),
3282 k1_l_pipe.reader("remote", fields=("sK",)),
3283 is_kv_valid_pipe.reader("wg0"),
3284 sM_wg0_pipe.writer(),
3285 sM_wg1_pipe.reader(),
3286 sS0_pipe.writer(),
3287 sS1_pipe.reader(),
3288 sL_wg0_pipe.writer(),
3289 sL_wg1_pipe.reader(),
3290 output_desc,
3291 output_row,
3292 h_base,
3293 cache_seqlen,
3294 log_scale,
3295 D,
3296 TD,
3297 kv_cache.dtype.element_ty,
3298 HAVE_TAIL,
3299 BK,
3300 BH,
3301 DPH,
3302 TDP,
3303 G,
3304 PAGE_SIZE,
3305 ),
3306 ),
3307 (
3308 _tle_dense_decode_consumer1,
3309 (
3310 q_pipe.reader("wg1"),
3311 k1_r_pipe.reader(),
3312 k1_l_pipe.reader("qk"),
3313 k0_r_pipe.reader("remote", fields=("sK",)),
3314 is_kv_valid_pipe.reader("wg1"),
3315 sM_wg1_pipe.writer(),
3316 sM_wg0_pipe.reader(),
3317 sS1_pipe.writer(),
3318 sS0_pipe.reader(),
3319 sL_wg1_pipe.writer(),
3320 sL_wg0_pipe.reader(),
3321 final_lse_smem,
3322 output_desc,
3323 output_row,
3324 l_base,
3325 h_base,
3326 cache_seqlen,
3327 log_scale,
3328 D,
3329 TD,
3330 kv_cache.dtype.element_ty,
3331 HAVE_TAIL,
3332 BK,
3333 BH,
3334 DPH,
3335 TDP,
3336 G,
3337 PAGE_SIZE,
3338 ),
3339 ),
3340 (
3341 _tle_dense_decode_producer,
3342 (
3343 k0_l_pipe.writer(),
3344 k0_r_pipe.writer(),
3345 k1_l_pipe.writer(),
3346 k1_r_pipe.writer(),
3347 is_kv_valid_pipe.writer(),
3348 kv_base,
3349 block_table_ptr,
3350 cache_seqlen,
3351 D,
3352 TD,
3353 DPH,
3354 TDP,
3355 PAGE_SIZE,
3356 HAVE_TAIL,
3357 BK,
3358 stride_kvn,
3359 stride_bt,
3360 ),
3361 ),
3362 ],
3363 [4, 4],
3364 [216, 72],
3365 )