Coverage for src/flag_gems/fused/flash_mla_with_kvcache.py: 4%

1201 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1""" 

2Triton implementation of flash_mla_with_kvcache for MLA attention. 

3Supports both sparse (FP8 KV cache + topk indices) and dense (paged attention) modes. 

4Only supports sm90 (Hopper) architecture. 

5""" 

6 

7import dataclasses 

8import os 

9from typing import Optional, Tuple 

10 

11import torch 

12import triton 

13import triton.language as tl 

14 

15from flag_gems.utils.triton_version_utils import has_triton_tle 

16 

17if has_triton_tle(3, 6, 0): 

18 try: 

19 import triton.experimental.tle.language as tle 

20 

21 HAS_TLE = True 

22 except ImportError: 

23 tle = None 

24 HAS_TLE = False 

25else: 

26 tle = None 

27 HAS_TLE = False 

28 

29 

30# TLE constants for decode 

31TLE_DECODE_BK = 64 

32TLE_DECODE_BH = 64 

33TLE_DECODE_PAIR_BLOCKS = 2 

34TLE_DECODE_WORKER_NUM_WARPS = 4 

35 

36 

37# ============================================================================ 

38# Data structures (compatible with original CUDA interface) 

39# ============================================================================ 

40 

41 

42@dataclasses.dataclass 

43class FlashMLASchedMeta: 

44 """Stores tile scheduler metadata for FlashMLA.""" 

45 

46 @dataclasses.dataclass 

47 class Config: 

48 b: int 

49 s_q: int 

50 h_q: int 

51 page_block_size: int 

52 h_k: int 

53 causal: bool 

54 is_fp8_kvcache: bool 

55 topk: Optional[int] 

56 extra_page_block_size: Optional[int] 

57 extra_topk: Optional[int] 

58 

59 have_initialized: bool = False 

60 config: Optional[Config] = None 

61 tile_scheduler_metadata: Optional[torch.Tensor] = None 

62 num_splits: Optional[torch.Tensor] = None 

63 

64 

65def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: 

66 """Returns an empty FlashMLASchedMeta instance.""" 

67 return FlashMLASchedMeta(), None 

68 

69 

70# ============================================================================ 

71# Sparse decode kernel (FP8 KV cache + topk indices) 

72# 

73# KV cache layout per token (656 bytes total): 

74# [0:512] - NoPE part: 512 float8_e4m3 values 

75# [512:528] - Scale factors: 4 float32 values (each for 128 FP8 values) 

76# [528:656] - RoPE part: 64 bfloat16 values 

77# 

78# The NoPE part (after dequantization) serves as BOTH K and V for MLA. 

79# ============================================================================ 

80 

81 

82@triton.autotune( 

83 configs=[ 

84 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=2), 

85 triton.Config({"BK": 64, "BH": 64}, num_warps=8, num_stages=4), 

86 ], 

87 key=["HQ", "DQK", "TOPK", "HAVE_ATTN_SINK", "HAVE_TOPK_LENGTH", "IS_FP8"], 

88) 

89@triton.jit 

90def _sparse_decode_kernel( 

91 q, 

92 kv, 

93 kv_scales, 

94 kv_rope, 

95 indices, 

96 attn_sink, 

97 topk_length, 

98 sm_scale: tl.constexpr, 

99 output, 

100 lse, 

101 stride_qb, 

102 stride_qsq, 

103 stride_qh, 

104 stride_kvn, 

105 stride_scales_n, 

106 stride_rope_n, 

107 stride_ib, 

108 stride_isq, 

109 stride_ob, 

110 stride_osq, 

111 stride_oh, 

112 stride_lseb, 

113 stride_lseh, 

114 SQ, 

115 HQ: tl.constexpr, 

116 DQK: tl.constexpr, 

117 SKV, 

118 TOPK: tl.constexpr, 

119 HAVE_ATTN_SINK: tl.constexpr, 

120 HAVE_TOPK_LENGTH: tl.constexpr, 

121 IS_FP8: tl.constexpr, 

122 BK: tl.constexpr, 

123 BH: tl.constexpr, 

124): 

125 """ 

126 Sparse decode kernel with online softmax. 

127 Grid: (batch_size * seq_q * ceil(HQ / BH),) 

128 Each program handles BH heads for one (batch, seq_q) position. 

129 

130 For FP8 mode: 

131 - kv: [num_tokens, 512] float8_e4m3fn (NoPE part) 

132 - kv_scales: [num_tokens, 4] float32 (per-128-element scales) 

133 - kv_rope: [num_tokens, 64] bfloat16 (RoPE part) 

134 For BF16 mode: 

135 - kv: [num_tokens, DQK] bfloat16 (full KV) 

136 - kv_scales, kv_rope: unused 

137 """ 

138 num_head_blocks: tl.constexpr = (HQ + BH - 1) // BH 

139 pid = tl.program_id(0) 

140 i_b = pid // (SQ * num_head_blocks) 

141 remainder = pid % (SQ * num_head_blocks) 

142 i_sq = remainder // num_head_blocks 

143 i_sq = i_sq.to(tl.int64) 

144 i_gbh = remainder % num_head_blocks 

145 gbh_base = i_gbh * BH 

146 

147 DP: tl.constexpr = 512 

148 BDP: tl.constexpr = 256 

149 

150 # Base pointers 

151 q_base = q + i_b * stride_qb + i_sq * stride_qsq + gbh_base * stride_qh 

152 kv_base = kv 

153 t_base = indices + i_b * stride_ib + i_sq * stride_isq 

154 attn_sink_ptr = attn_sink + gbh_base if HAVE_ATTN_SINK else 0 

155 topk_length_ptr = topk_length + i_b if HAVE_TOPK_LENGTH else 0 

156 o_base = output + i_b * stride_ob + i_sq * stride_osq + gbh_base * stride_oh 

157 l_base = lse + i_b * stride_lseb + gbh_base * stride_lseh + i_sq 

158 

159 offs_h = tl.arange(0, BH) 

160 offs_d = tl.arange(0, BDP) 

161 if DQK == 576: 

162 offs_td = tl.arange(0, 64) 

163 offs_t = tl.arange(0, BK) 

164 

165 # Load Q in two halves [BH, 256] x 2 

166 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] 

167 q_blk0 = tl.load(q_ptr, eviction_policy="evict_first") 

168 q_blk1 = tl.load(q_ptr + BDP, eviction_policy="evict_first") 

169 if DQK == 576: 

170 tq_ptr = q_base + DP + offs_h[:, None] * stride_qh + offs_td[None, :] 

171 tq_blk = tl.load(tq_ptr, eviction_policy="evict_first") 

172 

173 # Online softmax accumulators 

174 max_log = tl.full([BH], float("-inf"), dtype=tl.float32) 

175 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

176 acc0 = tl.zeros([BH, BDP], dtype=tl.float32) 

177 acc1 = tl.zeros([BH, BDP], dtype=tl.float32) 

178 

179 topk_len = tl.load(topk_length_ptr) if HAVE_TOPK_LENGTH else TOPK 

180 NK = tl.cdiv(topk_len, BK) 

181 for ck in range(NK): 

182 # Load indices 

183 t_ptr = BK * ck + offs_t 

184 t_msk = t_ptr < topk_len 

185 t_ptr += t_base 

186 kv_ids = tl.load(t_ptr, t_msk, other=-1) 

187 mask_ids = (kv_ids < SKV) & (kv_ids >= 0) 

188 kv_ids = tl.where(mask_ids, kv_ids, 0) 

189 

190 if IS_FP8: 

191 # FP8 mode: load FP8 values and dequantize with per-128-element scales 

192 # Load NoPE FP8 data: [BDP, BK] for each half 

193 kv_ptr = kv_base + offs_d[:, None] + kv_ids[None, :] * stride_kvn 

194 kv_fp8_0 = tl.load(kv_ptr, cache_modifier=".cg") # [256, BK] float8 

195 kv_fp8_1 = tl.load(kv_ptr + BDP, cache_modifier=".cg") # [256, BK] float8 

196 

197 # Load 4 scales per token separately 

198 # Scale layout: [num_tokens, 4] float32 

199 scale0 = tl.load(kv_scales + kv_ids * stride_scales_n + 0) # [BK] 

200 scale1 = tl.load(kv_scales + kv_ids * stride_scales_n + 1) # [BK] 

201 scale2 = tl.load(kv_scales + kv_ids * stride_scales_n + 2) # [BK] 

202 scale3 = tl.load(kv_scales + kv_ids * stride_scales_n + 3) # [BK] 

203 

204 # Dequantize first half [256, BK]: 

205 # elements [0:128] use scale0, elements [128:256] use scale1 

206 mask_lo = offs_d[:, None] < 128 

207 kv_blk0 = tl.where( 

208 mask_lo, 

209 kv_fp8_0.to(tl.float32) * scale0[None, :], 

210 kv_fp8_0.to(tl.float32) * scale1[None, :], 

211 ).to(tl.bfloat16) 

212 

213 # Dequantize second half [256, BK]: 

214 # elements [0:128] use scale2, elements [128:256] use scale3 

215 kv_blk1 = tl.where( 

216 mask_lo, 

217 kv_fp8_1.to(tl.float32) * scale2[None, :], 

218 kv_fp8_1.to(tl.float32) * scale3[None, :], 

219 ).to(tl.bfloat16) 

220 else: 

221 # BF16 mode: load directly 

222 kv_ptr = kv_base + offs_d[:, None] + kv_ids[None, :] * stride_kvn 

223 kv_blk0 = tl.load(kv_ptr, cache_modifier=".cg") # [BDP, BK] 

224 kv_blk1 = tl.load(kv_ptr + BDP, cache_modifier=".cg") # [BDP, BK] 

225 

226 # Compute QK^T 

227 qk = tl.dot(q_blk0, kv_blk0, out_dtype=tl.float32) 

228 qk = tl.dot(q_blk1, kv_blk1, qk, out_dtype=tl.float32) 

229 if DQK == 576: 

230 if IS_FP8: 

231 # RoPE part from separate tensor 

232 rope_ptr = kv_rope + offs_td[:, None] + kv_ids[None, :] * stride_rope_n 

233 tkv_blk = tl.load(rope_ptr, cache_modifier=".cg") 

234 else: 

235 tkv_ptr = kv_base + DP + offs_td[:, None] + kv_ids[None, :] * stride_kvn 

236 tkv_blk = tl.load(tkv_ptr, cache_modifier=".cg") 

237 qk = tl.dot(tq_blk, tkv_blk, qk, out_dtype=tl.float32) 

238 qk *= sm_scale 

239 

240 # Mask invalid tokens 

241 qk = tl.where(mask_ids[None, :], qk, float("-inf")) 

242 

243 # Online softmax 

244 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) 

245 exp_qk = tl.math.exp(qk - new_max[:, None]) 

246 sum_qk = tl.sum(exp_qk, axis=1) 

247 alpha = tl.math.exp(max_log - new_max) 

248 sum_exp = sum_exp * alpha + sum_qk 

249 

250 # Accumulate P @ V (V = K NoPE for MLA) 

251 acc0 = tl.dot( 

252 exp_qk.to(tl.bfloat16), 

253 kv_blk0.trans(), 

254 acc0 * alpha[:, None], 

255 out_dtype=tl.float32, 

256 ) 

257 acc1 = tl.dot( 

258 exp_qk.to(tl.bfloat16), 

259 kv_blk1.trans(), 

260 acc1 * alpha[:, None], 

261 out_dtype=tl.float32, 

262 ) 

263 max_log = new_max 

264 

265 # Finalize output 

266 valid_mask = max_log != float("-inf") 

267 max_log = tl.where(valid_mask, max_log, float("-inf")) 

268 

269 orig_lse = max_log + tl.math.log(sum_exp) 

270 lse_out = tl.where(valid_mask, orig_lse, float("inf")) 

271 tl.store(l_base + offs_h * stride_lseh, lse_out) 

272 

273 if HAVE_ATTN_SINK: 

274 sink = tl.load(attn_sink_ptr + offs_h) 

275 sum_exp_new_lse = tl.math.exp(orig_lse) + tl.math.exp(sink) 

276 factor = tl.math.exp(max_log) / sum_exp_new_lse 

277 else: 

278 factor = 1.0 / sum_exp 

279 

280 out_vals0 = tl.where(valid_mask[:, None], acc0 * factor[:, None], 0.0) 

281 out_vals1 = tl.where(valid_mask[:, None], acc1 * factor[:, None], 0.0) 

282 

283 # Store output 

284 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] 

285 tl.store(o_ptr, out_vals0.to(tl.bfloat16)) 

286 tl.store(o_ptr + BDP, out_vals1.to(tl.bfloat16)) 

287 

288 

289# ============================================================================ 

290# Sparse decode kernel for FlashMLA MODEL1 layout 

291# 

292# MODEL1 is FlashMLA's internal name for the d_qk=512 / 584-byte layout. 

293# It is not a model name. Per page: 

294# [0:page_block_size*576] - token data 

295# per token: 448 FP8 NoPE + 64 BF16 RoPE 

296# [page_block_size*576:...] - 8 uint8 E8M0 scales per token 

297# 

298# The 512-dim output uses both NoPE and RoPE values as V: 

299# output[0:448] = weighted NoPE 

300# output[448:512] = weighted RoPE 

301# ============================================================================ 

302 

303 

304@triton.autotune( 

305 configs=[ 

306 triton.Config({"BK": 32, "BH": 64}, num_warps=4, num_stages=1), 

307 triton.Config({"BK": 32, "BH": 64}, num_warps=8, num_stages=1), 

308 ], 

309 key=[ 

310 "HQ", 

311 "TOPK", 

312 "EXTRA_TOPK", 

313 "HAVE_ATTN_SINK", 

314 "HAVE_TOPK_LENGTH", 

315 "HAVE_EXTRA", 

316 "HAVE_EXTRA_TOPK_LENGTH", 

317 ], 

318) 

319@triton.jit 

320def _sparse_decode_model1_kernel( 

321 q, 

322 kv, 

323 indices, 

324 extra_kv, 

325 extra_indices, 

326 attn_sink, 

327 topk_length, 

328 extra_topk_length, 

329 sm_scale: tl.constexpr, 

330 output, 

331 lse, 

332 stride_qb, 

333 stride_qsq, 

334 stride_qh, 

335 stride_kv_block, 

336 stride_ib, 

337 stride_isq, 

338 stride_extra_kv_block, 

339 stride_eib, 

340 stride_eisq, 

341 stride_ob, 

342 stride_osq, 

343 stride_oh, 

344 stride_lseb, 

345 stride_lseh, 

346 SQ, 

347 HQ: tl.constexpr, 

348 PAGE_SIZE: tl.constexpr, 

349 EXTRA_PAGE_SIZE: tl.constexpr, 

350 NUM_BLOCKS, 

351 EXTRA_NUM_BLOCKS, 

352 TOPK: tl.constexpr, 

353 EXTRA_TOPK: tl.constexpr, 

354 HAVE_ATTN_SINK: tl.constexpr, 

355 HAVE_TOPK_LENGTH: tl.constexpr, 

356 HAVE_EXTRA: tl.constexpr, 

357 HAVE_EXTRA_TOPK_LENGTH: tl.constexpr, 

358 BK: tl.constexpr, 

359 BH: tl.constexpr, 

360): 

361 num_head_blocks: tl.constexpr = (HQ + BH - 1) // BH 

362 pid = tl.program_id(0) 

363 i_b = pid // (SQ * num_head_blocks) 

364 remainder = pid % (SQ * num_head_blocks) 

365 i_sq = remainder // num_head_blocks 

366 i_sq = i_sq.to(tl.int64) 

367 i_gbh = remainder % num_head_blocks 

368 gbh_base = i_gbh * BH 

369 

370 NOPE: tl.constexpr = 448 

371 ROPE: tl.constexpr = 64 

372 # D: tl.constexpr = 512 

373 BDP: tl.constexpr = 256 

374 TOKEN_DATA_BYTES: tl.constexpr = 576 

375 SCALE_BYTES: tl.constexpr = 8 

376 

377 q_base = q + i_b * stride_qb + i_sq * stride_qsq + gbh_base * stride_qh 

378 t_base = indices + i_b * stride_ib + i_sq * stride_isq 

379 et_base = extra_indices + i_b * stride_eib + i_sq * stride_eisq 

380 attn_sink_ptr = attn_sink + gbh_base if HAVE_ATTN_SINK else 0 

381 topk_length_ptr = topk_length + i_b if HAVE_TOPK_LENGTH else 0 

382 extra_topk_length_ptr = extra_topk_length + i_b if HAVE_EXTRA_TOPK_LENGTH else 0 

383 o_base = output + i_b * stride_ob + i_sq * stride_osq + gbh_base * stride_oh 

384 l_base = lse + i_b * stride_lseb + gbh_base * stride_lseh + i_sq 

385 

386 offs_h = tl.arange(0, BH) 

387 offs_d = tl.arange(0, BDP) 

388 offs_t = tl.arange(0, BK) 

389 offs_rope = tl.arange(0, ROPE) 

390 

391 q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] 

392 q_blk0 = tl.load(q_ptr, eviction_policy="evict_first") 

393 q_blk1_nope = tl.load( 

394 q_ptr + BDP, 

395 mask=offs_d[None, :] < (NOPE - BDP), 

396 other=0.0, 

397 eviction_policy="evict_first", 

398 ) 

399 q_rope = tl.load( 

400 q_base + offs_h[:, None] * stride_qh + (NOPE + offs_rope[None, :]), 

401 eviction_policy="evict_first", 

402 ) 

403 

404 max_log = tl.full([BH], float("-inf"), dtype=tl.float32) 

405 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

406 acc0 = tl.zeros([BH, BDP], dtype=tl.float32) 

407 acc1 = tl.zeros([BH, BDP], dtype=tl.float32) 

408 

409 topk_len = tl.load(topk_length_ptr) if HAVE_TOPK_LENGTH else TOPK 

410 NK = tl.cdiv(topk_len, BK) 

411 for ck in range(NK): 

412 t_offs = BK * ck + offs_t 

413 t_msk = t_offs < topk_len 

414 kv_ids = tl.load(t_base + t_offs, t_msk, other=-1) 

415 block_ids = kv_ids // PAGE_SIZE 

416 rel_ids = kv_ids - block_ids * PAGE_SIZE 

417 valid_ids = t_msk & (kv_ids >= 0) & (block_ids < NUM_BLOCKS) 

418 block_ids = tl.where(valid_ids, block_ids, 0) 

419 rel_ids = tl.where(valid_ids, rel_ids, 0) 

420 

421 token_base = ( 

422 kv + block_ids.to(tl.int64) * stride_kv_block + rel_ids * TOKEN_DATA_BYTES 

423 ) 

424 scale_base = ( 

425 kv 

426 + block_ids.to(tl.int64) * stride_kv_block 

427 + PAGE_SIZE * TOKEN_DATA_BYTES 

428 + rel_ids * SCALE_BYTES 

429 ) 

430 

431 kv_fp8_0_u8 = tl.load( 

432 token_base[None, :] + offs_d[:, None], 

433 mask=valid_ids[None, :], 

434 other=0, 

435 cache_modifier=".cg", 

436 ) 

437 kv_fp8_1_u8 = tl.load( 

438 token_base[None, :] + (BDP + offs_d[:, None]), 

439 mask=valid_ids[None, :] & (offs_d[:, None] < (NOPE - BDP)), 

440 other=0, 

441 cache_modifier=".cg", 

442 ) 

443 

444 scale0_u8 = tl.load(scale_base + 0, mask=valid_ids, other=127) 

445 scale1_u8 = tl.load(scale_base + 1, mask=valid_ids, other=127) 

446 scale2_u8 = tl.load(scale_base + 2, mask=valid_ids, other=127) 

447 scale3_u8 = tl.load(scale_base + 3, mask=valid_ids, other=127) 

448 scale4_u8 = tl.load(scale_base + 4, mask=valid_ids, other=127) 

449 scale5_u8 = tl.load(scale_base + 5, mask=valid_ids, other=127) 

450 scale6_u8 = tl.load(scale_base + 6, mask=valid_ids, other=127) 

451 

452 scale0 = (scale0_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

453 scale1 = (scale1_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

454 scale2 = (scale2_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

455 scale3 = (scale3_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

456 scale4 = (scale4_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

457 scale5 = (scale5_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

458 scale6 = (scale6_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

459 

460 kv_fp8_0 = kv_fp8_0_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32) 

461 scale_0 = tl.where( 

462 offs_d[:, None] < 64, 

463 scale0[None, :], 

464 tl.where( 

465 offs_d[:, None] < 128, 

466 scale1[None, :], 

467 tl.where(offs_d[:, None] < 192, scale2[None, :], scale3[None, :]), 

468 ), 

469 ) 

470 kv_blk0 = (kv_fp8_0 * scale_0).to(tl.bfloat16) 

471 

472 kv_fp8_1 = kv_fp8_1_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32) 

473 scale_1 = tl.where( 

474 offs_d[:, None] < 64, 

475 scale4[None, :], 

476 tl.where(offs_d[:, None] < 128, scale5[None, :], scale6[None, :]), 

477 ) 

478 nope_tail = (kv_fp8_1 * scale_1).to(tl.bfloat16) 

479 

480 rope_ptr = (token_base + NOPE).to(tl.pointer_type(tl.bfloat16)) 

481 rope_blk = tl.load( 

482 rope_ptr[None, :] + offs_rope[:, None], 

483 mask=valid_ids[None, :], 

484 other=0.0, 

485 cache_modifier=".cg", 

486 ) 

487 

488 kv_blk1 = tl.where( 

489 offs_d[:, None] < (NOPE - BDP), 

490 nope_tail, 

491 tl.load( 

492 rope_ptr[None, :] + (offs_d[:, None] - (NOPE - BDP)), 

493 mask=valid_ids[None, :] & (offs_d[:, None] >= (NOPE - BDP)), 

494 other=0.0, 

495 cache_modifier=".cg", 

496 ), 

497 ) 

498 

499 qk = tl.dot(q_blk0, kv_blk0, out_dtype=tl.float32) 

500 qk = tl.dot(q_blk1_nope, nope_tail, qk, out_dtype=tl.float32) 

501 qk = tl.dot(q_rope, rope_blk, qk, out_dtype=tl.float32) 

502 qk *= sm_scale 

503 qk = tl.where(valid_ids[None, :], qk, float("-inf")) 

504 

505 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) 

506 exp_qk = tl.math.exp(qk - new_max[:, None]) 

507 sum_qk = tl.sum(exp_qk, axis=1) 

508 alpha = tl.math.exp(max_log - new_max) 

509 sum_exp = sum_exp * alpha + sum_qk 

510 acc0 = tl.dot( 

511 exp_qk.to(tl.bfloat16), 

512 kv_blk0.trans(), 

513 acc0 * alpha[:, None], 

514 out_dtype=tl.float32, 

515 ) 

516 acc1 = tl.dot( 

517 exp_qk.to(tl.bfloat16), 

518 kv_blk1.trans(), 

519 acc1 * alpha[:, None], 

520 out_dtype=tl.float32, 

521 ) 

522 max_log = new_max 

523 

524 if HAVE_EXTRA: 

525 extra_topk_len = ( 

526 tl.load(extra_topk_length_ptr) if HAVE_EXTRA_TOPK_LENGTH else EXTRA_TOPK 

527 ) 

528 ENK = tl.cdiv(extra_topk_len, BK) 

529 for ck in range(ENK): 

530 t_offs = BK * ck + offs_t 

531 t_msk = t_offs < extra_topk_len 

532 kv_ids = tl.load(et_base + t_offs, t_msk, other=-1) 

533 block_ids = kv_ids // EXTRA_PAGE_SIZE 

534 rel_ids = kv_ids - block_ids * EXTRA_PAGE_SIZE 

535 valid_ids = t_msk & (kv_ids >= 0) & (block_ids < EXTRA_NUM_BLOCKS) 

536 block_ids = tl.where(valid_ids, block_ids, 0) 

537 rel_ids = tl.where(valid_ids, rel_ids, 0) 

538 

539 token_base = ( 

540 extra_kv 

541 + block_ids.to(tl.int64) * stride_extra_kv_block 

542 + rel_ids * TOKEN_DATA_BYTES 

543 ) 

544 scale_base = ( 

545 extra_kv 

546 + block_ids.to(tl.int64) * stride_extra_kv_block 

547 + EXTRA_PAGE_SIZE * TOKEN_DATA_BYTES 

548 + rel_ids * SCALE_BYTES 

549 ) 

550 

551 kv_fp8_0_u8 = tl.load( 

552 token_base[None, :] + offs_d[:, None], 

553 mask=valid_ids[None, :], 

554 other=0, 

555 cache_modifier=".cg", 

556 ) 

557 kv_fp8_1_u8 = tl.load( 

558 token_base[None, :] + (BDP + offs_d[:, None]), 

559 mask=valid_ids[None, :] & (offs_d[:, None] < (NOPE - BDP)), 

560 other=0, 

561 cache_modifier=".cg", 

562 ) 

563 

564 scale0_u8 = tl.load(scale_base + 0, mask=valid_ids, other=127) 

565 scale1_u8 = tl.load(scale_base + 1, mask=valid_ids, other=127) 

566 scale2_u8 = tl.load(scale_base + 2, mask=valid_ids, other=127) 

567 scale3_u8 = tl.load(scale_base + 3, mask=valid_ids, other=127) 

568 scale4_u8 = tl.load(scale_base + 4, mask=valid_ids, other=127) 

569 scale5_u8 = tl.load(scale_base + 5, mask=valid_ids, other=127) 

570 scale6_u8 = tl.load(scale_base + 6, mask=valid_ids, other=127) 

571 

572 scale0 = (scale0_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

573 scale1 = (scale1_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

574 scale2 = (scale2_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

575 scale3 = (scale3_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

576 scale4 = (scale4_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

577 scale5 = (scale5_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

578 scale6 = (scale6_u8.to(tl.int32) << 23).to(tl.float32, bitcast=True) 

579 

580 kv_fp8_0 = kv_fp8_0_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32) 

581 scale_0 = tl.where( 

582 offs_d[:, None] < 64, 

583 scale0[None, :], 

584 tl.where( 

585 offs_d[:, None] < 128, 

586 scale1[None, :], 

587 tl.where(offs_d[:, None] < 192, scale2[None, :], scale3[None, :]), 

588 ), 

589 ) 

590 kv_blk0 = (kv_fp8_0 * scale_0).to(tl.bfloat16) 

591 

592 kv_fp8_1 = kv_fp8_1_u8.to(tl.float8e4nv, bitcast=True).to(tl.float32) 

593 scale_1 = tl.where( 

594 offs_d[:, None] < 64, 

595 scale4[None, :], 

596 tl.where(offs_d[:, None] < 128, scale5[None, :], scale6[None, :]), 

597 ) 

598 nope_tail = (kv_fp8_1 * scale_1).to(tl.bfloat16) 

599 

600 rope_ptr = (token_base + NOPE).to(tl.pointer_type(tl.bfloat16)) 

601 rope_blk = tl.load( 

602 rope_ptr[None, :] + offs_rope[:, None], 

603 mask=valid_ids[None, :], 

604 other=0.0, 

605 cache_modifier=".cg", 

606 ) 

607 kv_blk1 = tl.where( 

608 offs_d[:, None] < (NOPE - BDP), 

609 nope_tail, 

610 tl.load( 

611 rope_ptr[None, :] + (offs_d[:, None] - (NOPE - BDP)), 

612 mask=valid_ids[None, :] & (offs_d[:, None] >= (NOPE - BDP)), 

613 other=0.0, 

614 cache_modifier=".cg", 

615 ), 

616 ) 

617 

618 qk = tl.dot(q_blk0, kv_blk0, out_dtype=tl.float32) 

619 qk = tl.dot(q_blk1_nope, nope_tail, qk, out_dtype=tl.float32) 

620 qk = tl.dot(q_rope, rope_blk, qk, out_dtype=tl.float32) 

621 qk *= sm_scale 

622 qk = tl.where(valid_ids[None, :], qk, float("-inf")) 

623 

624 new_max = tl.maximum(max_log, tl.max(qk, axis=1)) 

625 exp_qk = tl.math.exp(qk - new_max[:, None]) 

626 sum_qk = tl.sum(exp_qk, axis=1) 

627 alpha = tl.math.exp(max_log - new_max) 

628 sum_exp = sum_exp * alpha + sum_qk 

629 acc0 = tl.dot( 

630 exp_qk.to(tl.bfloat16), 

631 kv_blk0.trans(), 

632 acc0 * alpha[:, None], 

633 out_dtype=tl.float32, 

634 ) 

635 acc1 = tl.dot( 

636 exp_qk.to(tl.bfloat16), 

637 kv_blk1.trans(), 

638 acc1 * alpha[:, None], 

639 out_dtype=tl.float32, 

640 ) 

641 max_log = new_max 

642 

643 valid_mask = max_log != float("-inf") 

644 orig_lse = max_log + tl.math.log(sum_exp) 

645 lse_out = tl.where(valid_mask, orig_lse, float("inf")) 

646 tl.store(l_base + offs_h * stride_lseh, lse_out) 

647 

648 if HAVE_ATTN_SINK: 

649 sink = tl.load(attn_sink_ptr + offs_h) 

650 sum_exp_new_lse = tl.math.exp(orig_lse) + tl.math.exp(sink) 

651 factor = tl.math.exp(max_log) / sum_exp_new_lse 

652 else: 

653 factor = 1.0 / sum_exp 

654 

655 out_vals0 = tl.where(valid_mask[:, None], acc0 * factor[:, None], 0.0) 

656 out_vals1 = tl.where(valid_mask[:, None], acc1 * factor[:, None], 0.0) 

657 o_ptr = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] 

658 tl.store(o_ptr, out_vals0.to(tl.bfloat16)) 

659 tl.store(o_ptr + BDP, out_vals1.to(tl.bfloat16)) 

660 

661 

662# ============================================================================ 

663# Dense decode kernel (paged attention with block_table) 

664# ============================================================================ 

665 

666 

667@triton.autotune( 

668 configs=[ 

669 triton.Config({"BLOCK_H": 64, "BLOCK_N": 64}, num_warps=8, num_stages=2), 

670 triton.Config({"BLOCK_H": 64, "BLOCK_N": 64}, num_warps=8, num_stages=3), 

671 ], 

672 key=["HQ", "DQK", "HAVE_CAUSAL"], 

673) 

674@triton.jit 

675def _dense_decode_kernel( 

676 Q_ptr, 

677 stride_q_b, 

678 stride_q_sq, 

679 stride_q_h, 

680 KV_cache, 

681 stride_kv_bs, 

682 Block_table, 

683 stride_bt_b, 

684 Seq_lens, 

685 Out, 

686 stride_o_b, 

687 stride_o_sq, 

688 stride_o_h, 

689 LSE, 

690 stride_lse_b, 

691 stride_lse_h, 

692 sm_scale, 

693 SQ, 

694 HQ: tl.constexpr, 

695 DQK: tl.constexpr, 

696 HEAD_DIM_V: tl.constexpr, 

697 PAGE_SIZE: tl.constexpr, 

698 HAVE_CAUSAL: tl.constexpr, 

699 BLOCK_H: tl.constexpr, 

700 BLOCK_N: tl.constexpr, 

701): 

702 """ 

703 Dense decode kernel with paged attention and online softmax. 

704 Grid: (ceil(HQ / BLOCK_H), batch_size * seq_q) 

705 """ 

706 pid_h_block = tl.program_id(0) 

707 pid_b_sq = tl.program_id(1) 

708 i_b = pid_b_sq // SQ 

709 i_sq = pid_b_sq % SQ 

710 

711 cur_head = pid_h_block * BLOCK_H + tl.arange(0, BLOCK_H) 

712 mask_head = cur_head < HQ 

713 

714 # Load Q: NoPE part [BLOCK_H, HEAD_DIM_V] and RoPE part [BLOCK_H, DQK-HEAD_DIM_V] 

715 offs_d_nope = tl.arange(0, HEAD_DIM_V) 

716 offs_q_nope = ( 

717 i_b * stride_q_b 

718 + i_sq * stride_q_sq 

719 + cur_head[:, None] * stride_q_h 

720 + offs_d_nope[None, :] 

721 ) 

722 q_nope = tl.load(Q_ptr + offs_q_nope, mask=mask_head[:, None], other=0.0) 

723 

724 offs_d_pe = tl.arange(HEAD_DIM_V, DQK) 

725 offs_q_pe = ( 

726 i_b * stride_q_b 

727 + i_sq * stride_q_sq 

728 + cur_head[:, None] * stride_q_h 

729 + offs_d_pe[None, :] 

730 ) 

731 q_pe = tl.load(Q_ptr + offs_q_pe, mask=mask_head[:, None], other=0.0) 

732 

733 # Online softmax accumulators 

734 e_max = tl.full([BLOCK_H], value=float("-inf"), dtype=tl.float32) 

735 e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) 

736 acc = tl.zeros([BLOCK_H, HEAD_DIM_V], dtype=tl.float32) 

737 

738 cur_batch_seq_len = tl.load(Seq_lens + i_b) 

739 Block_table += i_b * stride_bt_b 

740 

741 offs_n = tl.arange(0, BLOCK_N) 

742 loop_time = cur_batch_seq_len // BLOCK_N 

743 remainder = cur_batch_seq_len % BLOCK_N 

744 

745 for i in range(0, loop_time): 

746 kv_page_number = tl.load(Block_table + offs_n // PAGE_SIZE) 

747 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

748 

749 # Load V (NoPE part) 

750 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_nope[None, :] 

751 v_c = tl.load(KV_cache + offs_v_c) 

752 k_c = tl.trans(v_c) 

753 

754 # QK = q_nope @ k_nope^T 

755 qk = tl.dot(q_nope, k_c) 

756 

757 # Add RoPE contribution 

758 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_pe[:, None] 

759 k_pe = tl.load(KV_cache + offs_k_pe) 

760 qk = tl.dot(q_pe, k_pe, acc=qk) 

761 qk *= sm_scale 

762 

763 # Online softmax update 

764 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

765 re_scale = tl.exp(e_max - n_e_max) 

766 p = tl.exp(qk - n_e_max[:, None]) 

767 acc *= re_scale[:, None] 

768 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

769 e_sum = e_sum * re_scale + tl.sum(p, 1) 

770 e_max = n_e_max 

771 offs_n += BLOCK_N 

772 

773 if remainder: 

774 mask_kvsplit = offs_n < cur_batch_seq_len 

775 kv_page_number = tl.load( 

776 Block_table + offs_n // PAGE_SIZE, mask=mask_kvsplit, other=0 

777 ) 

778 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE 

779 

780 offs_v_c = kv_loc[:, None] * stride_kv_bs + offs_d_nope[None, :] 

781 v_c = tl.load(KV_cache + offs_v_c, mask=mask_kvsplit[:, None], other=0.0) 

782 k_c = tl.trans(v_c) 

783 

784 qk = tl.dot(q_nope, k_c) 

785 

786 offs_k_pe = kv_loc[None, :] * stride_kv_bs + offs_d_pe[:, None] 

787 k_pe = tl.load(KV_cache + offs_k_pe, mask=mask_kvsplit[None, :], other=0.0) 

788 qk = tl.dot(q_pe, k_pe, acc=qk) 

789 qk *= sm_scale 

790 

791 qk = tl.where(mask_kvsplit[None, :], qk, float("-inf")) 

792 

793 n_e_max = tl.maximum(tl.max(qk, 1), e_max) 

794 re_scale = tl.exp(e_max - n_e_max) 

795 p = tl.exp(qk - n_e_max[:, None]) 

796 acc *= re_scale[:, None] 

797 acc = tl.dot(p.to(v_c.dtype), v_c, acc=acc) 

798 e_sum = e_sum * re_scale + tl.sum(p, 1) 

799 e_max = n_e_max 

800 

801 # Store output 

802 offs_o = ( 

803 i_b * stride_o_b 

804 + i_sq * stride_o_sq 

805 + cur_head[:, None] * stride_o_h 

806 + offs_d_nope[None, :] 

807 ) 

808 tl.store( 

809 Out + offs_o, 

810 (acc / e_sum[:, None]).to(Out.dtype.element_ty), 

811 mask=mask_head[:, None], 

812 ) 

813 

814 # Store LSE 

815 lse_val = e_max + tl.math.log(e_sum) 

816 lse_offset = i_b * stride_lse_b + cur_head * stride_lse_h + i_sq 

817 tl.store(LSE + lse_offset, lse_val, mask=mask_head) 

818 

819 

820# ============================================================================ 

821# Main dispatch function 

822# ============================================================================ 

823 

824 

825def flash_mla_with_kvcache( 

826 q: torch.Tensor, 

827 k_cache: torch.Tensor, 

828 block_table: Optional[torch.Tensor], 

829 cache_seqlens: Optional[torch.Tensor], 

830 head_dim_v: int, 

831 tile_scheduler_metadata: FlashMLASchedMeta, 

832 num_splits: None = None, 

833 softmax_scale: Optional[float] = None, 

834 causal: bool = False, 

835 is_fp8_kvcache: bool = False, 

836 indices: Optional[torch.Tensor] = None, 

837 attn_sink: Optional[torch.Tensor] = None, 

838 extra_k_cache: Optional[torch.Tensor] = None, 

839 extra_indices_in_kvcache: Optional[torch.Tensor] = None, 

840 topk_length: Optional[torch.Tensor] = None, 

841 extra_topk_length: Optional[torch.Tensor] = None, 

842 out: Optional[torch.Tensor] = None, 

843) -> Tuple[torch.Tensor, torch.Tensor]: 

844 """ 

845 Triton implementation of flash_mla_with_kvcache. 

846 Functionally equivalent to the CUDA implementation. 

847 

848 Returns: 

849 out: (batch_size, seq_len_q, num_heads_q, head_dim_v) 

850 softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32 

851 """ 

852 sched_meta = tile_scheduler_metadata 

853 assert isinstance(sched_meta, FlashMLASchedMeta) 

854 assert num_splits is None 

855 assert q.ndim == 4 

856 assert k_cache.ndim == 4 

857 

858 topk = indices.shape[-1] if indices is not None else None 

859 extra_k_page_block_size = ( 

860 extra_k_cache.shape[1] if extra_k_cache is not None else None 

861 ) 

862 extra_topk_val = ( 

863 extra_indices_in_kvcache.shape[-1] 

864 if extra_indices_in_kvcache is not None 

865 else None 

866 ) 

867 

868 if softmax_scale is None: 

869 softmax_scale = q.shape[-1] ** (-0.5) 

870 

871 if not sched_meta.have_initialized: 

872 if indices is not None: 

873 assert not causal, "causal must be False when sparse attention is enabled" 

874 sched_meta.have_initialized = True 

875 sched_meta.config = FlashMLASchedMeta.Config( 

876 q.shape[0], 

877 q.shape[1], 

878 q.shape[2], 

879 k_cache.shape[1], 

880 k_cache.shape[2], 

881 causal, 

882 is_fp8_kvcache, 

883 topk, 

884 extra_k_page_block_size, 

885 extra_topk_val, 

886 ) 

887 else: 

888 helper_msg = ( 

889 " Your input arguments are inconsistent with sched_meta. Please make " 

890 "sure the input arguments are consistent across different invocations " 

891 "of flash_mla_with_kvcache on the same sched_meta." 

892 ) 

893 assert sched_meta.config is not None 

894 assert sched_meta.config.b == q.shape[0], ( 

895 "sched_meta.config.b must be equal to batch_size." + helper_msg 

896 ) 

897 assert sched_meta.config.s_q == q.shape[1], ( 

898 "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg 

899 ) 

900 assert sched_meta.config.h_q == q.shape[2], ( 

901 "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg 

902 ) 

903 assert sched_meta.config.page_block_size == k_cache.shape[1], ( 

904 "sched_meta.config.page_block_size must be equal to page_block_size." 

905 + helper_msg 

906 ) 

907 assert sched_meta.config.h_k == k_cache.shape[2], ( 

908 "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg 

909 ) 

910 assert sched_meta.config.causal == causal, ( 

911 "sched_meta.config.causal must be equal to causal." + helper_msg 

912 ) 

913 assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, ( 

914 "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." 

915 + helper_msg 

916 ) 

917 assert sched_meta.config.topk == topk, ( 

918 "sched_meta.config.topk must be equal to the last dim of indices." 

919 + helper_msg 

920 ) 

921 assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, ( 

922 "sched_meta.config.extra_page_block_size must be equal to the " 

923 "page_block_size of extra_k_cache." + helper_msg 

924 ) 

925 assert sched_meta.config.extra_topk == extra_topk_val, ( 

926 "sched_meta.config.extra_topk must be equal to the last dim of " 

927 "extra_indices_in_kvcache." + helper_msg 

928 ) 

929 

930 batch_size, seq_q, num_heads_q, head_dim_k = q.shape 

931 num_heads_k = k_cache.shape[2] 

932 

933 if out is None: 

934 out = torch.empty( 

935 (batch_size, seq_q, num_heads_q, head_dim_v), 

936 dtype=q.dtype, 

937 device=q.device, 

938 ) 

939 else: 

940 assert out.shape == (batch_size, seq_q, num_heads_q, head_dim_v) 

941 assert out.dtype == q.dtype 

942 assert out.device == q.device 

943 assert out.stride(-1) == 1 

944 lse = torch.empty( 

945 (batch_size, num_heads_q, seq_q), 

946 dtype=torch.float32, 

947 device=q.device, 

948 ) 

949 

950 if indices is not None: 

951 assert not causal, "causal must be False when sparse attention is enabled" 

952 assert is_fp8_kvcache, "is_fp8_kvcache must be True for sparse attention" 

953 assert ( 

954 num_heads_k == 1 

955 ), "Currently only MQA (h_kv == 1) is supported for sparse decoding" 

956 assert head_dim_v == 512, "Only head_size_v == 512 is supported" 

957 assert num_heads_q in (64, 128), "Only h_q == 64 or 128 is supported" 

958 assert head_dim_k in ( 

959 512, 

960 576, 

961 ), "Only head_size_k == 512 or 576 is supported for sparse decoding" 

962 assert q.dtype == torch.bfloat16 

963 assert k_cache.dtype in (torch.float8_e4m3fn, torch.int8, torch.uint8) 

964 assert topk is not None and topk > 0 

965 assert topk % 64 == 0, "topk must be divisible by 64" 

966 assert indices.ndim == 3 and indices.shape[:2] == (batch_size, seq_q) 

967 assert indices.dtype == torch.int32 

968 assert indices.stride(-1) == 1 

969 if topk_length is not None: 

970 assert topk_length.shape == (batch_size,) 

971 assert topk_length.dtype == torch.int32 

972 assert topk_length.is_contiguous() 

973 if attn_sink is not None: 

974 assert attn_sink.shape == (num_heads_q,) 

975 assert attn_sink.dtype == torch.float32 

976 if extra_k_cache is not None: 

977 assert extra_indices_in_kvcache is not None, ( 

978 "extra_indices_in_kvcache must be provided when extra_k_cache " 

979 "is provided" 

980 ) 

981 assert extra_k_cache.dtype in ( 

982 torch.float8_e4m3fn, 

983 torch.int8, 

984 torch.uint8, 

985 ) 

986 else: 

987 assert extra_indices_in_kvcache is None, ( 

988 "extra_indices_in_kvcache must not be provided when extra_k_cache " 

989 "is not provided" 

990 ) 

991 assert extra_topk_length is None, ( 

992 "extra_topk_length must not be provided when extra_k_cache is " 

993 "not provided" 

994 ) 

995 if extra_indices_in_kvcache is not None: 

996 assert extra_indices_in_kvcache.ndim == 3 

997 assert extra_indices_in_kvcache.shape[:2] == (batch_size, seq_q) 

998 assert extra_indices_in_kvcache.dtype == torch.int32 

999 assert extra_indices_in_kvcache.stride(-1) == 1 

1000 assert extra_indices_in_kvcache.shape[-1] % 64 == 0 

1001 if extra_topk_length is not None: 

1002 assert extra_topk_length.shape == (batch_size,) 

1003 assert extra_topk_length.dtype == torch.int32 

1004 assert extra_topk_length.is_contiguous() 

1005 if head_dim_k == 576: 

1006 assert ( 

1007 k_cache.shape[-1] == 656 

1008 ), "V32 sparse FP8 cache must use 656 bytes per token" 

1009 assert ( 

1010 k_cache.stride(1) == 656 

1011 ), "The whole block must be contiguous for V32 KV cache" 

1012 assert topk_length is None, "V3.2/V32 does not support dynamic topk length" 

1013 assert extra_k_cache is None, "V3.2/V32 does not support extra KV cache" 

1014 assert ( 

1015 extra_indices_in_kvcache is None 

1016 ), "V3.2/V32 does not support extra indices" 

1017 assert ( 

1018 extra_topk_length is None 

1019 ), "V3.2/V32 does not support extra topk length" 

1020 else: 

1021 assert ( 

1022 k_cache.shape[-1] == 584 

1023 ), "MODEL1 sparse FP8 cache must use 584 bytes per token" 

1024 assert ( 

1025 k_cache.stride(1) == 584 

1026 ), "The whole block must be contiguous for MODEL1 KV cache" 

1027 if extra_k_cache is not None: 

1028 assert extra_k_cache.ndim == 4 

1029 assert extra_k_cache.shape[2] == 1 

1030 assert extra_k_cache.shape[-1] == 584 

1031 assert extra_k_cache.stride(1) == 584 

1032 _sparse_decode_dispatch( 

1033 q, 

1034 k_cache, 

1035 indices, 

1036 out, 

1037 lse, 

1038 attn_sink, 

1039 topk_length, 

1040 extra_k_cache, 

1041 extra_indices_in_kvcache, 

1042 extra_topk_length, 

1043 batch_size, 

1044 seq_q, 

1045 num_heads_q, 

1046 head_dim_k, 

1047 head_dim_v, 

1048 topk, 

1049 k_cache.shape[1], 

1050 softmax_scale, 

1051 is_fp8_kvcache, 

1052 ) 

1053 else: 

1054 assert ( 

1055 attn_sink is None 

1056 and extra_k_cache is None 

1057 and extra_indices_in_kvcache is None 

1058 and topk_length is None 

1059 and extra_topk_length is None 

1060 ), ( 

1061 "indices, attn_sink, extra_k_cache, extra_indices_in_kvcache, " 

1062 "topk_length and extra_topk_length must be None when dense " 

1063 "attention is used." 

1064 ) 

1065 assert block_table is not None and cache_seqlens is not None, ( 

1066 "block_table and cache_seqlens must be provided when dense attention " 

1067 "is used." 

1068 ) 

1069 assert num_heads_k == 1, "Only num_heads_k == 1 is supported for dense MLA" 

1070 if seq_q > 1 and causal: 

1071 raise NotImplementedError( 

1072 "causal dense attention with seq_q > 1 is not implemented" 

1073 ) 

1074 _dense_decode_dispatch( 

1075 q, 

1076 k_cache, 

1077 block_table, 

1078 cache_seqlens, 

1079 out, 

1080 lse, 

1081 batch_size, 

1082 seq_q, 

1083 num_heads_q, 

1084 head_dim_k, 

1085 head_dim_v, 

1086 k_cache.shape[1], 

1087 softmax_scale, 

1088 causal, 

1089 ) 

1090 

1091 return out, lse 

1092 

1093 

1094# ============================================================================ 

1095# Kernel launch helpers 

1096# ============================================================================ 

1097 

1098 

1099def _sparse_decode_dispatch( 

1100 q, 

1101 kv, 

1102 indices, 

1103 out, 

1104 lse, 

1105 attn_sink, 

1106 topk_length, 

1107 extra_kv, 

1108 extra_indices, 

1109 extra_topk_length, 

1110 batch_size, 

1111 seq_q, 

1112 num_heads_q, 

1113 head_dim_k, 

1114 head_dim_v, 

1115 topk, 

1116 page_block_size, 

1117 softmax_scale, 

1118 is_fp8_kvcache, 

1119): 

1120 """Launch sparse decode kernel.""" 

1121 BH = 64 

1122 num_head_blocks = (num_heads_q + BH - 1) // BH 

1123 grid = (batch_size * seq_q * num_head_blocks,) 

1124 

1125 skv = kv.shape[0] * page_block_size 

1126 

1127 if head_dim_k == 512: 

1128 _sparse_decode_model1_kernel[grid]( 

1129 q, 

1130 kv, 

1131 indices, 

1132 extra_kv if extra_kv is not None else kv, 

1133 extra_indices if extra_indices is not None else indices, 

1134 attn_sink if attn_sink is not None else None, 

1135 topk_length if topk_length is not None else None, 

1136 extra_topk_length if extra_topk_length is not None else None, 

1137 softmax_scale, 

1138 out, 

1139 lse, 

1140 # Q strides 

1141 q.stride(0), 

1142 q.stride(1), 

1143 q.stride(2), 

1144 # KV and indices strides 

1145 kv.stride(0), 

1146 indices.stride(0), 

1147 indices.stride(1), 

1148 extra_kv.stride(0) if extra_kv is not None else kv.stride(0), 

1149 extra_indices.stride(0) if extra_indices is not None else indices.stride(0), 

1150 extra_indices.stride(1) if extra_indices is not None else indices.stride(1), 

1151 # Output strides 

1152 out.stride(0), 

1153 out.stride(1), 

1154 out.stride(2), 

1155 # LSE strides 

1156 lse.stride(0), 

1157 lse.stride(1), 

1158 # Scalar args 

1159 seq_q, 

1160 num_heads_q, 

1161 page_block_size, 

1162 extra_kv.shape[1] if extra_kv is not None else 1, 

1163 kv.shape[0], 

1164 extra_kv.shape[0] if extra_kv is not None else 0, 

1165 topk, 

1166 extra_indices.shape[-1] if extra_indices is not None else 0, 

1167 attn_sink is not None, 

1168 topk_length is not None, 

1169 extra_kv is not None, 

1170 extra_topk_length is not None, 

1171 ) 

1172 return 

1173 

1174 if is_fp8_kvcache: 

1175 # FP8 mode: kv has shape [num_blocks, page_block_size, 1, 656] 

1176 # Layout per token (656 bytes): 

1177 # [0:512] - 512 float8_e4m3fn values (NoPE) 

1178 # [512:528] - 4 float32 scales (16 bytes) 

1179 # [528:656] - 64 bfloat16 values (RoPE, 128 bytes) 

1180 kv_bytes = kv.reshape(-1, 656).contiguous() # [num_tokens, 656] uint8 

1181 

1182 # NoPE FP8 part: first 512 bytes as float8_e4m3fn 

1183 kv_nope = ( 

1184 kv_bytes[:, :512].contiguous().view(torch.float8_e4m3fn) 

1185 ) # [num_tokens, 512] 

1186 stride_kvn = kv_nope.stride(0) 

1187 

1188 # Scales: bytes [512:528] as 4 float32 values 

1189 kv_scales = ( 

1190 kv_bytes[:, 512:528].contiguous().view(torch.float32) 

1191 ) # [num_tokens, 4] 

1192 stride_scales_n = kv_scales.stride(0) 

1193 

1194 # RoPE BF16 part: bytes [528:656] as 64 bfloat16 values 

1195 kv_rope = ( 

1196 kv_bytes[:, 528:656].contiguous().view(torch.bfloat16) 

1197 ) # [num_tokens, 64] 

1198 stride_rope_n = kv_rope.stride(0) 

1199 else: 

1200 # BF16 mode: kv has shape [num_blocks, page_block_size, 1, head_dim_k] 

1201 kv_nope = kv.reshape(-1, kv.shape[-1]).contiguous() 

1202 stride_kvn = kv_nope.stride(0) 

1203 kv_scales = kv_nope # unused, pass same tensor 

1204 stride_scales_n = 0 

1205 kv_rope = kv_nope # unused, pass same tensor 

1206 stride_rope_n = 0 

1207 

1208 # # TLE warp specialization path TODO 

1209 # if _can_use_tle_sparse_decode(q, indices, head_dim_v, head_dim_k, is_fp8_kvcache): 

1210 # _tle_sparse_decode_launch( 

1211 # q, kv_nope, kv_scales, kv_rope, indices, out, lse, 

1212 # attn_sink, topk_length, 

1213 # batch_size, seq_q, num_heads_q, 

1214 # head_dim_k, head_dim_v, topk, skv, 

1215 # softmax_scale, is_fp8_kvcache, 

1216 # stride_kvn, stride_scales_n, stride_rope_n, 

1217 # ) 

1218 # return 

1219 

1220 _sparse_decode_kernel[grid]( 

1221 q, 

1222 kv_nope, 

1223 kv_scales, 

1224 kv_rope, 

1225 indices, 

1226 attn_sink if attn_sink is not None else None, 

1227 topk_length if topk_length is not None else None, 

1228 softmax_scale, 

1229 out, 

1230 lse, 

1231 # Q strides 

1232 q.stride(0), 

1233 q.stride(1), 

1234 q.stride(2), 

1235 # KV strides 

1236 stride_kvn, 

1237 stride_scales_n, 

1238 stride_rope_n, 

1239 # Indices strides 

1240 indices.stride(0), 

1241 indices.stride(1), 

1242 # Output strides 

1243 out.stride(0), 

1244 out.stride(1), 

1245 out.stride(2), 

1246 # LSE strides 

1247 lse.stride(0), 

1248 lse.stride(1), 

1249 # Scalar args 

1250 seq_q, 

1251 num_heads_q, 

1252 head_dim_k, 

1253 skv, 

1254 topk, 

1255 attn_sink is not None, 

1256 topk_length is not None, 

1257 is_fp8_kvcache, 

1258 ) 

1259 

1260 

1261def _dense_decode_dispatch( 

1262 q, 

1263 kv_cache, 

1264 block_table, 

1265 cache_seqlens, 

1266 out, 

1267 lse, 

1268 batch_size, 

1269 seq_q, 

1270 num_heads_q, 

1271 head_dim_k, 

1272 head_dim_v, 

1273 page_block_size, 

1274 softmax_scale, 

1275 causal, 

1276): 

1277 """Launch dense decode kernel.""" 

1278 BLOCK_H = 64 

1279 num_head_blocks = (num_heads_q + BLOCK_H - 1) // BLOCK_H 

1280 

1281 # KV cache: [num_blocks, page_block_size, num_heads_k, head_dim_k] 

1282 # Flatten to [num_tokens_total, head_dim_k] for paged access 

1283 kv_flat = kv_cache.view(-1, head_dim_k).contiguous() 

1284 block_table = block_table.contiguous() 

1285 

1286 # TLE warp specialization path 

1287 if _can_use_tle_dense_decode(q, kv_cache, block_table, head_dim_v, page_block_size): 

1288 _tle_dense_decode_launch( 

1289 q, 

1290 kv_flat, 

1291 block_table, 

1292 cache_seqlens, 

1293 out, 

1294 lse, 

1295 batch_size, 

1296 seq_q, 

1297 num_heads_q, 

1298 head_dim_k, 

1299 head_dim_v, 

1300 page_block_size, 

1301 softmax_scale, 

1302 causal, 

1303 ) 

1304 return 

1305 

1306 grid = (num_head_blocks, batch_size * seq_q) 

1307 

1308 _dense_decode_kernel[grid]( 

1309 q, 

1310 q.stride(0), 

1311 q.stride(1), 

1312 q.stride(2), 

1313 kv_flat, 

1314 kv_flat.stride(0), 

1315 block_table, 

1316 block_table.stride(0), 

1317 cache_seqlens, 

1318 out, 

1319 out.stride(0), 

1320 out.stride(1), 

1321 out.stride(2), 

1322 lse, 

1323 lse.stride(0), 

1324 lse.stride(1), 

1325 softmax_scale, 

1326 seq_q, 

1327 num_heads_q, 

1328 head_dim_k, 

1329 head_dim_v, 

1330 page_block_size, 

1331 causal, 

1332 ) 

1333 

1334 

1335# ============================================================================ 

1336# TLE Warp Specialization path for sparse decode 

1337# ============================================================================ 

1338 

1339 

1340def _tle_decode_enabled() -> bool: 

1341 value = os.environ.get("FLAGGEMS_FLASHMLA_DECODE_TLE", "1").lower() 

1342 return value not in {"0", "false", "off", "no"} 

1343 

1344 

1345def _can_use_tle_sparse_decode( 

1346 q: torch.Tensor, 

1347 indices: torch.Tensor, 

1348 head_dim_v: int, 

1349 head_dim_k: int, 

1350 is_fp8: bool, 

1351) -> bool: 

1352 if not (HAS_TLE and _tle_decode_enabled()): 

1353 return False 

1354 if q.device.type != "cuda": 

1355 return False 

1356 batch_size, seq_q, num_heads_q, d_qk = q.shape 

1357 TOPK = indices.shape[-1] 

1358 return ( 

1359 head_dim_v == 512 

1360 and d_qk in (512, 576) 

1361 and num_heads_q % TLE_DECODE_BH == 0 

1362 and TOPK > 0 

1363 and TOPK % (TLE_DECODE_BK * TLE_DECODE_PAIR_BLOCKS) == 0 

1364 ) 

1365 

1366 

1367def _can_use_tle_dense_decode( 

1368 q: torch.Tensor, 

1369 kv_cache: torch.Tensor, 

1370 block_table: torch.Tensor, 

1371 head_dim_v: int, 

1372 page_block_size: int, 

1373) -> bool: 

1374 if not (HAS_TLE and _tle_decode_enabled()): 

1375 return False 

1376 if q.device.type != "cuda": 

1377 return False 

1378 batch_size, seq_q, num_heads_q, d_qk = q.shape 

1379 return ( 

1380 head_dim_v == 512 

1381 and d_qk in (512, 576) 

1382 and num_heads_q % TLE_DECODE_BH == 0 

1383 and page_block_size == TLE_DECODE_BK 

1384 ) 

1385 

1386 

1387def _set_triton_descriptor_allocator(device: torch.device) -> None: 

1388 def alloc_fn(size: int, align: int, stream): 

1389 _ = align 

1390 _ = stream 

1391 return torch.empty(size, dtype=torch.int8, device=device) 

1392 

1393 triton.set_allocator(alloc_fn) 

1394 

1395 

1396def _tle_sparse_decode_launch( 

1397 q, 

1398 kv_nope, 

1399 kv_scales, 

1400 kv_rope, 

1401 indices, 

1402 out, 

1403 lse, 

1404 attn_sink, 

1405 topk_length, 

1406 batch_size, 

1407 seq_q, 

1408 num_heads_q, 

1409 head_dim_k, 

1410 head_dim_v, 

1411 topk, 

1412 skv, 

1413 softmax_scale, 

1414 is_fp8_kvcache, 

1415 stride_kvn, 

1416 stride_scales_n, 

1417 stride_rope_n, 

1418): 

1419 """Launch TLE warp-specialized sparse decode kernel.""" 

1420 from triton.tools.tensor_descriptor import TensorDescriptor 

1421 

1422 _set_triton_descriptor_allocator(q.device) 

1423 

1424 BH = TLE_DECODE_BH 

1425 BK = TLE_DECODE_BK 

1426 D = head_dim_v # 512 

1427 TD = head_dim_k - D # 64 for DQK=576, 0 for DQK=512 

1428 DP = triton.next_power_of_2(D) 

1429 DPH = DP // 2 

1430 HAVE_TAIL = TD > 0 

1431 TDP = triton.next_power_of_2(TD) if HAVE_TAIL else 1 

1432 G = num_heads_q 

1433 RH = G // BH 

1434 

1435 # Reshape q for TensorDescriptor: [batch*seq_q*HQ, DQK] 

1436 q_flat = q.reshape(batch_size * seq_q * num_heads_q, head_dim_k).contiguous() 

1437 out_flat = out.reshape(batch_size * seq_q * num_heads_q, head_dim_v) 

1438 

1439 q_desc = TensorDescriptor( 

1440 q_flat, 

1441 shape=[batch_size * seq_q * num_heads_q, head_dim_k], 

1442 strides=[head_dim_k, 1], 

1443 block_shape=[BH, DPH], 

1444 ) 

1445 if HAVE_TAIL: 

1446 tq_desc = TensorDescriptor( 

1447 q_flat, 

1448 shape=[batch_size * seq_q * num_heads_q, head_dim_k], 

1449 strides=[head_dim_k, 1], 

1450 block_shape=[BH, TDP], 

1451 ) 

1452 else: 

1453 tq_desc = q_desc 

1454 output_desc = TensorDescriptor( 

1455 out_flat, 

1456 shape=[batch_size * seq_q * num_heads_q, D], 

1457 strides=[D, 1], 

1458 block_shape=[BH, DPH], 

1459 ) 

1460 

1461 # Grid: one program per (batch*seq_q, head_block) 

1462 grid = (batch_size * seq_q * RH,) 

1463 

1464 # Indices stride: [batch, seq_q, topk] -> stride for batch*seq_q dim 

1465 stride_isq = ( 

1466 indices.stride(0) * indices.stride(1) // indices.stride(1) 

1467 if seq_q == 1 

1468 else indices.stride(1) 

1469 ) 

1470 # For shape [batch, seq_q, topk]: stride_isq = topk (contiguous) 

1471 stride_isq = topk 

1472 

1473 _tle_sparse_decode_fwd[grid]( 

1474 q_desc, 

1475 tq_desc, 

1476 output_desc, 

1477 kv_nope, 

1478 kv_scales, 

1479 kv_rope, 

1480 indices.reshape(batch_size * seq_q, topk).contiguous(), 

1481 attn_sink, 

1482 topk_length, 

1483 softmax_scale, 

1484 out_flat, 

1485 lse.reshape(batch_size * seq_q, num_heads_q).contiguous(), 

1486 batch_size * seq_q, 

1487 num_heads_q, 

1488 head_dim_k, 

1489 skv, 

1490 topk, 

1491 attn_sink is not None, 

1492 topk_length is not None, 

1493 is_fp8_kvcache, 

1494 D, 

1495 TD, 

1496 DP, 

1497 TDP, 

1498 G, 

1499 RH, 

1500 HAVE_TAIL, 

1501 BK, 

1502 BH, 

1503 TLE_DECODE_PAIR_BLOCKS, 

1504 stride_kvn, 

1505 stride_scales_n, 

1506 stride_rope_n, 

1507 indices.stride(0), 

1508 stride_isq, 

1509 num_warps=TLE_DECODE_WORKER_NUM_WARPS, 

1510 num_stages=1, 

1511 ) 

1512 

1513 

1514if HAS_TLE: 

1515 

1516 @triton.jit 

1517 def _tle_sparse_decode_producer( 

1518 k0_l_writer, 

1519 k0_r_writer, 

1520 k1_l_writer, 

1521 k1_r_writer, 

1522 valid_writer, 

1523 kv_nope_base, 

1524 kv_scales_base, 

1525 kv_rope_base, 

1526 t_base, 

1527 topk_len_ptr, 

1528 D: tl.constexpr, 

1529 TD: tl.constexpr, 

1530 DPH: tl.constexpr, 

1531 TDP: tl.constexpr, 

1532 SKV, 

1533 TOPK: tl.constexpr, 

1534 HAVE_TOPK_LENGTH: tl.constexpr, 

1535 HAVE_TAIL: tl.constexpr, 

1536 IS_FP8: tl.constexpr, 

1537 BK: tl.constexpr, 

1538 stride_kvn, 

1539 stride_scales_n, 

1540 stride_rope_n, 

1541 ): 

1542 """ 

1543 Producer warpgroup: loads KV data from global memory to shared memory. 

1544 For FP8 mode: loads FP8 NoPE + scales + RoPE, dequantizes FP8 to BF16. 

1545 For BF16 mode: loads BF16 KV directly. 

1546 """ 

1547 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK 

1548 max_col = SKV - 1 

1549 NK = tl.cdiv(topk_len, BK) 

1550 NPAIRS = tl.cdiv(NK, 2) 

1551 offs_t = tl.arange(0, BK) 

1552 offs_tile = tl.arange(0, 64) 

1553 kv_tile_rows = tl.broadcast_to(offs_t[:, None], (BK, 64)) 

1554 

1555 for pair in tl.range(NPAIRS): 

1556 ck0 = pair * 2 

1557 ck1 = ck0 + 1 

1558 

1559 # Load indices for both blocks 

1560 t_offs0 = BK * ck0 + offs_t 

1561 t_msk0 = t_offs0 < topk_len 

1562 kv_ids0 = tl.load(t_base + t_offs0, t_msk0, other=-1) 

1563 valid0 = t_msk0 & (kv_ids0 <= max_col) & (kv_ids0 >= 0) 

1564 

1565 t_offs1 = BK * ck1 + offs_t 

1566 t_msk1 = t_offs1 < topk_len 

1567 kv_ids1 = tl.load(t_base + t_offs1, t_msk1, other=-1) 

1568 valid1 = t_msk1 & (kv_ids1 <= max_col) & (kv_ids1 >= 0) 

1569 

1570 # Process k0_l (left half of block 0) 

1571 k0_l_slot = k0_l_writer.acquire(pair) 

1572 for tile in tl.static_range(0, DPH, 64): 

1573 k_cols = tile + offs_tile 

1574 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

1575 

1576 if IS_FP8: 

1577 # Load FP8 data 

1578 kv_ptr = ( 

1579 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn 

1580 ) 

1581 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :] 

1582 k0_l_fp8 = tl.load( 

1583 kv_ptr, mask=k0_l_msk, other=0.0, eviction_policy="evict_last" 

1584 ) 

1585 

1586 # Load scales for dequantization 

1587 # Each 128 elements share one scale 

1588 scale_idx = k_cols // 128 # 0 or 1 for left half 

1589 scale0 = tl.load( 

1590 kv_scales_base + kv_ids0 * stride_scales_n + scale_idx, 

1591 mask=valid0, 

1592 other=1.0, 

1593 ) 

1594 

1595 # Dequantize: FP8 * scale -> BF16 

1596 k0_l_blk = (k0_l_fp8.to(tl.float32) * scale0[:, None]).to( 

1597 tl.bfloat16 

1598 ) 

1599 else: 

1600 # BF16 mode: load directly 

1601 kv_ptr = ( 

1602 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn 

1603 ) 

1604 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :] 

1605 k0_l_blk = tl.load( 

1606 kv_ptr, mask=k0_l_msk, other=0.0, eviction_policy="evict_last" 

1607 ) 

1608 

1609 tl.store( 

1610 tle.gpu.local_ptr(k0_l_slot.sK, (kv_tile_rows, k_cols_b)), 

1611 k0_l_blk, 

1612 mask=valid0[:, None] & (k_cols < D)[None, :], 

1613 ) 

1614 k0_l_writer.commit(pair) 

1615 

1616 # Process k1_r (right half of block 1) 

1617 k1_r_slot = k1_r_writer.acquire(pair) 

1618 for tile in tl.static_range(0, DPH, 64): 

1619 k_cols = DPH + tile + offs_tile 

1620 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

1621 

1622 if IS_FP8: 

1623 kv_ptr = ( 

1624 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn 

1625 ) 

1626 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :] 

1627 k1_r_fp8 = tl.load( 

1628 kv_ptr, mask=k1_r_msk, other=0.0, eviction_policy="evict_last" 

1629 ) 

1630 

1631 # Scale index: 2 or 3 for right half 

1632 scale_idx = 2 + (k_cols - DPH) // 128 

1633 scale1 = tl.load( 

1634 kv_scales_base + kv_ids1 * stride_scales_n + scale_idx, 

1635 mask=valid1, 

1636 other=1.0, 

1637 ) 

1638 

1639 k1_r_blk = (k1_r_fp8.to(tl.float32) * scale1[:, None]).to( 

1640 tl.bfloat16 

1641 ) 

1642 else: 

1643 kv_ptr = ( 

1644 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn 

1645 ) 

1646 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :] 

1647 k1_r_blk = tl.load( 

1648 kv_ptr, mask=k1_r_msk, other=0.0, eviction_policy="evict_last" 

1649 ) 

1650 

1651 tl.store( 

1652 tle.gpu.local_ptr(k1_r_slot.sK, (kv_tile_rows, k_cols_b)), 

1653 k1_r_blk, 

1654 mask=valid1[:, None] & (k_cols < D)[None, :], 

1655 ) 

1656 

1657 # Load RoPE tail if needed 

1658 if HAVE_TAIL: 

1659 offs_td = tl.arange(0, TDP) 

1660 if IS_FP8: 

1661 k1_r_tail_ptr = ( 

1662 kv_rope_base 

1663 + offs_td[None, :] 

1664 + kv_ids1[:, None] * stride_rope_n 

1665 ) 

1666 else: 

1667 k1_r_tail_ptr = ( 

1668 kv_nope_base 

1669 + D 

1670 + offs_td[None, :] 

1671 + kv_ids1[:, None] * stride_kvn 

1672 ) 

1673 k1_r_tail_msk = valid1[:, None] & (offs_td < TD)[None, :] 

1674 k1_r_tail_blk = tl.load( 

1675 k1_r_tail_ptr, 

1676 mask=k1_r_tail_msk, 

1677 other=0.0, 

1678 eviction_policy="evict_last", 

1679 ) 

1680 tl.store( 

1681 tle.gpu.local_ptr(k1_r_slot.sK_tail), 

1682 k1_r_tail_blk, 

1683 mask=k1_r_tail_msk, 

1684 ) 

1685 k1_r_writer.commit(pair) 

1686 

1687 # Process k0_r (right half of block 0) 

1688 k0_r_slot = k0_r_writer.acquire(pair) 

1689 for tile in tl.static_range(0, DPH, 64): 

1690 k_cols = DPH + tile + offs_tile 

1691 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

1692 

1693 if IS_FP8: 

1694 kv_ptr = ( 

1695 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn 

1696 ) 

1697 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :] 

1698 k0_r_fp8 = tl.load( 

1699 kv_ptr, mask=k0_r_msk, other=0.0, eviction_policy="evict_last" 

1700 ) 

1701 

1702 scale_idx = 2 + (k_cols - DPH) // 128 

1703 scale0 = tl.load( 

1704 kv_scales_base + kv_ids0 * stride_scales_n + scale_idx, 

1705 mask=valid0, 

1706 other=1.0, 

1707 ) 

1708 

1709 k0_r_blk = (k0_r_fp8.to(tl.float32) * scale0[:, None]).to( 

1710 tl.bfloat16 

1711 ) 

1712 else: 

1713 kv_ptr = ( 

1714 kv_nope_base + k_cols[None, :] + kv_ids0[:, None] * stride_kvn 

1715 ) 

1716 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :] 

1717 k0_r_blk = tl.load( 

1718 kv_ptr, mask=k0_r_msk, other=0.0, eviction_policy="evict_last" 

1719 ) 

1720 

1721 tl.store( 

1722 tle.gpu.local_ptr(k0_r_slot.sK, (kv_tile_rows, k_cols_b)), 

1723 k0_r_blk, 

1724 mask=valid0[:, None] & (k_cols < D)[None, :], 

1725 ) 

1726 

1727 if HAVE_TAIL: 

1728 offs_td = tl.arange(0, TDP) 

1729 if IS_FP8: 

1730 k0_r_tail_ptr = ( 

1731 kv_rope_base 

1732 + offs_td[None, :] 

1733 + kv_ids0[:, None] * stride_rope_n 

1734 ) 

1735 else: 

1736 k0_r_tail_ptr = ( 

1737 kv_nope_base 

1738 + D 

1739 + offs_td[None, :] 

1740 + kv_ids0[:, None] * stride_kvn 

1741 ) 

1742 k0_r_tail_msk = valid0[:, None] & (offs_td < TD)[None, :] 

1743 k0_r_tail_blk = tl.load( 

1744 k0_r_tail_ptr, 

1745 mask=k0_r_tail_msk, 

1746 other=0.0, 

1747 eviction_policy="evict_last", 

1748 ) 

1749 tl.store( 

1750 tle.gpu.local_ptr(k0_r_slot.sK_tail), 

1751 k0_r_tail_blk, 

1752 mask=k0_r_tail_msk, 

1753 ) 

1754 k0_r_writer.commit(pair) 

1755 

1756 # Process k1_l (left half of block 1) 

1757 k1_l_slot = k1_l_writer.acquire(pair) 

1758 for tile in tl.static_range(0, DPH, 64): 

1759 k_cols = tile + offs_tile 

1760 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

1761 

1762 if IS_FP8: 

1763 kv_ptr = ( 

1764 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn 

1765 ) 

1766 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :] 

1767 k1_l_fp8 = tl.load( 

1768 kv_ptr, mask=k1_l_msk, other=0.0, eviction_policy="evict_last" 

1769 ) 

1770 

1771 scale_idx = k_cols // 128 

1772 scale1 = tl.load( 

1773 kv_scales_base + kv_ids1 * stride_scales_n + scale_idx, 

1774 mask=valid1, 

1775 other=1.0, 

1776 ) 

1777 

1778 k1_l_blk = (k1_l_fp8.to(tl.float32) * scale1[:, None]).to( 

1779 tl.bfloat16 

1780 ) 

1781 else: 

1782 kv_ptr = ( 

1783 kv_nope_base + k_cols[None, :] + kv_ids1[:, None] * stride_kvn 

1784 ) 

1785 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :] 

1786 k1_l_blk = tl.load( 

1787 kv_ptr, mask=k1_l_msk, other=0.0, eviction_policy="evict_last" 

1788 ) 

1789 

1790 tl.store( 

1791 tle.gpu.local_ptr(k1_l_slot.sK, (kv_tile_rows, k_cols_b)), 

1792 k1_l_blk, 

1793 mask=valid1[:, None] & (k_cols < D)[None, :], 

1794 ) 

1795 k1_l_writer.commit(pair) 

1796 

1797 # Store validity masks 

1798 valid_slot = valid_writer.acquire(pair) 

1799 valid_row0 = tl.full([BK], 0, dtype=tl.int32) 

1800 valid_row1 = tl.full([BK], 1, dtype=tl.int32) 

1801 valid_ptr0 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row0, offs_t)) 

1802 valid_ptr1 = tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row1, offs_t)) 

1803 tl.store(valid_ptr0, valid0.to(tl.int8)) 

1804 tl.store(valid_ptr1, valid1.to(tl.int8)) 

1805 valid_writer.commit(pair) 

1806 

1807 @triton.jit 

1808 def _tle_sparse_decode_consumer0( 

1809 q_writer, 

1810 q_reader, 

1811 q_desc, 

1812 tq_desc, 

1813 k0_l_reader, 

1814 k0_r_qk_reader, 

1815 k1_l_remote_reader, 

1816 valid_reader, 

1817 sM_wg0_writer, 

1818 sM_wg1_reader, 

1819 sS0_writer, 

1820 sS1_reader, 

1821 sL_wg0_writer, 

1822 sL_wg1_reader, 

1823 output_desc, 

1824 output_row, 

1825 h_base, 

1826 topk_len_ptr, 

1827 attn_sink_base, 

1828 log_scale: tl.constexpr, 

1829 D: tl.constexpr, 

1830 TD: tl.constexpr, 

1831 OUT_DTYPE: tl.constexpr, 

1832 HAVE_ATTN_SINK: tl.constexpr, 

1833 TOPK: tl.constexpr, 

1834 HAVE_TOPK_LENGTH: tl.constexpr, 

1835 HAVE_TAIL: tl.constexpr, 

1836 BK: tl.constexpr, 

1837 BH: tl.constexpr, 

1838 DPH: tl.constexpr, 

1839 TDP: tl.constexpr, 

1840 G: tl.constexpr, 

1841 ): 

1842 """Consumer 0: computes QK^T + online softmax + P@V_left.""" 

1843 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK 

1844 offs_h = tl.arange(0, BH) 

1845 offs_dh = tl.arange(0, DPH) 

1846 mask_h = h_base + offs_h < G 

1847 mask_od_l = offs_dh < D 

1848 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH)) 

1849 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH)) 

1850 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH)) 

1851 

1852 # Load Q into shared memory (one-shot) 

1853 q_write_slot = q_writer.acquire(0) 

1854 tle.gpu.copy(q_desc, q_write_slot.sQ_l, [BH, DPH], [output_row, 0]) 

1855 tle.gpu.copy(q_desc, q_write_slot.sQ_r, [BH, DPH], [output_row, DPH]) 

1856 if HAVE_TAIL: 

1857 tle.gpu.copy(tq_desc, q_write_slot.sQ_tail, [BH, TDP], [output_row, D]) 

1858 q_writer.commit(0) 

1859 

1860 q_slot = q_reader.wait(0).slot 

1861 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l) 

1862 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r) 

1863 

1864 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32) 

1865 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

1866 acc_l = tl.zeros([BH, DPH], dtype=tl.float32) 

1867 

1868 NK = tl.cdiv(topk_len, BK) 

1869 NPAIRS = tl.cdiv(NK, 2) 

1870 for pair in tl.range(NPAIRS): 

1871 # Wait for k0_l data 

1872 k0_l_wait = k0_l_reader.wait(pair) 

1873 k0_l_slot = k0_l_wait.slot 

1874 

1875 q_l_blk = tl.load(q_l_smem_ptr) 

1876 q_r_blk = tl.load(q_r_smem_ptr) 

1877 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l))) 

1878 

1879 # QK for block 0: q_l @ k0_l^T + q_r @ k0_r^T + q_tail @ k0_tail^T 

1880 qk0 = tl.full([BH, BK], 0.0, dtype=tl.float32) 

1881 qk0 = tl.dot(q_l_blk, tl.trans(k0_l_blk), qk0, out_dtype=tl.float32) 

1882 

1883 # Wait for k0_r 

1884 k0_r_wait = k0_r_qk_reader.wait(pair) 

1885 k0_r_slot = k0_r_wait.slot 

1886 k0_r_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK, (kv_rows, kv_cols_r))) 

1887 qk0 = tl.dot(q_r_blk, tl.trans(k0_r_blk), qk0, out_dtype=tl.float32) 

1888 

1889 if HAVE_TAIL: 

1890 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail)) 

1891 k0_t_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK_tail)) 

1892 qk0 = tl.dot(q_tail_blk, tl.trans(k0_t_blk), qk0, out_dtype=tl.float32) 

1893 

1894 # Get validity mask for block 0 

1895 valid_wait = valid_reader.wait(pair) 

1896 row0 = tl.full([BK], 0, dtype=tl.int32) 

1897 valid0 = ( 

1898 tl.load( 

1899 tle.gpu.local_ptr( 

1900 valid_wait.slot.is_kv_valid, (row0, tl.arange(0, BK)) 

1901 ) 

1902 ).to(tl.int32) 

1903 == 1 

1904 ) 

1905 

1906 qk0 = tl.where(valid0[None, :], qk0, float("-inf")) 

1907 

1908 # Compute local softmax for block 0 only 

1909 local_max = tl.maximum(max_prev, tl.max(qk0, axis=1)) 

1910 alpha = tl.math.exp2((max_prev - local_max) * log_scale) 

1911 prob0 = tl.math.exp2(qk0 * log_scale - local_max[:, None] * log_scale) 

1912 sum_exp = sum_exp * alpha + tl.sum(prob0, axis=1) 

1913 acc_l = acc_l * alpha[:, None] 

1914 prob0_b = prob0.to(OUT_DTYPE) 

1915 

1916 # Send local_max to consumer1 

1917 sM_wg0_slot = sM_wg0_writer.acquire(pair) 

1918 tl.store(tle.gpu.local_ptr(sM_wg0_slot.sM), local_max) 

1919 sM_wg0_writer.commit(pair) 

1920 

1921 # Accumulate P@V_left with prob0 

1922 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l))) 

1923 acc_l = tl.dot(prob0_b, k0_l_blk, acc_l, out_dtype=tl.float32) 

1924 k0_l_reader.release(pair) 

1925 k0_r_qk_reader.release(pair) 

1926 

1927 # Wait for max_next from consumer1 (merged max of block0 and block1) 

1928 sM_wg1_wait = sM_wg1_reader.wait(pair) 

1929 max_next = tl.load(tle.gpu.local_ptr(sM_wg1_wait.slot.sM)) 

1930 sM_wg1_reader.release(pair) 

1931 

1932 # Rescale prob0 and acc_l using the global max 

1933 final_scale = tl.math.exp2((local_max - max_next) * log_scale) 

1934 sum_exp = sum_exp * final_scale 

1935 acc_l = acc_l * final_scale[:, None] 

1936 

1937 # Send rescaled prob0 to consumer1 

1938 prob0_scaled = prob0 * final_scale[:, None] 

1939 sS0_slot = sS0_writer.acquire(pair) 

1940 tl.store(tle.gpu.local_ptr(sS0_slot.sS0), prob0_scaled.to(OUT_DTYPE)) 

1941 sS0_writer.commit(pair) 

1942 

1943 # Receive prob1 from consumer1 and accumulate k1_l 

1944 sS1_wait = sS1_reader.wait(pair) 

1945 prob1 = tl.load(tle.gpu.local_ptr(sS1_wait.slot.sS1)) 

1946 k1_l_wait = k1_l_remote_reader.wait(pair) 

1947 k1_l_blk = tl.load( 

1948 tle.gpu.local_ptr(k1_l_wait.slot.sK, (kv_rows, kv_cols_l)) 

1949 ) 

1950 acc_l = tl.dot(prob1, k1_l_blk, acc_l, out_dtype=tl.float32) 

1951 sS1_reader.release(pair) 

1952 k1_l_remote_reader.release(pair) 

1953 

1954 valid_reader.release(pair) 

1955 

1956 max_prev = max_next 

1957 

1958 # Exchange final sum_exp with consumer1 

1959 sL_wg0_slot = sL_wg0_writer.acquire(0) 

1960 tl.store(tle.gpu.local_ptr(sL_wg0_slot.sL), sum_exp) 

1961 sL_wg0_writer.commit(0) 

1962 sL_wg1_wait = sL_wg1_reader.wait(1) 

1963 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg1_wait.slot.sL)) 

1964 total_sum = sum_exp + peer_sum 

1965 sL_wg1_reader.release(1) 

1966 

1967 is_no_valid_tokens = total_sum == 0.0 

1968 inv_total_sum = tl.fdiv(1.0, total_sum) 

1969 out_l_vals = acc_l * inv_total_sum[:, None] 

1970 if HAVE_ATTN_SINK: 

1971 fin_log = ( 

1972 max_prev * log_scale + tl.math.log2(total_sum) 

1973 ) * 0.6931471805599453 

1974 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0) 

1975 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log)) 

1976 out_l_vals = out_l_vals * sink_scale[:, None] 

1977 out_l_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_l_vals) 

1978 o_l_msk = mask_h[:, None] & mask_od_l[None, :] 

1979 tl.store(q_l_smem_ptr, out_l_vals.to(OUT_DTYPE), o_l_msk) 

1980 tle.gpu.copy(q_slot.sQ_l, output_desc, [BH, DPH], [output_row, 0]) 

1981 

1982 @triton.jit 

1983 def _tle_sparse_decode_consumer1( 

1984 q_reader, 

1985 k1_r_reader, 

1986 k1_l_qk_reader, 

1987 k0_r_remote_reader, 

1988 valid_reader, 

1989 sM_wg1_writer, 

1990 sM_wg0_reader, 

1991 sS1_writer, 

1992 sS0_reader, 

1993 sL_wg1_writer, 

1994 sL_wg0_reader, 

1995 output_desc, 

1996 output_row, 

1997 lse_base, 

1998 h_base, 

1999 topk_len_ptr, 

2000 attn_sink_base, 

2001 log_scale: tl.constexpr, 

2002 D: tl.constexpr, 

2003 TD: tl.constexpr, 

2004 OUT_DTYPE: tl.constexpr, 

2005 HAVE_ATTN_SINK: tl.constexpr, 

2006 TOPK: tl.constexpr, 

2007 HAVE_TOPK_LENGTH: tl.constexpr, 

2008 HAVE_TAIL: tl.constexpr, 

2009 BK: tl.constexpr, 

2010 BH: tl.constexpr, 

2011 DPH: tl.constexpr, 

2012 TDP: tl.constexpr, 

2013 G: tl.constexpr, 

2014 ): 

2015 """Consumer 1: computes P@V_right, exchanges softmax state with consumer0.""" 

2016 topk_len = tl.load(topk_len_ptr) if HAVE_TOPK_LENGTH else TOPK 

2017 offs_h = tl.arange(0, BH) 

2018 offs_dh = tl.arange(0, DPH) 

2019 mask_h = h_base + offs_h < G 

2020 mask_od_r = DPH + offs_dh < D 

2021 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH)) 

2022 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH)) 

2023 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH)) 

2024 

2025 q_slot = q_reader.wait(0).slot 

2026 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l) 

2027 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r) 

2028 

2029 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32) 

2030 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

2031 acc_r = tl.zeros([BH, DPH], dtype=tl.float32) 

2032 

2033 NK = tl.cdiv(topk_len, BK) 

2034 NPAIRS = tl.cdiv(NK, 2) 

2035 for pair in tl.range(NPAIRS): 

2036 # Wait for k1_r data 

2037 k1_r_wait = k1_r_reader.wait(pair) 

2038 k1_r_slot = k1_r_wait.slot 

2039 

2040 q_l_blk = tl.load(q_l_smem_ptr) 

2041 q_r_blk = tl.load(q_r_smem_ptr) 

2042 k1_r_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK, (kv_rows, kv_cols_r))) 

2043 

2044 # QK for block 1 

2045 qk1 = tl.full([BH, BK], 0.0, dtype=tl.float32) 

2046 qk1 = tl.dot(q_r_blk, tl.trans(k1_r_blk), qk1, out_dtype=tl.float32) 

2047 if HAVE_TAIL: 

2048 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail)) 

2049 k1_t_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK_tail)) 

2050 qk1 = tl.dot(q_tail_blk, tl.trans(k1_t_blk), qk1, out_dtype=tl.float32) 

2051 

2052 k1_l_wait = k1_l_qk_reader.wait(pair) 

2053 k1_l_slot = k1_l_wait.slot 

2054 k1_l_blk = tl.load(tle.gpu.local_ptr(k1_l_slot.sK, (kv_rows, kv_cols_l))) 

2055 qk1 = tl.dot(q_l_blk, tl.trans(k1_l_blk), qk1, out_dtype=tl.float32) 

2056 

2057 # Get validity mask for block 1 

2058 valid_wait = valid_reader.wait(pair) 

2059 row1 = tl.full([BK], 1, dtype=tl.int32) 

2060 valid1 = ( 

2061 tl.load( 

2062 tle.gpu.local_ptr( 

2063 valid_wait.slot.is_kv_valid, (row1, tl.arange(0, BK)) 

2064 ) 

2065 ).to(tl.int32) 

2066 == 1 

2067 ) 

2068 

2069 qk1 = tl.where(valid1[None, :], qk1, float("-inf")) 

2070 valid_reader.release(pair) 

2071 

2072 # Receive candidate0 (local_max) from consumer0 

2073 sM_wg0_wait = sM_wg0_reader.wait(pair) 

2074 candidate0 = tl.load(tle.gpu.local_ptr(sM_wg0_wait.slot.sM)) 

2075 sM_wg0_reader.release(pair) 

2076 

2077 # Compute candidate1 and merge to get global max_next 

2078 candidate1 = tl.maximum(max_prev, tl.max(qk1, axis=1)) 

2079 max_next = tl.maximum(candidate1, candidate0) 

2080 

2081 # Send max_next back to consumer0 

2082 sM_wg1_slot = sM_wg1_writer.acquire(pair) 

2083 tl.store(tle.gpu.local_ptr(sM_wg1_slot.sM), max_next) 

2084 sM_wg1_writer.commit(pair) 

2085 

2086 # Compute prob1 using global max_next 

2087 alpha = tl.math.exp2((max_prev - max_next) * log_scale) 

2088 prob1 = tl.math.exp2(qk1 * log_scale - max_next[:, None] * log_scale) 

2089 sum_exp = sum_exp * alpha + tl.sum(prob1, axis=1) 

2090 acc_r = acc_r * alpha[:, None] 

2091 prob1_b = prob1.to(OUT_DTYPE) 

2092 

2093 k1_l_qk_reader.release(pair) 

2094 

2095 # Accumulate P@V_right with prob1 

2096 acc_r = tl.dot(prob1_b, k1_r_blk, acc_r, out_dtype=tl.float32) 

2097 

2098 # Send prob1 to consumer0 

2099 sS1_slot = sS1_writer.acquire(pair) 

2100 tl.store(tle.gpu.local_ptr(sS1_slot.sS1), prob1_b) 

2101 sS1_writer.commit(pair) 

2102 

2103 # Receive rescaled prob0 from consumer0 and accumulate k0_r 

2104 sS0_wait = sS0_reader.wait(pair) 

2105 prob0 = tl.load(tle.gpu.local_ptr(sS0_wait.slot.sS0)) 

2106 k0_r_wait = k0_r_remote_reader.wait(pair) 

2107 k0_r_blk = tl.load( 

2108 tle.gpu.local_ptr(k0_r_wait.slot.sK, (kv_rows, kv_cols_r)) 

2109 ) 

2110 acc_r = tl.dot(prob0, k0_r_blk, acc_r, out_dtype=tl.float32) 

2111 k1_r_reader.release(pair) 

2112 sS0_reader.release(pair) 

2113 k0_r_remote_reader.release(pair) 

2114 

2115 max_prev = max_next 

2116 

2117 # Exchange final sum_exp with consumer0 

2118 sL_wg1_slot = sL_wg1_writer.acquire(1) 

2119 tl.store(tle.gpu.local_ptr(sL_wg1_slot.sL), sum_exp) 

2120 sL_wg1_writer.commit(1) 

2121 sL_wg0_wait = sL_wg0_reader.wait(0) 

2122 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg0_wait.slot.sL)) 

2123 total_sum = sum_exp + peer_sum 

2124 sL_wg0_reader.release(0) 

2125 

2126 is_no_valid_tokens = total_sum == 0.0 

2127 inv_total_sum = tl.fdiv(1.0, total_sum) 

2128 out_r_vals = acc_r * inv_total_sum[:, None] 

2129 if HAVE_ATTN_SINK: 

2130 fin_log = ( 

2131 max_prev * log_scale + tl.math.log2(total_sum) 

2132 ) * 0.6931471805599453 

2133 sink = tl.load(attn_sink_base + h_base + offs_h, mask_h, other=0.0) 

2134 sink_scale = tl.fdiv(1.0, 1.0 + tl.math.exp(sink - fin_log)) 

2135 out_r_vals = out_r_vals * sink_scale[:, None] 

2136 out_r_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_r_vals) 

2137 o_r_msk = mask_h[:, None] & mask_od_r[None, :] 

2138 tl.store(q_r_smem_ptr, out_r_vals.to(OUT_DTYPE), o_r_msk) 

2139 tle.gpu.copy(q_slot.sQ_r, output_desc, [BH, DPH], [output_row, DPH]) 

2140 

2141 # Store LSE 

2142 lse_val = (max_prev * log_scale + tl.math.log2(total_sum)) * 0.6931471805599453 

2143 lse_val = tl.where(is_no_valid_tokens, float("inf"), lse_val) 

2144 tl.store(lse_base + offs_h, lse_val, mask=mask_h) 

2145 

2146 @triton.jit 

2147 def _tle_sparse_decode_fwd( 

2148 q_desc, 

2149 tq_desc, 

2150 output_desc, 

2151 kv_nope, 

2152 kv_scales, 

2153 kv_rope, 

2154 indices, 

2155 attn_sink, 

2156 topk_length, 

2157 sm_scale: tl.constexpr, 

2158 output, 

2159 lse, 

2160 BATCH_SQ, 

2161 HQ: tl.constexpr, 

2162 DQK: tl.constexpr, 

2163 SKV, 

2164 TOPK: tl.constexpr, 

2165 HAVE_ATTN_SINK: tl.constexpr, 

2166 HAVE_TOPK_LENGTH: tl.constexpr, 

2167 IS_FP8: tl.constexpr, 

2168 D: tl.constexpr, 

2169 TD: tl.constexpr, 

2170 DP: tl.constexpr, 

2171 TDP: tl.constexpr, 

2172 G: tl.constexpr, 

2173 RH: tl.constexpr, 

2174 HAVE_TAIL: tl.constexpr, 

2175 BK: tl.constexpr, 

2176 BH: tl.constexpr, 

2177 PAIR_BLOCKS: tl.constexpr, 

2178 stride_kvn, 

2179 stride_scales_n, 

2180 stride_rope_n, 

2181 stride_ib, 

2182 stride_isq, 

2183 ): 

2184 DPH: tl.constexpr = DP // 2 

2185 stride_lm = HQ 

2186 

2187 pid = tl.program_id(0) 

2188 programs_per_bsq: tl.constexpr = RH 

2189 i_bsq = pid // programs_per_bsq 

2190 i_rh = pid % programs_per_bsq 

2191 h_base = i_rh * BH 

2192 i_bsq64 = i_bsq.to(tl.int64) 

2193 

2194 kv_nope_base = kv_nope 

2195 kv_scales_base = kv_scales 

2196 kv_rope_base = kv_rope 

2197 t_base = indices + i_bsq64 * stride_isq 

2198 topk_len_ptr = topk_length + i_bsq64 if HAVE_TOPK_LENGTH else indices 

2199 attn_sink_base = attn_sink if HAVE_ATTN_SINK else lse 

2200 l_base = lse + i_bsq64 * stride_lm + h_base 

2201 q_row = i_bsq * HQ + h_base 

2202 _ = output 

2203 _ = BATCH_SQ 

2204 _ = DQK 

2205 

2206 sQ_l_smem = tle.gpu.alloc( 

2207 [1, BH, DPH], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2208 ) 

2209 sQ_r_smem = tle.gpu.alloc( 

2210 [1, BH, DPH], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2211 ) 

2212 if HAVE_TAIL: 

2213 sQ_tail_smem = tle.gpu.alloc( 

2214 [1, BH, TDP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2215 ) 

2216 q_pipe = tle.pipe( 

2217 capacity=1, 

2218 scope="cta", 

2219 name="decode_sQ", 

2220 readers=("wg0", "wg1"), 

2221 one_shot=True, 

2222 sQ_l=sQ_l_smem, 

2223 sQ_r=sQ_r_smem, 

2224 sQ_tail=sQ_tail_smem, 

2225 ) 

2226 else: 

2227 q_pipe = tle.pipe( 

2228 capacity=1, 

2229 scope="cta", 

2230 name="decode_sQ", 

2231 readers=("wg0", "wg1"), 

2232 one_shot=True, 

2233 sQ_l=sQ_l_smem, 

2234 sQ_r=sQ_r_smem, 

2235 ) 

2236 

2237 sK0_smem = tle.gpu.alloc( 

2238 [1, BK, DP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2239 ) 

2240 sK1_smem = tle.gpu.alloc( 

2241 [1, BK, DP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2242 ) 

2243 if HAVE_TAIL: 

2244 sK0_tail_smem = tle.gpu.alloc( 

2245 [1, BK, TDP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2246 ) 

2247 sK1_tail_smem = tle.gpu.alloc( 

2248 [1, BK, TDP], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2249 ) 

2250 sS0_smem = sK0_tail_smem 

2251 sS1_smem = sK1_tail_smem 

2252 else: 

2253 sS0_smem = tle.gpu.alloc( 

2254 [1, BH, BK], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2255 ) 

2256 sS1_smem = tle.gpu.alloc( 

2257 [1, BH, BK], dtype=tl.bfloat16, layout=None, scope=tle.gpu.smem 

2258 ) 

2259 is_kv_valid_smem = tle.gpu.alloc( 

2260 [1, 2, BK], dtype=tl.int8, layout=None, scope=tle.gpu.smem 

2261 ) 

2262 sM_smem = tle.gpu.alloc( 

2263 [1, BH], dtype=tl.float32, layout=None, scope=tle.gpu.smem 

2264 ) 

2265 sL_smem = tle.gpu.alloc( 

2266 [2, BH], dtype=tl.float32, layout=None, scope=tle.gpu.smem 

2267 ) 

2268 

2269 # Pipe definitions 

2270 if HAVE_TAIL: 

2271 k0_l_pipe = tle.pipe( 

2272 capacity=1, scope="cta", name="decode_k0_l", sK=sK0_smem 

2273 ) 

2274 k0_r_pipe = tle.pipe( 

2275 capacity=1, 

2276 scope="cta", 

2277 name="decode_k0_r", 

2278 readers=("qk", "remote"), 

2279 sK=sK0_smem, 

2280 sK_tail=sK0_tail_smem, 

2281 ) 

2282 k1_l_pipe = tle.pipe( 

2283 capacity=1, 

2284 scope="cta", 

2285 name="decode_k1_l", 

2286 readers=("qk", "remote"), 

2287 sK=sK1_smem, 

2288 ) 

2289 k1_r_pipe = tle.pipe( 

2290 capacity=1, 

2291 scope="cta", 

2292 name="decode_k1_r", 

2293 sK=sK1_smem, 

2294 sK_tail=sK1_tail_smem, 

2295 ) 

2296 else: 

2297 k0_l_pipe = tle.pipe( 

2298 capacity=1, scope="cta", name="decode_k0_l", sK=sK0_smem 

2299 ) 

2300 k0_r_pipe = tle.pipe( 

2301 capacity=1, 

2302 scope="cta", 

2303 name="decode_k0_r", 

2304 readers=("qk", "remote"), 

2305 sK=sK0_smem, 

2306 ) 

2307 k1_l_pipe = tle.pipe( 

2308 capacity=1, 

2309 scope="cta", 

2310 name="decode_k1_l", 

2311 readers=("qk", "remote"), 

2312 sK=sK1_smem, 

2313 ) 

2314 k1_r_pipe = tle.pipe( 

2315 capacity=1, scope="cta", name="decode_k1_r", sK=sK1_smem 

2316 ) 

2317 

2318 is_kv_valid_pipe = tle.pipe( 

2319 capacity=1, 

2320 scope="cta", 

2321 name="decode_valid", 

2322 readers=("wg0", "wg1"), 

2323 is_kv_valid=is_kv_valid_smem, 

2324 ) 

2325 sM_wg0_pipe = tle.pipe( 

2326 capacity=1, scope="cta", name="decode_wg0_max", sM=sM_smem 

2327 ) 

2328 sM_wg1_pipe = tle.pipe( 

2329 capacity=1, scope="cta", name="decode_wg1_max", sM=sM_smem 

2330 ) 

2331 sS0_pipe = tle.pipe(capacity=1, scope="cta", name="decode_sS0", sS0=sS0_smem) 

2332 sS1_pipe = tle.pipe(capacity=1, scope="cta", name="decode_sS1", sS1=sS1_smem) 

2333 sL_wg0_pipe = tle.pipe( 

2334 capacity=2, scope="cta", name="decode_sL_wg0", sL=sL_smem 

2335 ) 

2336 sL_wg1_pipe = tle.pipe( 

2337 capacity=2, scope="cta", name="decode_sL_wg1", sL=sL_smem 

2338 ) 

2339 

2340 log_scale: tl.constexpr = sm_scale * 1.4426950408889634 

2341 

2342 tle.gpu.warp_specialize( 

2343 [ 

2344 ( 

2345 _tle_sparse_decode_consumer0, 

2346 ( 

2347 q_pipe.writer(), 

2348 q_pipe.reader("wg0"), 

2349 q_desc, 

2350 tq_desc, 

2351 k0_l_pipe.reader(), 

2352 k0_r_pipe.reader("qk"), 

2353 k1_l_pipe.reader("remote", fields=("sK",)), 

2354 is_kv_valid_pipe.reader("wg0"), 

2355 sM_wg0_pipe.writer(), 

2356 sM_wg1_pipe.reader(), 

2357 sS0_pipe.writer(), 

2358 sS1_pipe.reader(), 

2359 sL_wg0_pipe.writer(), 

2360 sL_wg1_pipe.reader(), 

2361 output_desc, 

2362 q_row, 

2363 h_base, 

2364 topk_len_ptr, 

2365 attn_sink_base, 

2366 log_scale, 

2367 D, 

2368 TD, 

2369 tl.bfloat16, 

2370 HAVE_ATTN_SINK, 

2371 TOPK, 

2372 HAVE_TOPK_LENGTH, 

2373 HAVE_TAIL, 

2374 BK, 

2375 BH, 

2376 DPH, 

2377 TDP, 

2378 G, 

2379 ), 

2380 ), 

2381 ( 

2382 _tle_sparse_decode_consumer1, 

2383 ( 

2384 q_pipe.reader("wg1"), 

2385 k1_r_pipe.reader(), 

2386 k1_l_pipe.reader("qk"), 

2387 k0_r_pipe.reader("remote", fields=("sK",)), 

2388 is_kv_valid_pipe.reader("wg1"), 

2389 sM_wg1_pipe.writer(), 

2390 sM_wg0_pipe.reader(), 

2391 sS1_pipe.writer(), 

2392 sS0_pipe.reader(), 

2393 sL_wg1_pipe.writer(), 

2394 sL_wg0_pipe.reader(), 

2395 output_desc, 

2396 q_row, 

2397 l_base, 

2398 h_base, 

2399 topk_len_ptr, 

2400 attn_sink_base, 

2401 log_scale, 

2402 D, 

2403 TD, 

2404 tl.bfloat16, 

2405 HAVE_ATTN_SINK, 

2406 TOPK, 

2407 HAVE_TOPK_LENGTH, 

2408 HAVE_TAIL, 

2409 BK, 

2410 BH, 

2411 DPH, 

2412 TDP, 

2413 G, 

2414 ), 

2415 ), 

2416 ( 

2417 _tle_sparse_decode_producer, 

2418 ( 

2419 k0_l_pipe.writer(), 

2420 k0_r_pipe.writer(), 

2421 k1_l_pipe.writer(), 

2422 k1_r_pipe.writer(), 

2423 is_kv_valid_pipe.writer(), 

2424 kv_nope_base, 

2425 kv_scales_base, 

2426 kv_rope_base, 

2427 t_base, 

2428 topk_len_ptr, 

2429 D, 

2430 TD, 

2431 DPH, 

2432 TDP, 

2433 SKV, 

2434 TOPK, 

2435 HAVE_TOPK_LENGTH, 

2436 HAVE_TAIL, 

2437 IS_FP8, 

2438 BK, 

2439 stride_kvn, 

2440 stride_scales_n, 

2441 stride_rope_n, 

2442 ), 

2443 ), 

2444 ], 

2445 [4, 4], 

2446 [216, 72], 

2447 ) 

2448 

2449 

2450# ============================================================================ 

2451# TLE Warp Specialization path for dense decode 

2452# ============================================================================ 

2453 

2454 

2455def _tle_dense_decode_launch( 

2456 q, 

2457 kv_flat, 

2458 block_table, 

2459 cache_seqlens, 

2460 out, 

2461 lse, 

2462 batch_size, 

2463 seq_q, 

2464 num_heads_q, 

2465 head_dim_k, 

2466 head_dim_v, 

2467 page_block_size, 

2468 softmax_scale, 

2469 causal, 

2470): 

2471 """Launch TLE warp-specialized dense decode kernel.""" 

2472 from triton.tools.tensor_descriptor import TensorDescriptor 

2473 

2474 _set_triton_descriptor_allocator(q.device) 

2475 

2476 BH = TLE_DECODE_BH 

2477 BK = TLE_DECODE_BK 

2478 D = head_dim_v # 512 

2479 TD = head_dim_k - D # 64 for DQK=576, 0 for DQK=512 

2480 DP = triton.next_power_of_2(D) 

2481 DPH = DP // 2 

2482 HAVE_TAIL = TD > 0 

2483 TDP = triton.next_power_of_2(TD) if HAVE_TAIL else 1 

2484 G = num_heads_q 

2485 RH = G // BH 

2486 

2487 # Reshape q for TensorDescriptor: [batch*seq_q*HQ, DQK] 

2488 q_flat = q.reshape(batch_size * seq_q * num_heads_q, head_dim_k).contiguous() 

2489 out_flat = out.reshape(batch_size * seq_q * num_heads_q, head_dim_v) 

2490 

2491 q_desc = TensorDescriptor( 

2492 q_flat, 

2493 shape=[batch_size * seq_q * num_heads_q, head_dim_k], 

2494 strides=[head_dim_k, 1], 

2495 block_shape=[BH, DPH], 

2496 ) 

2497 if HAVE_TAIL: 

2498 tq_desc = TensorDescriptor( 

2499 q_flat, 

2500 shape=[batch_size * seq_q * num_heads_q, head_dim_k], 

2501 strides=[head_dim_k, 1], 

2502 block_shape=[BH, TDP], 

2503 ) 

2504 else: 

2505 tq_desc = q_desc 

2506 output_desc = TensorDescriptor( 

2507 out_flat, 

2508 shape=[batch_size * seq_q * num_heads_q, D], 

2509 strides=[D, 1], 

2510 block_shape=[BH, DPH], 

2511 ) 

2512 

2513 # Grid: one program per (batch*seq_q, head_block) 

2514 grid = (batch_size * seq_q * RH,) 

2515 

2516 # Reshape block_table and cache_seqlens for kernel 

2517 block_table_flat = block_table.reshape(batch_size * seq_q, -1).contiguous() 

2518 cache_seqlens_flat = cache_seqlens.reshape(batch_size * seq_q).contiguous() 

2519 

2520 _tle_dense_decode_fwd[grid]( 

2521 q_desc, 

2522 tq_desc, 

2523 output_desc, 

2524 kv_flat, 

2525 block_table_flat, 

2526 cache_seqlens_flat, 

2527 softmax_scale, 

2528 out_flat, 

2529 lse.reshape(batch_size * seq_q, num_heads_q).contiguous(), 

2530 batch_size * seq_q, 

2531 num_heads_q, 

2532 head_dim_k, 

2533 page_block_size, 

2534 causal, 

2535 D, 

2536 TD, 

2537 DP, 

2538 TDP, 

2539 G, 

2540 RH, 

2541 HAVE_TAIL, 

2542 BK, 

2543 BH, 

2544 TLE_DECODE_PAIR_BLOCKS, 

2545 kv_flat.stride(0), 

2546 block_table_flat.stride(0), 

2547 num_warps=TLE_DECODE_WORKER_NUM_WARPS, 

2548 num_stages=1, 

2549 ) 

2550 

2551 

2552if HAS_TLE: 

2553 

2554 @triton.jit 

2555 def _tle_dense_decode_producer( 

2556 k0_l_writer, 

2557 k0_r_writer, 

2558 k1_l_writer, 

2559 k1_r_writer, 

2560 is_kv_valid_writer, 

2561 kv_base, 

2562 block_table_ptr, 

2563 cache_seqlen, 

2564 D: tl.constexpr, 

2565 TD: tl.constexpr, 

2566 DPH: tl.constexpr, 

2567 TDP: tl.constexpr, 

2568 PAGE_SIZE: tl.constexpr, 

2569 HAVE_TAIL: tl.constexpr, 

2570 BK: tl.constexpr, 

2571 stride_kvn: tl.constexpr, 

2572 stride_bt: tl.constexpr, 

2573 ): 

2574 """ 

2575 Producer: Load KV pages from paged cache to shared memory. 

2576 Key difference from sparse: pages are contiguous, enabling efficient loads. 

2577 """ 

2578 num_pages = tl.cdiv(cache_seqlen, PAGE_SIZE) 

2579 NPAIRS = tl.cdiv(num_pages, 2) 

2580 

2581 offs_t = tl.arange(0, BK) 

2582 offs_tile = tl.arange(0, 64) 

2583 kv_tile_rows = tl.broadcast_to(offs_t[:, None], (BK, 64)) 

2584 

2585 for pair in tl.range(NPAIRS): 

2586 page_idx0 = pair * 2 

2587 page_idx1 = page_idx0 + 1 

2588 

2589 # Load physical page numbers from block_table (stride within row is 1) 

2590 phys_page0 = tl.load( 

2591 block_table_ptr + page_idx0, 

2592 mask=page_idx0 < tl.cdiv(cache_seqlen, PAGE_SIZE), 

2593 other=0, 

2594 ) 

2595 phys_page1 = tl.load( 

2596 block_table_ptr + page_idx1, 

2597 mask=page_idx1 < tl.cdiv(cache_seqlen, PAGE_SIZE), 

2598 other=0, 

2599 ) 

2600 

2601 # Compute base addresses for contiguous page data 

2602 base0 = phys_page0.to(tl.int64) * PAGE_SIZE * stride_kvn 

2603 base1 = phys_page1.to(tl.int64) * PAGE_SIZE * stride_kvn 

2604 

2605 # Validity masks for partial last page 

2606 t_offs0 = page_idx0 * PAGE_SIZE + offs_t 

2607 t_offs1 = page_idx1 * PAGE_SIZE + offs_t 

2608 valid0 = t_offs0 < cache_seqlen 

2609 valid1 = t_offs1 < cache_seqlen 

2610 

2611 # Store validity masks 

2612 valid_slot = is_kv_valid_writer.acquire(pair) 

2613 valid_row0 = tl.full([BK], 0, dtype=tl.int32) 

2614 valid_row1 = tl.full([BK], 1, dtype=tl.int32) 

2615 tl.store( 

2616 tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row0, offs_t)), 

2617 valid0.to(tl.int8), 

2618 ) 

2619 tl.store( 

2620 tle.gpu.local_ptr(valid_slot.is_kv_valid, (valid_row1, offs_t)), 

2621 valid1.to(tl.int8), 

2622 ) 

2623 is_kv_valid_writer.commit(pair) 

2624 

2625 # Load page0 left half (NoPE [:256]) 

2626 k0_l_slot = k0_l_writer.acquire(pair) 

2627 for tile in tl.static_range(0, DPH, 64): 

2628 k_cols = tile + offs_tile 

2629 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

2630 k0_l_ptr = ( 

2631 kv_base + base0 + offs_t[:, None] * stride_kvn + k_cols[None, :] 

2632 ) 

2633 k0_l_msk = valid0[:, None] & (k_cols < D)[None, :] 

2634 k0_l_blk = tl.load( 

2635 k0_l_ptr, mask=k0_l_msk, other=0.0, eviction_policy="evict_last" 

2636 ) 

2637 tl.store( 

2638 tle.gpu.local_ptr(k0_l_slot.sK, (kv_tile_rows, k_cols_b)), 

2639 k0_l_blk, 

2640 mask=k0_l_msk, 

2641 ) 

2642 k0_l_writer.commit(pair) 

2643 

2644 # Load page1 right half (NoPE [256:512]) 

2645 k1_r_slot = k1_r_writer.acquire(pair) 

2646 for tile in tl.static_range(0, DPH, 64): 

2647 k_cols = DPH + tile + offs_tile 

2648 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

2649 k1_r_ptr = ( 

2650 kv_base + base1 + offs_t[:, None] * stride_kvn + k_cols[None, :] 

2651 ) 

2652 k1_r_msk = valid1[:, None] & (k_cols < D)[None, :] 

2653 k1_r_blk = tl.load( 

2654 k1_r_ptr, mask=k1_r_msk, other=0.0, eviction_policy="evict_last" 

2655 ) 

2656 tl.store( 

2657 tle.gpu.local_ptr(k1_r_slot.sK, (kv_tile_rows, k_cols_b)), 

2658 k1_r_blk, 

2659 mask=k1_r_msk, 

2660 ) 

2661 if HAVE_TAIL: 

2662 offs_td = tl.arange(0, TDP) 

2663 k1_r_tail_ptr = ( 

2664 kv_base 

2665 + base1 

2666 + offs_t[:, None] * stride_kvn 

2667 + (D + offs_td)[None, :] 

2668 ) 

2669 k1_r_tail_msk = valid1[:, None] & (offs_td < TD)[None, :] 

2670 k1_r_tail_blk = tl.load( 

2671 k1_r_tail_ptr, 

2672 mask=k1_r_tail_msk, 

2673 other=0.0, 

2674 eviction_policy="evict_last", 

2675 ) 

2676 tl.store( 

2677 tle.gpu.local_ptr(k1_r_slot.sK_tail), 

2678 k1_r_tail_blk, 

2679 mask=k1_r_tail_msk, 

2680 ) 

2681 k1_r_writer.commit(pair) 

2682 

2683 # Load page0 right half (NoPE [256:512]) 

2684 k0_r_slot = k0_r_writer.acquire(pair) 

2685 for tile in tl.static_range(0, DPH, 64): 

2686 k_cols = DPH + tile + offs_tile 

2687 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

2688 k0_r_ptr = ( 

2689 kv_base + base0 + offs_t[:, None] * stride_kvn + k_cols[None, :] 

2690 ) 

2691 k0_r_msk = valid0[:, None] & (k_cols < D)[None, :] 

2692 k0_r_blk = tl.load( 

2693 k0_r_ptr, mask=k0_r_msk, other=0.0, eviction_policy="evict_last" 

2694 ) 

2695 tl.store( 

2696 tle.gpu.local_ptr(k0_r_slot.sK, (kv_tile_rows, k_cols_b)), 

2697 k0_r_blk, 

2698 mask=k0_r_msk, 

2699 ) 

2700 if HAVE_TAIL: 

2701 offs_td = tl.arange(0, TDP) 

2702 k0_r_tail_ptr = ( 

2703 kv_base 

2704 + base0 

2705 + offs_t[:, None] * stride_kvn 

2706 + (D + offs_td)[None, :] 

2707 ) 

2708 k0_r_tail_msk = valid0[:, None] & (offs_td < TD)[None, :] 

2709 k0_r_tail_blk = tl.load( 

2710 k0_r_tail_ptr, 

2711 mask=k0_r_tail_msk, 

2712 other=0.0, 

2713 eviction_policy="evict_last", 

2714 ) 

2715 tl.store( 

2716 tle.gpu.local_ptr(k0_r_slot.sK_tail), 

2717 k0_r_tail_blk, 

2718 mask=k0_r_tail_msk, 

2719 ) 

2720 k0_r_writer.commit(pair) 

2721 

2722 # Load page1 left half (NoPE [:256]) 

2723 k1_l_slot = k1_l_writer.acquire(pair) 

2724 for tile in tl.static_range(0, DPH, 64): 

2725 k_cols = tile + offs_tile 

2726 k_cols_b = tl.broadcast_to(k_cols[None, :], (BK, 64)) 

2727 k1_l_ptr = ( 

2728 kv_base + base1 + offs_t[:, None] * stride_kvn + k_cols[None, :] 

2729 ) 

2730 k1_l_msk = valid1[:, None] & (k_cols < D)[None, :] 

2731 k1_l_blk = tl.load( 

2732 k1_l_ptr, mask=k1_l_msk, other=0.0, eviction_policy="evict_last" 

2733 ) 

2734 tl.store( 

2735 tle.gpu.local_ptr(k1_l_slot.sK, (kv_tile_rows, k_cols_b)), 

2736 k1_l_blk, 

2737 mask=k1_l_msk, 

2738 ) 

2739 k1_l_writer.commit(pair) 

2740 

2741 @triton.jit 

2742 def _tle_dense_decode_consumer0( 

2743 q_writer, 

2744 q_reader, 

2745 q_desc, 

2746 tq_desc, 

2747 k0_l_reader, 

2748 k0_r_qk_reader, 

2749 k1_l_remote_reader, 

2750 is_kv_valid_reader, 

2751 sM_wg0_writer, 

2752 sM_wg1_reader, 

2753 sS0_writer, 

2754 sS1_reader, 

2755 sL_wg0_writer, 

2756 sL_wg1_reader, 

2757 output_desc, 

2758 output_row, 

2759 h_base, 

2760 cache_seqlen, 

2761 log_scale: tl.constexpr, 

2762 D: tl.constexpr, 

2763 TD: tl.constexpr, 

2764 OUT_DTYPE: tl.constexpr, 

2765 HAVE_TAIL: tl.constexpr, 

2766 BK: tl.constexpr, 

2767 BH: tl.constexpr, 

2768 DPH: tl.constexpr, 

2769 TDP: tl.constexpr, 

2770 G: tl.constexpr, 

2771 PAGE_SIZE: tl.constexpr, 

2772 ): 

2773 """Consumer 0: QK^T left half + softmax + P@V_left.""" 

2774 offs_h = tl.arange(0, BH) 

2775 offs_dh = tl.arange(0, DPH) 

2776 mask_h = h_base + offs_h < G 

2777 mask_od_l = offs_dh < D 

2778 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH)) 

2779 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH)) 

2780 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH)) 

2781 

2782 # Load Q once 

2783 q_write_slot = q_writer.acquire(0) 

2784 tle.gpu.copy(q_desc, q_write_slot.sQ_l, [BH, DPH], [output_row, 0]) 

2785 tle.gpu.copy(q_desc, q_write_slot.sQ_r, [BH, DPH], [output_row, DPH]) 

2786 if HAVE_TAIL: 

2787 tle.gpu.copy(tq_desc, q_write_slot.sQ_tail, [BH, TDP], [output_row, D]) 

2788 q_writer.commit(0) 

2789 

2790 q_slot = q_reader.wait(0).slot 

2791 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l) 

2792 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r) 

2793 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32) 

2794 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

2795 acc_l = tl.zeros([BH, DPH], dtype=tl.float32) 

2796 

2797 num_pages = tl.cdiv(cache_seqlen, PAGE_SIZE) 

2798 NPAIRS = tl.cdiv(num_pages, 2) 

2799 

2800 for pair in tl.range(NPAIRS): 

2801 # Compute QK^T for page0 

2802 k0_l_wait = k0_l_reader.wait(pair) 

2803 k0_l_slot = k0_l_wait.slot 

2804 

2805 q_l_blk = tl.load(q_l_smem_ptr) 

2806 q_r_blk = tl.load(q_r_smem_ptr) 

2807 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l))) 

2808 

2809 qk0 = tl.full([BH, BK], 0.0, dtype=tl.float32) 

2810 qk0 = tl.dot(q_l_blk, tl.trans(k0_l_blk), qk0, out_dtype=tl.float32) 

2811 

2812 k0_r_wait = k0_r_qk_reader.wait(pair) 

2813 k0_r_slot = k0_r_wait.slot 

2814 k0_r_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK, (kv_rows, kv_cols_r))) 

2815 qk0 = tl.dot(q_r_blk, tl.trans(k0_r_blk), qk0, out_dtype=tl.float32) 

2816 if HAVE_TAIL: 

2817 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail)) 

2818 k0_t_blk = tl.load(tle.gpu.local_ptr(k0_r_slot.sK_tail)) 

2819 qk0 = tl.dot(q_tail_blk, tl.trans(k0_t_blk), qk0, out_dtype=tl.float32) 

2820 

2821 # Apply validity mask 

2822 valid_wait = is_kv_valid_reader.wait(pair) 

2823 row0 = tl.full([BK], 0, dtype=tl.int32) 

2824 valid0 = ( 

2825 tl.load( 

2826 tle.gpu.local_ptr( 

2827 valid_wait.slot.is_kv_valid, (row0, tl.arange(0, BK)) 

2828 ) 

2829 ) 

2830 != 0 

2831 ) 

2832 qk0 = tl.where(valid0[None, :], qk0, float("-inf")) 

2833 is_kv_valid_reader.release(pair) 

2834 

2835 # Online softmax 

2836 local_max = tl.maximum(max_prev, tl.max(qk0, axis=1)) 

2837 alpha = tl.math.exp2((max_prev - local_max) * log_scale) 

2838 prob0 = tl.math.exp2(qk0 * log_scale - local_max[:, None] * log_scale) 

2839 sum_exp = sum_exp * alpha + tl.sum(prob0, axis=1) 

2840 acc_l = acc_l * alpha[:, None] 

2841 prob0_b = prob0.to(OUT_DTYPE) 

2842 

2843 # Send local max to consumer1 

2844 sM_wg0_slot = sM_wg0_writer.acquire(pair) 

2845 tl.store(tle.gpu.local_ptr(sM_wg0_slot.sM), local_max) 

2846 sM_wg0_writer.commit(pair) 

2847 

2848 # Accumulate P@V_left 

2849 k0_l_blk = tl.load(tle.gpu.local_ptr(k0_l_slot.sK, (kv_rows, kv_cols_l))) 

2850 acc_l = tl.dot(prob0_b, k0_l_blk, acc_l, out_dtype=tl.float32) 

2851 k0_l_reader.release(pair) 

2852 k0_r_qk_reader.release(pair) 

2853 

2854 # Receive final max from consumer1 

2855 sM_wg1_wait = sM_wg1_reader.wait(pair) 

2856 max_next = tl.load(tle.gpu.local_ptr(sM_wg1_wait.slot.sM)) 

2857 sM_wg1_reader.release(pair) 

2858 

2859 # Rescale with final max 

2860 final_scale = tl.math.exp2((local_max - max_next) * log_scale) 

2861 sum_exp = sum_exp * final_scale 

2862 acc_l = acc_l * final_scale[:, None] 

2863 

2864 # Send rescaled prob0 to consumer1 

2865 prob0_scaled = prob0 * final_scale[:, None] 

2866 sS0_slot = sS0_writer.acquire(pair) 

2867 tl.store(tle.gpu.local_ptr(sS0_slot.sS0), prob0_scaled.to(OUT_DTYPE)) 

2868 sS0_writer.commit(pair) 

2869 

2870 # Receive prob1 and accumulate P@V_left from page1 

2871 sS1_wait = sS1_reader.wait(pair) 

2872 prob1 = tl.load(tle.gpu.local_ptr(sS1_wait.slot.sS1)) 

2873 k1_l_wait = k1_l_remote_reader.wait(pair) 

2874 k1_l_blk = tl.load( 

2875 tle.gpu.local_ptr(k1_l_wait.slot.sK, (kv_rows, kv_cols_l)) 

2876 ) 

2877 acc_l = tl.dot(prob1, k1_l_blk, acc_l, out_dtype=tl.float32) 

2878 sS1_reader.release(pair) 

2879 k1_l_remote_reader.release(pair) 

2880 

2881 max_prev = max_next 

2882 

2883 # Exchange sum_exp with consumer1 

2884 sL_wg0_slot = sL_wg0_writer.acquire(0) 

2885 tl.store(tle.gpu.local_ptr(sL_wg0_slot.sL), sum_exp) 

2886 sL_wg0_writer.commit(0) 

2887 sL_wg1_wait = sL_wg1_reader.wait(1) 

2888 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg1_wait.slot.sL)) 

2889 total_sum = sum_exp + peer_sum 

2890 sL_wg1_reader.release(1) 

2891 

2892 # Normalize and write output left half 

2893 is_no_valid_tokens = total_sum == 0.0 

2894 inv_total_sum = tl.fdiv(1.0, total_sum) 

2895 out_l_vals = acc_l * inv_total_sum[:, None] 

2896 out_l_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_l_vals) 

2897 o_l_msk = mask_h[:, None] & mask_od_l[None, :] 

2898 tl.store(q_l_smem_ptr, out_l_vals.to(OUT_DTYPE), o_l_msk) 

2899 tle.gpu.copy(q_slot.sQ_l, output_desc, [BH, DPH], [output_row, 0]) 

2900 

2901 @triton.jit 

2902 def _tle_dense_decode_consumer1( 

2903 q_reader, 

2904 k1_r_reader, 

2905 k1_l_qk_reader, 

2906 k0_r_remote_reader, 

2907 is_kv_valid_reader, 

2908 sM_wg1_writer, 

2909 sM_wg0_reader, 

2910 sS1_writer, 

2911 sS0_reader, 

2912 sL_wg1_writer, 

2913 sL_wg0_reader, 

2914 final_lse_smem, 

2915 output_desc, 

2916 output_row, 

2917 l_base, 

2918 h_base, 

2919 cache_seqlen, 

2920 log_scale: tl.constexpr, 

2921 D: tl.constexpr, 

2922 TD: tl.constexpr, 

2923 OUT_DTYPE: tl.constexpr, 

2924 HAVE_TAIL: tl.constexpr, 

2925 BK: tl.constexpr, 

2926 BH: tl.constexpr, 

2927 DPH: tl.constexpr, 

2928 TDP: tl.constexpr, 

2929 G: tl.constexpr, 

2930 PAGE_SIZE: tl.constexpr, 

2931 ): 

2932 """Consumer 1: QK^T right half + P@V_right.""" 

2933 offs_h = tl.arange(0, BH) 

2934 offs_dh = tl.arange(0, DPH) 

2935 mask_h = h_base + offs_h < G 

2936 mask_od_r = DPH + offs_dh < D 

2937 kv_rows = tl.broadcast_to(tl.arange(0, BK)[:, None], (BK, DPH)) 

2938 kv_cols_l = tl.broadcast_to(offs_dh[None, :], (BK, DPH)) 

2939 kv_cols_r = tl.broadcast_to((DPH + offs_dh)[None, :], (BK, DPH)) 

2940 

2941 q_slot = q_reader.wait(0).slot 

2942 q_l_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_l) 

2943 q_r_smem_ptr = tle.gpu.local_ptr(q_slot.sQ_r) 

2944 max_prev = tl.full([BH], -1.0e30, dtype=tl.float32) 

2945 sum_exp = tl.full([BH], 0.0, dtype=tl.float32) 

2946 acc_r = tl.zeros([BH, DPH], dtype=tl.float32) 

2947 

2948 num_pages = tl.cdiv(cache_seqlen, PAGE_SIZE) 

2949 NPAIRS = tl.cdiv(num_pages, 2) 

2950 

2951 for pair in tl.range(NPAIRS): 

2952 # Compute QK^T for page1 

2953 k1_r_wait = k1_r_reader.wait(pair) 

2954 k1_r_slot = k1_r_wait.slot 

2955 

2956 q_l_blk = tl.load(q_l_smem_ptr) 

2957 q_r_blk = tl.load(q_r_smem_ptr) 

2958 k1_r_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK, (kv_rows, kv_cols_r))) 

2959 

2960 qk1 = tl.full([BH, BK], 0.0, dtype=tl.float32) 

2961 qk1 = tl.dot(q_r_blk, tl.trans(k1_r_blk), qk1, out_dtype=tl.float32) 

2962 if HAVE_TAIL: 

2963 q_tail_blk = tl.load(tle.gpu.local_ptr(q_slot.sQ_tail)) 

2964 k1_t_blk = tl.load(tle.gpu.local_ptr(k1_r_slot.sK_tail)) 

2965 qk1 = tl.dot(q_tail_blk, tl.trans(k1_t_blk), qk1, out_dtype=tl.float32) 

2966 

2967 k1_l_wait = k1_l_qk_reader.wait(pair) 

2968 k1_l_slot = k1_l_wait.slot 

2969 k1_l_blk = tl.load(tle.gpu.local_ptr(k1_l_slot.sK, (kv_rows, kv_cols_l))) 

2970 qk1 = tl.dot(q_l_blk, tl.trans(k1_l_blk), qk1, out_dtype=tl.float32) 

2971 

2972 # Apply validity mask 

2973 valid_wait = is_kv_valid_reader.wait(pair) 

2974 row1 = tl.full([BK], 1, dtype=tl.int32) 

2975 valid1 = ( 

2976 tl.load( 

2977 tle.gpu.local_ptr( 

2978 valid_wait.slot.is_kv_valid, (row1, tl.arange(0, BK)) 

2979 ) 

2980 ) 

2981 != 0 

2982 ) 

2983 qk1 = tl.where(valid1[None, :], qk1, float("-inf")) 

2984 is_kv_valid_reader.release(pair) 

2985 

2986 # Receive candidate0 from consumer0 

2987 sM_wg0_wait = sM_wg0_reader.wait(pair) 

2988 candidate0 = tl.load(tle.gpu.local_ptr(sM_wg0_wait.slot.sM)) 

2989 sM_wg0_reader.release(pair) 

2990 

2991 # Compute final max 

2992 candidate1 = tl.maximum(max_prev, tl.max(qk1, axis=1)) 

2993 max_next = tl.maximum(candidate1, candidate0) 

2994 sM_wg1_slot = sM_wg1_writer.acquire(pair) 

2995 tl.store(tle.gpu.local_ptr(sM_wg1_slot.sM), max_next) 

2996 sM_wg1_writer.commit(pair) 

2997 

2998 # Online softmax 

2999 alpha = tl.math.exp2((max_prev - max_next) * log_scale) 

3000 prob1 = tl.math.exp2(qk1 * log_scale - max_next[:, None] * log_scale) 

3001 sum_exp = sum_exp * alpha + tl.sum(prob1, axis=1) 

3002 acc_r = acc_r * alpha[:, None] 

3003 prob1_b = prob1.to(OUT_DTYPE) 

3004 

3005 k1_l_qk_reader.release(pair) 

3006 

3007 # Accumulate P@V_right from page1 

3008 acc_r = tl.dot(prob1_b, k1_r_blk, acc_r, out_dtype=tl.float32) 

3009 

3010 # Send prob1 to consumer0 

3011 sS1_slot = sS1_writer.acquire(pair) 

3012 tl.store(tle.gpu.local_ptr(sS1_slot.sS1), prob1_b) 

3013 sS1_writer.commit(pair) 

3014 

3015 # Receive prob0 and accumulate P@V_right from page0 

3016 sS0_wait = sS0_reader.wait(pair) 

3017 prob0 = tl.load(tle.gpu.local_ptr(sS0_wait.slot.sS0)) 

3018 k0_r_wait = k0_r_remote_reader.wait(pair) 

3019 k0_r_blk = tl.load( 

3020 tle.gpu.local_ptr(k0_r_wait.slot.sK, (kv_rows, kv_cols_r)) 

3021 ) 

3022 acc_r = tl.dot(prob0, k0_r_blk, acc_r, out_dtype=tl.float32) 

3023 k1_r_reader.release(pair) 

3024 sS0_reader.release(pair) 

3025 k0_r_remote_reader.release(pair) 

3026 max_prev = max_next 

3027 

3028 # Exchange sum_exp with consumer0 

3029 sL_wg1_slot = sL_wg1_writer.acquire(1) 

3030 tl.store(tle.gpu.local_ptr(sL_wg1_slot.sL), sum_exp) 

3031 sL_wg1_writer.commit(1) 

3032 sL_wg0_wait = sL_wg0_reader.wait(0) 

3033 peer_sum = tl.load(tle.gpu.local_ptr(sL_wg0_wait.slot.sL)) 

3034 total_sum = sum_exp + peer_sum 

3035 sL_wg0_reader.release(0) 

3036 

3037 # Normalize and write output right half 

3038 is_no_valid_tokens = total_sum == 0.0 

3039 inv_total_sum = tl.fdiv(1.0, total_sum) 

3040 out_r_vals = acc_r * inv_total_sum[:, None] 

3041 final_max_logits_log2 = max_prev * log_scale 

3042 fin_log = (final_max_logits_log2 + tl.math.log2(total_sum)) * 0.6931471805599453 

3043 out_r_vals = tl.where(is_no_valid_tokens[:, None], 0.0, out_r_vals) 

3044 o_r_msk = mask_h[:, None] & mask_od_r[None, :] 

3045 tl.store(q_r_smem_ptr, out_r_vals.to(OUT_DTYPE), o_r_msk) 

3046 tle.gpu.copy(q_slot.sQ_r, output_desc, [BH, DPH], [output_row, DPH]) 

3047 

3048 # Write LSE 

3049 fin_log = tl.where(is_no_valid_tokens, float("inf"), fin_log) 

3050 tl.store(tle.gpu.local_ptr(final_lse_smem), fin_log, mask_h) 

3051 fin_log = tl.load(tle.gpu.local_ptr(final_lse_smem), mask_h, other=float("inf")) 

3052 tl.store(l_base + offs_h, fin_log, mask_h) 

3053 

3054 @triton.jit 

3055 def _tle_dense_decode_fwd( 

3056 q_desc, 

3057 tq_desc, 

3058 output_desc, 

3059 kv_cache, 

3060 block_table, 

3061 cache_seqlens, 

3062 sm_scale: tl.constexpr, 

3063 output, 

3064 lse, 

3065 BS, 

3066 G: tl.constexpr, 

3067 DQK: tl.constexpr, 

3068 PAGE_SIZE: tl.constexpr, 

3069 CAUSAL: tl.constexpr, 

3070 D: tl.constexpr, 

3071 TD: tl.constexpr, 

3072 DP: tl.constexpr, 

3073 TDP: tl.constexpr, 

3074 H: tl.constexpr, 

3075 RH: tl.constexpr, 

3076 HAVE_TAIL: tl.constexpr, 

3077 BK: tl.constexpr, 

3078 BH: tl.constexpr, 

3079 PAIR_BLOCKS: tl.constexpr, 

3080 stride_kvn: tl.constexpr, 

3081 stride_bt: tl.constexpr, 

3082 ): 

3083 DPH: tl.constexpr = DP // 2 

3084 stride_lm = G 

3085 

3086 pid = tl.program_id(0) 

3087 i_sq = pid // RH 

3088 i_rh = pid % RH 

3089 h_base = i_rh * BH 

3090 output_row = i_sq * G + h_base 

3091 i_sq64 = i_sq.to(tl.int64) 

3092 

3093 cache_seqlen = tl.load(cache_seqlens + i_sq64) 

3094 block_table_ptr = block_table + i_sq64 * stride_bt 

3095 kv_base = kv_cache 

3096 l_base = lse + i_sq64 * stride_lm + h_base 

3097 _ = output 

3098 _ = BS 

3099 _ = DQK 

3100 _ = CAUSAL 

3101 

3102 sQ_l_smem = tle.gpu.alloc( 

3103 [1, BH, DPH], 

3104 dtype=kv_cache.dtype.element_ty, 

3105 layout=None, 

3106 scope=tle.gpu.smem, 

3107 ) 

3108 sQ_r_smem = tle.gpu.alloc( 

3109 [1, BH, DPH], 

3110 dtype=kv_cache.dtype.element_ty, 

3111 layout=None, 

3112 scope=tle.gpu.smem, 

3113 ) 

3114 if HAVE_TAIL: 

3115 sQ_tail_smem = tle.gpu.alloc( 

3116 [1, BH, TDP], 

3117 dtype=kv_cache.dtype.element_ty, 

3118 layout=None, 

3119 scope=tle.gpu.smem, 

3120 ) 

3121 q_pipe = tle.pipe( 

3122 capacity=1, 

3123 scope="cta", 

3124 name="dense_sQ", 

3125 readers=("wg0", "wg1"), 

3126 one_shot=True, 

3127 sQ_l=sQ_l_smem, 

3128 sQ_r=sQ_r_smem, 

3129 sQ_tail=sQ_tail_smem, 

3130 ) 

3131 else: 

3132 q_pipe = tle.pipe( 

3133 capacity=1, 

3134 scope="cta", 

3135 name="dense_sQ", 

3136 readers=("wg0", "wg1"), 

3137 one_shot=True, 

3138 sQ_l=sQ_l_smem, 

3139 sQ_r=sQ_r_smem, 

3140 ) 

3141 

3142 sK0_smem = tle.gpu.alloc( 

3143 [1, BK, DP], 

3144 dtype=kv_cache.dtype.element_ty, 

3145 layout=None, 

3146 scope=tle.gpu.smem, 

3147 ) 

3148 sK1_smem = tle.gpu.alloc( 

3149 [1, BK, DP], 

3150 dtype=kv_cache.dtype.element_ty, 

3151 layout=None, 

3152 scope=tle.gpu.smem, 

3153 ) 

3154 if HAVE_TAIL: 

3155 sK0_tail_smem = tle.gpu.alloc( 

3156 [1, BK, TDP], 

3157 dtype=kv_cache.dtype.element_ty, 

3158 layout=None, 

3159 scope=tle.gpu.smem, 

3160 ) 

3161 sK1_tail_smem = tle.gpu.alloc( 

3162 [1, BK, TDP], 

3163 dtype=kv_cache.dtype.element_ty, 

3164 layout=None, 

3165 scope=tle.gpu.smem, 

3166 ) 

3167 sS0_smem = sK0_tail_smem 

3168 else: 

3169 sS0_smem = tle.gpu.alloc( 

3170 [1, BH, BK], 

3171 dtype=kv_cache.dtype.element_ty, 

3172 layout=None, 

3173 scope=tle.gpu.smem, 

3174 ) 

3175 

3176 is_kv_valid_smem = tle.gpu.alloc( 

3177 [1, PAIR_BLOCKS, BK], 

3178 dtype=tl.int8, 

3179 layout=None, 

3180 scope=tle.gpu.smem, 

3181 nv_mma_shared_layout=False, 

3182 ) 

3183 

3184 k0_l_pipe = tle.pipe(capacity=1, scope="cta", name="dense_sK0_l", sK=sK0_smem) 

3185 if HAVE_TAIL: 

3186 k0_r_pipe = tle.pipe( 

3187 capacity=1, 

3188 scope="cta", 

3189 name="dense_sK0_r", 

3190 readers=("qk", "remote"), 

3191 sK=sK0_smem, 

3192 sK_tail=sK0_tail_smem, 

3193 ) 

3194 else: 

3195 k0_r_pipe = tle.pipe( 

3196 capacity=1, 

3197 scope="cta", 

3198 name="dense_sK0_r", 

3199 readers=("qk", "remote"), 

3200 sK=sK0_smem, 

3201 ) 

3202 k1_l_pipe = tle.pipe( 

3203 capacity=1, 

3204 scope="cta", 

3205 name="dense_sK1_l", 

3206 readers=("qk", "remote"), 

3207 sK=sK1_smem, 

3208 ) 

3209 if HAVE_TAIL: 

3210 k1_r_pipe = tle.pipe( 

3211 capacity=1, 

3212 scope="cta", 

3213 name="dense_sK1_r", 

3214 sK=sK1_smem, 

3215 sK_tail=sK1_tail_smem, 

3216 ) 

3217 else: 

3218 k1_r_pipe = tle.pipe( 

3219 capacity=1, scope="cta", name="dense_sK1_r", sK=sK1_smem 

3220 ) 

3221 

3222 is_kv_valid_pipe = tle.pipe( 

3223 capacity=1, 

3224 scope="cta", 

3225 name="dense_is_kv_valid", 

3226 readers=("wg0", "wg1"), 

3227 is_kv_valid=is_kv_valid_smem, 

3228 ) 

3229 

3230 sM_smem = tle.gpu.alloc( 

3231 [1, BH], 

3232 dtype=tl.float32, 

3233 layout=None, 

3234 scope=tle.gpu.smem, 

3235 nv_mma_shared_layout=False, 

3236 ) 

3237 sS1_smem = tle.gpu.alloc( 

3238 [1, BH, BK], 

3239 dtype=kv_cache.dtype.element_ty, 

3240 layout=None, 

3241 scope=tle.gpu.smem, 

3242 ) 

3243 sL_smem = tle.gpu.alloc( 

3244 [2, BH], 

3245 dtype=tl.float32, 

3246 layout=None, 

3247 scope=tle.gpu.smem, 

3248 nv_mma_shared_layout=False, 

3249 ) 

3250 final_lse_smem = tle.gpu.alloc( 

3251 [BH], 

3252 dtype=tl.float32, 

3253 layout=None, 

3254 scope=tle.gpu.smem, 

3255 nv_mma_shared_layout=False, 

3256 ) 

3257 

3258 sM_wg0_pipe = tle.pipe( 

3259 capacity=1, scope="cta", name="dense_wg0_max", sM=sM_smem 

3260 ) 

3261 sM_wg1_pipe = tle.pipe( 

3262 capacity=1, scope="cta", name="dense_wg1_max", sM=sM_smem 

3263 ) 

3264 sS0_pipe = tle.pipe(capacity=1, scope="cta", name="dense_sS0", sS0=sS0_smem) 

3265 sS1_pipe = tle.pipe(capacity=1, scope="cta", name="dense_sS1", sS1=sS1_smem) 

3266 sL_wg0_pipe = tle.pipe(capacity=2, scope="cta", name="dense_sL_wg0", sL=sL_smem) 

3267 sL_wg1_pipe = tle.pipe(capacity=2, scope="cta", name="dense_sL_wg1", sL=sL_smem) 

3268 

3269 log_scale: tl.constexpr = sm_scale * 1.4426950408889634 

3270 

3271 tle.gpu.warp_specialize( 

3272 [ 

3273 ( 

3274 _tle_dense_decode_consumer0, 

3275 ( 

3276 q_pipe.writer(), 

3277 q_pipe.reader("wg0"), 

3278 q_desc, 

3279 tq_desc, 

3280 k0_l_pipe.reader(), 

3281 k0_r_pipe.reader("qk"), 

3282 k1_l_pipe.reader("remote", fields=("sK",)), 

3283 is_kv_valid_pipe.reader("wg0"), 

3284 sM_wg0_pipe.writer(), 

3285 sM_wg1_pipe.reader(), 

3286 sS0_pipe.writer(), 

3287 sS1_pipe.reader(), 

3288 sL_wg0_pipe.writer(), 

3289 sL_wg1_pipe.reader(), 

3290 output_desc, 

3291 output_row, 

3292 h_base, 

3293 cache_seqlen, 

3294 log_scale, 

3295 D, 

3296 TD, 

3297 kv_cache.dtype.element_ty, 

3298 HAVE_TAIL, 

3299 BK, 

3300 BH, 

3301 DPH, 

3302 TDP, 

3303 G, 

3304 PAGE_SIZE, 

3305 ), 

3306 ), 

3307 ( 

3308 _tle_dense_decode_consumer1, 

3309 ( 

3310 q_pipe.reader("wg1"), 

3311 k1_r_pipe.reader(), 

3312 k1_l_pipe.reader("qk"), 

3313 k0_r_pipe.reader("remote", fields=("sK",)), 

3314 is_kv_valid_pipe.reader("wg1"), 

3315 sM_wg1_pipe.writer(), 

3316 sM_wg0_pipe.reader(), 

3317 sS1_pipe.writer(), 

3318 sS0_pipe.reader(), 

3319 sL_wg1_pipe.writer(), 

3320 sL_wg0_pipe.reader(), 

3321 final_lse_smem, 

3322 output_desc, 

3323 output_row, 

3324 l_base, 

3325 h_base, 

3326 cache_seqlen, 

3327 log_scale, 

3328 D, 

3329 TD, 

3330 kv_cache.dtype.element_ty, 

3331 HAVE_TAIL, 

3332 BK, 

3333 BH, 

3334 DPH, 

3335 TDP, 

3336 G, 

3337 PAGE_SIZE, 

3338 ), 

3339 ), 

3340 ( 

3341 _tle_dense_decode_producer, 

3342 ( 

3343 k0_l_pipe.writer(), 

3344 k0_r_pipe.writer(), 

3345 k1_l_pipe.writer(), 

3346 k1_r_pipe.writer(), 

3347 is_kv_valid_pipe.writer(), 

3348 kv_base, 

3349 block_table_ptr, 

3350 cache_seqlen, 

3351 D, 

3352 TD, 

3353 DPH, 

3354 TDP, 

3355 PAGE_SIZE, 

3356 HAVE_TAIL, 

3357 BK, 

3358 stride_kvn, 

3359 stride_bt, 

3360 ), 

3361 ), 

3362 ], 

3363 [4, 4], 

3364 [216, 72], 

3365 )