Coverage for src/flag_gems/ops/flash_api.py: 86%

536 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.ops.flash_kernel import ( 

10 block_m_splitkv_heuristic, 

11 block_n_splitkv_heuristic, 

12 flash_fwd_kernel, 

13 flash_fwd_splitkv_combine_kernel, 

14 flash_fwd_splitkv_kernel, 

15 flash_varlen_fwd_kernel, 

16) 

17from flag_gems.runtime import torch_device_fn 

18from flag_gems.utils.random_utils import philox_backend_seed_offset 

19 

20logger = logging.getLogger(__name__) 

21_debug = False 

22 

23 

24def CHECK_DEVICE(x): 

25 assert x.device.type == flag_gems.device 

26 

27 

28class fwd_params: 

29 __slots__ = ( 

30 # pointers and strides 

31 "q_ptr", 

32 "k_ptr", 

33 "v_ptr", 

34 "o_ptr", 

35 "p_ptr", 

36 "softmax_lse_ptr", 

37 "q_row_stride", 

38 "k_row_stride", 

39 "v_row_stride", 

40 "q_head_stride", 

41 "k_head_stride", 

42 "v_head_stride", 

43 "o_row_stride", 

44 "o_head_stride", 

45 "q_batch_stride", 

46 "k_batch_stride", 

47 "v_batch_stride", 

48 "o_batch_stride", 

49 "is_cu_seqlens_q", 

50 "cu_seqlens_q_ptr", 

51 "is_cu_seqlens_k", 

52 "cu_seqlens_k_ptr", 

53 "is_seqused_k", 

54 "seqused_k_ptr", 

55 # sizes 

56 "b", 

57 "bk", 

58 "h", 

59 "hk", 

60 "h_hk_ratio", 

61 "seqlen_q", 

62 "seqlen_k", 

63 "seqlen_q_rounded", 

64 "seqlen_k_rounded", 

65 "d", 

66 "d_rounded", 

67 # scaling factors 

68 "is_softcap", 

69 "softcap", 

70 "scale_softmax", 

71 "scale_softmax_log2", 

72 # dropout 

73 "is_dropout", 

74 "p_dropout", 

75 "rp_dropout", 

76 "p_dropout_in_uint8_t", 

77 "philox_args", 

78 "return_softmax", 

79 # masking 

80 "is_causal", 

81 "is_local", 

82 "window_size_left", 

83 "window_size_right", 

84 "seqlenq_ngroups_swapped", 

85 "is_paged", 

86 # alibi 

87 "is_alibi", 

88 "alibi_slopes_ptr", 

89 "alibi_slopes_batch_stride", 

90 # block table 

91 "total_q", 

92 "page_table_ptr", 

93 "page_table_batch_stride", 

94 "block_size", 

95 ) 

96 

97 def __init__( 

98 self, 

99 q_ptr, 

100 k_ptr, 

101 v_ptr, 

102 o_ptr, 

103 p_ptr, 

104 softmax_lse_ptr, 

105 q_row_stride, 

106 k_row_stride, 

107 v_row_stride, 

108 q_head_stride, 

109 k_head_stride, 

110 v_head_stride, 

111 o_row_stride, 

112 o_head_stride, 

113 q_batch_stride, 

114 k_batch_stride, 

115 v_batch_stride, 

116 o_batch_stride, 

117 is_cu_seqlens_q, 

118 cu_seqlens_q_ptr, 

119 is_cu_seqlens_k, 

120 cu_seqlens_k_ptr, 

121 is_seqused_k, 

122 seqused_k_ptr, 

123 # sizes 

124 b, 

125 bk, 

126 h, 

127 hk, 

128 h_hk_ratio, 

129 seqlen_q, 

130 seqlen_k, 

131 seqlen_q_rounded, 

132 seqlen_k_rounded, 

133 d, 

134 d_rounded, 

135 # scaling factors 

136 is_softcap, 

137 softcap, 

138 scale_softmax, 

139 scale_softmax_log2, 

140 # dropout 

141 is_dropout, 

142 p_dropout, 

143 rp_dropout, 

144 p_dropout_in_uint8_t, 

145 philox_args, 

146 return_softmax, 

147 # masking 

148 is_causal, 

149 is_local, 

150 window_size_left, 

151 window_size_right, 

152 seqlenq_ngroups_swapped, 

153 is_paged, 

154 # alibi 

155 is_alibi, 

156 alibi_slopes_ptr, 

157 alibi_slopes_batch_stride, 

158 # block table 

159 total_q, 

160 page_table_ptr, 

161 page_table_batch_stride, 

162 block_size, 

163 ): 

164 self.q_ptr = q_ptr 

165 self.k_ptr = k_ptr 

166 self.v_ptr = v_ptr 

167 self.o_ptr = o_ptr 

168 self.p_ptr = p_ptr 

169 self.softmax_lse_ptr = softmax_lse_ptr 

170 self.q_row_stride = q_row_stride 

171 self.k_row_stride = k_row_stride 

172 self.v_row_stride = v_row_stride 

173 self.q_head_stride = q_head_stride 

174 self.k_head_stride = k_head_stride 

175 self.v_head_stride = v_head_stride 

176 self.o_row_stride = o_row_stride 

177 self.o_head_stride = o_head_stride 

178 self.q_batch_stride = q_batch_stride 

179 self.k_batch_stride = k_batch_stride 

180 self.v_batch_stride = v_batch_stride 

181 self.o_batch_stride = o_batch_stride 

182 self.is_cu_seqlens_q = is_cu_seqlens_q 

183 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

184 self.is_cu_seqlens_k = is_cu_seqlens_k 

185 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

186 self.is_seqused_k = is_seqused_k 

187 self.seqused_k_ptr = seqused_k_ptr 

188 # sizes 

189 self.b = b 

190 self.bk = bk 

191 self.h = h 

192 self.hk = hk 

193 self.h_hk_ratio = h_hk_ratio 

194 self.seqlen_q = seqlen_q 

195 self.seqlen_k = seqlen_k 

196 self.seqlen_q_rounded = seqlen_q_rounded 

197 self.seqlen_k_rounded = seqlen_k_rounded 

198 self.d = d 

199 self.d_rounded = d_rounded 

200 # scaling factors 

201 self.is_softcap = is_softcap 

202 self.softcap = softcap 

203 self.scale_softmax = scale_softmax 

204 self.scale_softmax_log2 = scale_softmax_log2 

205 # dropout 

206 self.is_dropout = is_dropout 

207 self.p_dropout = p_dropout 

208 self.rp_dropout = rp_dropout 

209 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

210 self.philox_args = philox_args 

211 self.return_softmax = return_softmax 

212 # masking 

213 self.is_causal = is_causal 

214 self.is_local = is_local 

215 self.window_size_left = window_size_left 

216 self.window_size_right = window_size_right 

217 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

218 self.is_paged = is_paged 

219 # alibi 

220 self.is_alibi = is_alibi 

221 self.alibi_slopes_ptr = alibi_slopes_ptr 

222 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

223 # block table 

224 self.total_q = total_q 

225 self.page_table_ptr = page_table_ptr 

226 self.page_table_batch_stride = page_table_batch_stride 

227 self.block_size = block_size 

228 

229 def args(self): 

230 return tuple(getattr(self, k) for k in self.__slots__) 

231 

232 

233def mha_varlan_fwd( 

234 q, 

235 k, 

236 v, 

237 out, 

238 cu_seqlens_q, 

239 cu_seqlens_k, 

240 seqused_k, 

241 leftpad_k, 

242 page_table, 

243 alibi_slopes, 

244 max_seqlen_q, 

245 max_seqlen_k, 

246 p_dropout, 

247 softmax_scale, 

248 zero_tensors, 

249 is_causal, 

250 window_size_left, 

251 window_size_right, 

252 softcap, 

253 return_softmax, 

254 gen, 

255): 

256 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

257 q_device = q.device 

258 q_dtype = q.dtype 

259 assert q_dtype in ( 

260 torch.float16, 

261 torch.bfloat16, 

262 ), "FlashAttention only support fp16 and bf16 data type" 

263 assert q_dtype == k.dtype 

264 assert q_dtype == v.dtype 

265 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

266 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

267 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

268 

269 assert cu_seqlens_q.dtype == torch.int32 

270 assert cu_seqlens_q.is_contiguous() 

271 

272 assert cu_seqlens_k.dtype == torch.int32 

273 assert cu_seqlens_k.is_contiguous() 

274 

275 is_paged = page_table is not None 

276 if not is_paged: 

277 page_table = torch.empty((0, 0), device=q_device, dtype=torch.int32) 

278 

279 # q shape: [total_q_tokens, num_heads, head_size] 

280 # k shape: 

281 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

282 # batch_size, number of sentences 

283 total_q, num_heads, head_size = q.size() 

284 num_heads_k = k.size(2) if is_paged else k.size(1) 

285 batch_size = cu_seqlens_q.numel() - 1 

286 block_size = k.size(1) if is_paged else 1 

287 num_pages = k.size(0) if is_paged else 0 

288 k_batch_size = num_pages 

289 # max_num_pages_per_seq = page_table.size(1) 

290 page_table_batch_stride = page_table.stride(0) 

291 k_batch_stride = k.stride(0) 

292 v_batch_stride = v.stride(0) 

293 

294 assert k.size() == v.size() 

295 assert cu_seqlens_q.size() == (batch_size + 1,) 

296 assert cu_seqlens_k.size() == (batch_size + 1,) 

297 

298 # Check output shape 

299 if out is not None: 

300 assert out.stride(-1) == 1 

301 assert out.dtype == q.dtype 

302 assert out.size() == (total_q, num_heads, head_size) 

303 

304 if seqused_k is not None: 

305 assert seqused_k.is_contiguous() 

306 assert seqused_k.size() == (batch_size,) 

307 

308 if max_seqlen_q == 1 and alibi_slopes is None: 

309 is_causal = False 

310 

311 if is_causal: 

312 window_size_right = 0 

313 

314 # check disable swa 

315 if window_size_left >= max_seqlen_k: 

316 window_size_left = -1 

317 if window_size_right >= max_seqlen_k: 

318 window_size_right = -1 

319 

320 is_local = window_size_left >= 0 

321 

322 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

323 seqlenq_ngroups_swapped = ( 

324 max_seqlen_q == 1 

325 and alibi_slopes is None 

326 and num_heads > num_heads_k 

327 and window_size_left < 0 

328 and window_size_right < 0 

329 and p_dropout == 0 

330 ) 

331 q_groups = num_heads // num_heads_k 

332 if seqlenq_ngroups_swapped: 

333 logger.debug("Swapping query groups and sequence dimensions") 

334 q = ( 

335 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

336 .transpose(1, 2) 

337 .reshape(batch_size * q_groups, num_heads_k, head_size) 

338 ) 

339 max_seqlen_q = q_groups 

340 num_heads = num_heads_k 

341 cu_seqlens_q = None 

342 q_batch_stride = q.stride(0) * max_seqlen_q 

343 k_batch_stride = k.stride(0) 

344 v_batch_stride = v.stride(0) 

345 # o_batch_stride = out.stride(0) * max_seqlen_q 

346 else: 

347 q_batch_stride = 0 

348 k_batch_stride = 0 

349 v_batch_stride = 0 

350 o_batch_stride = 0 

351 

352 total_q = q.size(0) 

353 

354 assert leftpad_k is None, "leftpad_k is not supported." 

355 assert ( 

356 head_size <= 256 

357 ), "FlashAttention forward only supports head dimension at most 256" 

358 assert ( 

359 head_size % 8 == 0 

360 ), "head_size must be a multiple of 8, this is ensured by padding!" 

361 assert ( 

362 num_heads % num_heads_k == 0 

363 ), "Number of heads in key/value must divide number of heads in query" 

364 

365 assert q.shape == (total_q, num_heads, head_size) 

366 if is_paged: 

367 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

368 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

369 assert k.stride() == v.stride() 

370 

371 if softcap > 0.0: 

372 assert p_dropout == 0, "dropout is not supported if softcap is used." 

373 

374 round_multiple = lambda x, m: (x + m - 1) // m * m 

375 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

376 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

377 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

378 

379 M_LOG2E = 1.4426950408889634074 

380 if softcap > 0.0: 

381 is_softcap = True 

382 adjusted_scale_softmax = softcap 

383 adjusted_softcap = softmax_scale / softcap 

384 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

385 else: 

386 is_softcap = False 

387 adjusted_softcap = 0.0 

388 adjusted_scale_softmax = softmax_scale 

389 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

390 

391 # Set alibi params 

392 if alibi_slopes is not None: 

393 assert alibi_slopes.device == q_device 

394 assert alibi_slopes.dtype in (torch.float,) 

395 assert alibi_slopes.stride(-1) == 1 

396 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

397 batch_size, 

398 num_heads, 

399 ) 

400 alibi_slopes_batch_stride = ( 

401 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

402 ) 

403 is_alibi = True 

404 else: 

405 alibi_slopes_batch_stride = 0 

406 is_alibi = False 

407 

408 # Prepare params to kernel 

409 with torch_device_fn.device(q_device): 

410 if out is not None: 

411 out_ = out 

412 if seqlenq_ngroups_swapped: 

413 out = torch.empty_like(q, dtype=v.dtype) 

414 else: 

415 out_ = None 

416 out = torch.empty_like(q, dtype=v.dtype) 

417 

418 if seqlenq_ngroups_swapped: 

419 o_batch_stride = out.stride(0) * max_seqlen_q 

420 

421 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

422 

423 if p_dropout > 0: 

424 is_dropout = True 

425 increment = batch_size * num_heads * 32 

426 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

427 philox_args = torch.tensor( 

428 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

429 ) 

430 else: 

431 is_dropout = False 

432 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

433 

434 p_dropout = 1 - p_dropout 

435 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

436 rp_dropout = 1.0 / p_dropout 

437 

438 if return_softmax: 

439 assert is_dropout, "Only supported with non-zero dropout." 

440 p = torch.empty( 

441 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

442 device=q_device, 

443 ) 

444 else: 

445 p = torch.empty((), device=q_device) 

446 

447 if zero_tensors: 

448 out.zero_() 

449 lse.fill_(float("-inf")) 

450 

451 params = fwd_params( 

452 q, # q_ptr, 

453 k, # k_ptr, 

454 v, # v_ptr, 

455 out, # o_ptr, 

456 p, # p_ptr, 

457 lse, # softmax_lse_ptr, 

458 q.stride(-3), # q_row_stride, 

459 k.stride(-3), # k_row_stride, 

460 v.stride(-3), # v_row_stride, 

461 q.stride(-2), # q_head_stride, 

462 k.stride(-2), # k_head_stride, 

463 v.stride(-2), # v_head_stride, 

464 out.stride(-3), # o_row_stride, 

465 out.stride(-2), # o_head_stride, 

466 q_batch_stride, # q_batch_stride, 

467 k_batch_stride, # k_batch_stride, 

468 v_batch_stride, # v_batch_stride, 

469 o_batch_stride, # o_batch_stride, 

470 cu_seqlens_q is not None, # is_cu_seqlens_q, 

471 cu_seqlens_q, # cu_seqlens_q_ptr, 

472 seqused_k is None, # is_cu_seqlens_k, 

473 cu_seqlens_k, # cu_seqlens_k_ptr, 

474 seqused_k is not None, # is_seqused_k, 

475 seqused_k, # seqused_k_ptr, 

476 # sizes 

477 batch_size, # b, 

478 k_batch_size, # bk, 

479 num_heads, # h, 

480 num_heads_k, # hk, 

481 num_heads // num_heads_k, # h_hk_ratio, 

482 max_seqlen_q, # seqlen_q, 

483 max_seqlen_k, # seqlen_k, 

484 seqlen_q_rounded, # seqlen_q_rounded, 

485 seqlen_k_rounded, # seqlen_k_rounded, 

486 head_size, # d, 

487 head_size_rounded, # d_rounded, 

488 # scaling factors 

489 is_softcap, 

490 adjusted_softcap, # softcap, 

491 adjusted_scale_softmax, # scale_softmax, 

492 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

493 # dropout 

494 is_dropout, 

495 p_dropout, 

496 rp_dropout, 

497 p_dropout_in_uint8_t, 

498 philox_args, 

499 return_softmax, 

500 # causal and swa 

501 is_causal, # is_causal, 

502 is_local, # is_local, 

503 window_size_left, # window_size_left, 

504 window_size_right, # window_size_right, 

505 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

506 is_paged, 

507 # alibi 

508 is_alibi, # 

509 alibi_slopes, # alibi_slopes_ptr, 

510 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

511 # block table params 

512 total_q, # total_q, 

513 page_table, # page_table_ptr, 

514 page_table_batch_stride, # page_table_batch_stride, 

515 block_size, # block_size, 

516 ) 

517 

518 if flag_gems.vendor_name == "iluvatar": 

519 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

520 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

521 logger.debug("kernel: flash_varlen_fwd") 

522 grid = lambda args: ( 

523 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

524 batch_size, 

525 num_heads, 

526 ) 

527 kernel = flash_varlen_fwd_kernel[grid] 

528 args = tuple(getattr(params, k) for k in params.__slots__) 

529 

530 # We assess which phase the requests are likely to be in and set the config accordingly. 

531 total_rows = total_q * num_heads 

532 num_sms = torch_device_fn.get_device_properties( 

533 flag_gems.device 

534 ).multi_processor_count 

535 avg_rows_per_sm = total_rows / num_sms 

536 avg_rows_per_batch = total_q / batch_size 

537 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

538 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

539 # This is a rough heuristic and may not be accurate for all scenarios. 

540 if avg_rows_per_cta > 64: 

541 varlen_fwd_config_str = "mha_block_128" 

542 elif avg_rows_per_cta > 32: 

543 varlen_fwd_config_str = "mha_block_64" 

544 elif avg_rows_per_cta > 16: 

545 varlen_fwd_config_str = "mha_block_32" 

546 else: 

547 varlen_fwd_config_str = "mha_block_16" 

548 if flag_gems.vendor_name == "mthreads": 

549 varlen_fwd_config_str = "mha_block_32" 

550 

551 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

552 cfg_params = { 

553 "BLOCK_M": cfg["BLOCK_M"](args), 

554 "BLOCK_N": cfg["BLOCK_N"](args), 

555 "BLOCK_K": triton.next_power_of_2(head_size), 

556 "num_warps": cfg["num_warps"](args), 

557 "num_stages": 1 if not is_paged else cfg["num_stages"](args), 

558 } 

559 

560 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

561 kernel(*args, **cfg_params) 

562 

563 if seqlenq_ngroups_swapped: 

564 out = out.reshape( 

565 batch_size, max_seqlen_q, num_heads_k, head_size 

566 ).transpose(1, 2) 

567 if out_ is not None: 

568 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

569 out = out_ 

570 else: 

571 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

572 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

573 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

574 

575 unused = torch.empty((), dtype=torch.int64, device=q_device) 

576 return out, q, k, v, lse, philox_args, unused, p 

577 

578 

579def mha_varlan_fwd_opt( 

580 q, 

581 k, 

582 v, 

583 out, 

584 lse, 

585 cu_seqlens_q, 

586 cu_seqlens_k, 

587 seqused_k, 

588 leftpad_k, 

589 page_table, 

590 alibi_slopes, 

591 max_seqlen_q, 

592 max_seqlen_k, 

593 p_dropout, 

594 softmax_scale, 

595 zero_tensors, 

596 is_causal, 

597 window_size_left, 

598 window_size_right, 

599 softcap, 

600 return_softmax, 

601 gen, 

602): 

603 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

604 q_device = q.device 

605 q_dtype = q.dtype 

606 assert q_dtype in ( 

607 torch.float16, 

608 torch.bfloat16, 

609 ), "FlashAttention only support fp16 and bf16 data type" 

610 assert q_dtype == k.dtype 

611 assert q_dtype == v.dtype 

612 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

613 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

614 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

615 

616 assert cu_seqlens_q.dtype == torch.int32 

617 assert cu_seqlens_q.is_contiguous() 

618 

619 assert cu_seqlens_k.dtype == torch.int32 

620 assert cu_seqlens_k.is_contiguous() 

621 

622 is_paged = page_table is not None 

623 if not is_paged: 

624 page_table = torch.emtpty((0, 0), device=q_device, dtype=torch.int32) 

625 

626 # q shape: [total_q_tokens, num_heads, head_size] 

627 # k shape: 

628 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

629 # batch_size, number of sentences 

630 total_q, num_heads, head_size = q.size() 

631 num_heads_k = k.size(2) if is_paged else k.size(1) 

632 batch_size = cu_seqlens_q.numel() - 1 

633 block_size = k.size(1) if is_paged else 1 

634 num_pages = k.size(0) if is_paged else 0 

635 k_batch_size = num_pages 

636 # max_num_pages_per_seq = page_table.size(1) 

637 page_table_batch_stride = page_table.stride(0) 

638 k_batch_stride = k.stride(0) 

639 v_batch_stride = v.stride(0) 

640 

641 assert k.size() == v.size() 

642 assert cu_seqlens_q.size() == (batch_size + 1,) 

643 assert cu_seqlens_k.size() == (batch_size + 1,) 

644 

645 # Check output shape 

646 if out is not None: 

647 assert out.stride(-1) == 1 

648 assert out.dtype == q.dtype 

649 assert out.size() == (total_q, num_heads, head_size) 

650 

651 if seqused_k is not None: 

652 assert seqused_k.is_contiguous() 

653 assert seqused_k.size() == (batch_size,) 

654 

655 if max_seqlen_q == 1 and alibi_slopes is None: 

656 is_causal = False 

657 

658 if is_causal: 

659 window_size_right = 0 

660 

661 # check disable swa 

662 if window_size_left >= max_seqlen_k: 

663 window_size_left = -1 

664 if window_size_right >= max_seqlen_k: 

665 window_size_right = -1 

666 

667 is_local = window_size_left >= 0 

668 

669 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

670 seqlenq_ngroups_swapped = ( 

671 max_seqlen_q == 1 

672 and alibi_slopes is None 

673 and num_heads > num_heads_k 

674 and window_size_left < 0 

675 and window_size_right < 0 

676 and p_dropout == 0 

677 ) 

678 q_groups = num_heads // num_heads_k 

679 if seqlenq_ngroups_swapped: 

680 logger.debug("Swapping query groups and sequence dimensions") 

681 q = ( 

682 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

683 .transpose(1, 2) 

684 .reshape(batch_size * q_groups, num_heads_k, head_size) 

685 ) 

686 max_seqlen_q = q_groups 

687 num_heads = num_heads_k 

688 cu_seqlens_q = None 

689 q_batch_stride = q.stride(0) * max_seqlen_q 

690 k_batch_stride = k.stride(0) 

691 v_batch_stride = v.stride(0) 

692 # o_batch_stride = out.stride(0) * max_seqlen_q 

693 else: 

694 q_batch_stride = 0 

695 k_batch_stride = 0 

696 v_batch_stride = 0 

697 o_batch_stride = 0 

698 

699 total_q = q.size(0) 

700 

701 assert leftpad_k is None, "leftpad_k is not supported." 

702 assert ( 

703 head_size <= 256 

704 ), "FlashAttention forward only supports head dimension at most 256" 

705 assert ( 

706 head_size % 8 == 0 

707 ), "head_size must be a multiple of 8, this is ensured by padding!" 

708 assert ( 

709 num_heads % num_heads_k == 0 

710 ), "Number of heads in key/value must divide number of heads in query" 

711 

712 assert q.shape == (total_q, num_heads, head_size) 

713 if is_paged: 

714 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

715 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

716 assert k.stride() == v.stride() 

717 

718 if softcap > 0.0: 

719 assert p_dropout == 0, "dropout is not supported if softcap is used." 

720 

721 round_multiple = lambda x, m: (x + m - 1) // m * m 

722 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

723 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

724 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

725 

726 M_LOG2E = 1.4426950408889634074 

727 if softcap > 0.0: 

728 is_softcap = True 

729 adjusted_scale_softmax = softcap 

730 adjusted_softcap = softmax_scale / softcap 

731 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

732 else: 

733 is_softcap = False 

734 adjusted_softcap = 0.0 

735 adjusted_scale_softmax = softmax_scale 

736 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

737 

738 # Set alibi params 

739 if alibi_slopes is not None: 

740 assert alibi_slopes.device == q_device 

741 assert alibi_slopes.dtype in (torch.float,) 

742 assert alibi_slopes.stride(-1) == 1 

743 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

744 batch_size, 

745 num_heads, 

746 ) 

747 alibi_slopes_batch_stride = ( 

748 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

749 ) 

750 is_alibi = True 

751 else: 

752 alibi_slopes_batch_stride = 0 

753 is_alibi = False 

754 

755 # Prepare params to kernel 

756 with torch_device_fn.device(q_device): 

757 if out is not None: 

758 out_ = out 

759 if seqlenq_ngroups_swapped: 

760 out = torch.empty_like(q, dtype=v.dtype) 

761 else: 

762 out_ = None 

763 out = torch.empty_like(q, dtype=v.dtype) 

764 

765 if seqlenq_ngroups_swapped: 

766 o_batch_stride = out.stride(0) * max_seqlen_q 

767 

768 if lse is None: 

769 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

770 

771 if p_dropout > 0: 

772 is_dropout = True 

773 increment = batch_size * num_heads * 32 

774 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

775 philox_args = torch.tensor( 

776 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

777 ) 

778 else: 

779 is_dropout = False 

780 # philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

781 philox_args = None 

782 

783 p_dropout = 1 - p_dropout 

784 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

785 rp_dropout = 1.0 / p_dropout 

786 

787 if return_softmax: 

788 assert is_dropout, "Only supported with non-zero dropout." 

789 p = torch.empty( 

790 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

791 device=q_device, 

792 ) 

793 else: 

794 # p = torch.empty((), device=q_device) 

795 p = None 

796 if zero_tensors: 

797 out.zero_() 

798 lse.fill_(float("-inf")) 

799 

800 params = fwd_params( 

801 q, # q_ptr, 

802 k, # k_ptr, 

803 v, # v_ptr, 

804 out, # o_ptr, 

805 p, # p_ptr, 

806 lse, # softmax_lse_ptr, 

807 q.stride(-3), # q_row_stride, 

808 k.stride(-3), # k_row_stride, 

809 v.stride(-3), # v_row_stride, 

810 q.stride(-2), # q_head_stride, 

811 k.stride(-2), # k_head_stride, 

812 v.stride(-2), # v_head_stride, 

813 out.stride(-3), # o_row_stride, 

814 out.stride(-2), # o_head_stride, 

815 q_batch_stride, # q_batch_stride, 

816 k_batch_stride, # k_batch_stride, 

817 v_batch_stride, # v_batch_stride, 

818 o_batch_stride, # o_batch_stride, 

819 cu_seqlens_q is not None, # is_cu_seqlens_q, 

820 cu_seqlens_q, # cu_seqlens_q_ptr, 

821 cu_seqlens_k is not None, # is_cu_seqlens_k, 

822 cu_seqlens_k, # cu_seqlens_k_ptr, 

823 seqused_k is not None, # is_seqused_k, 

824 seqused_k, # seqused_k_ptr, 

825 # sizes 

826 batch_size, # b, 

827 k_batch_size, # bk, 

828 num_heads, # h, 

829 num_heads_k, # hk, 

830 num_heads // num_heads_k, # h_hk_ratio, 

831 max_seqlen_q, # seqlen_q, 

832 max_seqlen_k, # seqlen_k, 

833 seqlen_q_rounded, # seqlen_q_rounded, 

834 seqlen_k_rounded, # seqlen_k_rounded, 

835 head_size, # d, 

836 head_size_rounded, # d_rounded, 

837 # scaling factors 

838 is_softcap, 

839 adjusted_softcap, # softcap, 

840 adjusted_scale_softmax, # scale_softmax, 

841 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

842 # dropout 

843 is_dropout, 

844 p_dropout, 

845 rp_dropout, 

846 p_dropout_in_uint8_t, 

847 philox_args, 

848 return_softmax, 

849 # causal and swa 

850 is_causal, # is_causal, 

851 is_local, # is_local, 

852 window_size_left, # window_size_left, 

853 window_size_right, # window_size_right, 

854 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

855 is_paged, 

856 # alibi 

857 is_alibi, # 

858 alibi_slopes, # alibi_slopes_ptr, 

859 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

860 # block table params 

861 total_q, # total_q, 

862 page_table, # page_table_ptr, 

863 page_table_batch_stride, # page_table_batch_stride, 

864 block_size, # block_size, 

865 ) 

866 

867 if flag_gems.vendor_name == "iluvatar": 

868 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

869 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

870 logger.debug("kernel: flash_varlen_fwd") 

871 grid = lambda args: ( 

872 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

873 batch_size, 

874 num_heads, 

875 ) 

876 kernel = flash_varlen_fwd_kernel[grid] 

877 args = tuple(getattr(params, k) for k in params.__slots__) 

878 

879 # We assess which phase the requests are likely to be in and set the config accordingly. 

880 total_rows = total_q * num_heads 

881 num_sms = torch_device_fn.get_device_properties( 

882 flag_gems.device 

883 ).multi_processor_count 

884 avg_rows_per_sm = total_rows / num_sms 

885 avg_rows_per_batch = total_q / batch_size 

886 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

887 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

888 # This is a rough heuristic and may not be accurate for all scenarios. 

889 if avg_rows_per_cta > 64: 

890 varlen_fwd_config_str = "mha_block_128" 

891 elif avg_rows_per_cta > 32: 

892 varlen_fwd_config_str = "mha_block_64" 

893 elif avg_rows_per_cta > 16: 

894 varlen_fwd_config_str = "mha_block_32" 

895 else: 

896 varlen_fwd_config_str = "mha_block_16" 

897 if flag_gems.vendor_name == "mthreads": 

898 varlen_fwd_config_str = "mha_block_32" 

899 

900 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

901 cfg_params = { 

902 "BLOCK_M": cfg["BLOCK_M"](args), 

903 "BLOCK_N": cfg["BLOCK_N"](args), 

904 "BLOCK_K": triton.next_power_of_2(head_size), 

905 "num_warps": cfg["num_warps"](args), 

906 "num_stages": 1 if not is_paged else cfg["num_stages"](args), 

907 } 

908 

909 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

910 kernel(*args, **cfg_params) 

911 

912 if seqlenq_ngroups_swapped: 

913 out = out.reshape( 

914 batch_size, max_seqlen_q, num_heads_k, head_size 

915 ).transpose(1, 2) 

916 if out_ is not None: 

917 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

918 out = out_ 

919 else: 

920 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

921 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

922 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

923 

924 # unused = torch.empty((), dtype=torch.int64, device=q_device) 

925 unused = None 

926 return out, q, k, v, lse, philox_args, unused, p 

927 

928 

929def mha_fwd( 

930 q, 

931 k, 

932 v, 

933 out, 

934 alibi_slopes, 

935 p_dropout, 

936 softmax_scale, 

937 is_causal, 

938 window_size_left, 

939 window_size_right, 

940 softcap, 

941 return_softmax, 

942 disable_splitkv=False, 

943): 

944 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

945 q_dtype = q.dtype 

946 q_device = q.device 

947 assert q_dtype in ( 

948 torch.float16, 

949 torch.bfloat16, 

950 ), "FlashAttention only support fp16 and bf16 data type" 

951 assert q_dtype == k.dtype 

952 assert q_dtype == v.dtype 

953 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

954 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

955 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

956 batch_size, seqlen_q, num_heads, head_size = q.size() 

957 _, seqlen_k, num_heads_k, _ = k.size() 

958 

959 # Check output shape 

960 if out is not None: 

961 assert out.stride(-1) == 1 

962 assert out.dtype == q.dtype 

963 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

964 CHECK_DEVICE(out) 

965 

966 assert ( 

967 head_size % 8 == 0 

968 ), "head_size must be a multiple of 8, this is ensured by padding!" 

969 assert ( 

970 num_heads % num_heads_k == 0 

971 ), "Number of heads in key/value must divide number of heads in query" 

972 if window_size_left >= seqlen_k: 

973 window_size_left = -1 

974 if window_size_right >= seqlen_k: 

975 window_size_right = -1 

976 if seqlen_q == 1 and alibi_slopes is None: 

977 is_causal = False 

978 if is_causal: 

979 window_size_right = 0 

980 

981 is_causal = window_size_left < 0 and window_size_right == 0 

982 is_local = window_size_left >= 0 and window_size_right >= 0 

983 

984 seqlenq_ngroups_swapped = ( 

985 seqlen_q == 1 

986 and alibi_slopes is None 

987 and num_heads > num_heads_k 

988 and window_size_left < 0 

989 and window_size_right < 0 

990 and p_dropout == 0 

991 ) 

992 q_groups = num_heads // num_heads_k 

993 

994 if seqlenq_ngroups_swapped: 

995 logger.debug("q_kg swapped.") 

996 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

997 seqlen_q = q_groups 

998 num_heads = num_heads_k 

999 

1000 round_multiple = lambda x, m: (x + m - 1) // m * m 

1001 head_size_rounded = round_multiple(head_size, 32) 

1002 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

1003 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

1004 

1005 assert ( 

1006 head_size <= 256 

1007 ), "FlashAttention forward only supports head dimension at most 256" 

1008 assert head_size == head_size_rounded, "head_size must be rounded to 32" 

1009 

1010 def splits_heuristic(num_tasks, num_sms, n_blocks): 

1011 # splits when wave efficiency is low 

1012 n_waves = triton.cdiv(num_tasks, num_sms) 

1013 eff = (num_tasks / num_sms) / n_waves 

1014 if eff > 0.8 or n_waves > 1: 

1015 return 1 

1016 

1017 min_blocks_per_split = 2 

1018 best_splits = min( 

1019 triton.cdiv(n_blocks, min_blocks_per_split), 

1020 int(math.floor(1.0 / eff)), 

1021 num_sms, 

1022 ) 

1023 

1024 return best_splits 

1025 

1026 with torch_device_fn.device(q_device): 

1027 # Set softmax params 

1028 lse = torch.empty( 

1029 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

1030 ) 

1031 

1032 if out is not None: 

1033 if seqlenq_ngroups_swapped: 

1034 out = out.reshape( 

1035 batch_size, num_heads_k, q_groups, head_size 

1036 ).transpose(1, 2) 

1037 else: 

1038 out = torch.empty_like(q, dtype=v.dtype) 

1039 

1040 # Set dropout params 

1041 if p_dropout > 0: 

1042 is_dropout = True 

1043 increment = batch_size * num_heads * 32 

1044 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

1045 philox_args = torch.tensor( 

1046 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

1047 ) 

1048 else: 

1049 is_dropout = False 

1050 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

1051 

1052 p_dropout = 1 - p_dropout 

1053 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

1054 rp_dropout = 1.0 / p_dropout 

1055 

1056 if return_softmax: 

1057 assert is_dropout, "Only supported with non-zero dropout." 

1058 p = torch.empty( 

1059 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

1060 device=q_device, 

1061 ) 

1062 else: 

1063 p = torch.empty((), device=q_device) 

1064 

1065 M_LOG2E = 1.4426950408889634074 

1066 if softcap > 0.0: 

1067 is_softcap = True 

1068 adjusted_scale_softmax = softcap 

1069 adjusted_softcap = softmax_scale / softcap 

1070 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

1071 else: 

1072 is_softcap = False 

1073 adjusted_softcap = 0.0 

1074 adjusted_scale_softmax = softmax_scale 

1075 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

1076 

1077 # Set alibi params 

1078 if alibi_slopes is not None: 

1079 assert alibi_slopes.device == q_device 

1080 assert alibi_slopes.dtype in (torch.float,) 

1081 assert alibi_slopes.stride(-1) == 1 

1082 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

1083 batch_size, 

1084 num_heads, 

1085 ) 

1086 alibi_slopes_batch_stride = ( 

1087 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

1088 ) 

1089 is_alibi = True 

1090 else: 

1091 alibi_slopes_batch_stride = 0 

1092 is_alibi = False 

1093 

1094 # ONLY EVEN_K IS SUPPORTED 

1095 assert head_size == head_size_rounded 

1096 

1097 # Do kernel dispatching 

1098 def dispatch(B, H, Q, K, D, params): 

1099 num_sms = torch_device_fn.get_device_properties( 

1100 "cuda" 

1101 ).multi_processor_count 

1102 

1103 # Try bh parallel 

1104 # if B * H > 0.8 * num_sms: 

1105 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

1106 # # Yield kernel and prefilled args 

1107 # return kernel, default_args, None, None 

1108 

1109 # Try splitkv 

1110 if not is_dropout and not is_local and not disable_splitkv: 

1111 BM = block_m_splitkv_heuristic(D) 

1112 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

1113 BN = block_n_splitkv_heuristic(D) 

1114 n_blocks = triton.cdiv(seqlen_k, BN) 

1115 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

1116 

1117 if n_splits > 1: 

1118 logger.debug("kernel: flash_fwd_splitkv") 

1119 lse_splits = torch.empty( 

1120 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

1121 ) 

1122 out_splits = torch.empty( 

1123 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device 

1124 ) 

1125 grid = lambda args: ( 

1126 triton.cdiv(Q, args["BLOCK_M"]), 

1127 n_splits, 

1128 B * H, 

1129 ) 

1130 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

1131 params.o_ptr = out_splits 

1132 params.softmax_lse_ptr = lse_splits 

1133 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

1134 kernel = splitkv_kernel(*params.args(), **extra_args) 

1135 

1136 if D >= 128: 

1137 BLOCK_M = 4 

1138 elif D >= 64: 

1139 BLOCK_M = 8 

1140 else: 

1141 BLOCK_M = 16 

1142 BLOCK_K = triton.next_power_of_2(D) 

1143 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

1144 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

1145 combine_args = { 

1146 "out_ptr": out, 

1147 "lse_ptr": lse, 

1148 "head_size": head_size, 

1149 "out_split_stride": out_splits.stride(0), 

1150 "lse_split_stride": lse_splits.stride(0), 

1151 "out_b_stride": out.stride(0), 

1152 "out_s_stride": out.stride(-3), 

1153 "out_h_stride": out.stride(-1), 

1154 "out_splits_ptr": out_splits, 

1155 "lse_splits_ptr": lse_splits, 

1156 "n_splits": n_splits, 

1157 "BLOCK_M": BLOCK_M, 

1158 "BLOCK_K": BLOCK_K, 

1159 "q_total": B * H * Q, 

1160 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

1161 } 

1162 combine_kernel(**combine_args) 

1163 return kernel 

1164 

1165 # Last option: flash_fwd 

1166 logger.debug("kernel: flash_fwd") 

1167 grid = lambda args: ( 

1168 triton.cdiv(Q, args["BLOCK_M"]), 

1169 H * B, 

1170 ) 

1171 kernel = flash_fwd_kernel[grid] 

1172 kernel = kernel(*params.args()) 

1173 return kernel 

1174 

1175 if _debug: 

1176 p = torch.empty( 

1177 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

1178 dtype=torch.float32, 

1179 device=q_device, 

1180 ) 

1181 return_softmax = True 

1182 

1183 params = fwd_params( 

1184 q, # q_ptr, 

1185 k, # k_ptr, 

1186 v, # v_ptr, 

1187 out, # o_ptr, 

1188 p, # p_ptr, 

1189 lse, # softmax_lse_ptr, 

1190 q.stride(-3), # q_row_stride, 

1191 k.stride(-3), # k_row_stride, 

1192 v.stride(-3), # v_row_stride, 

1193 q.stride(-2), # q_head_stride, 

1194 k.stride(-2), # k_head_stride, 

1195 v.stride(-2), # v_head_stride, 

1196 out.stride(-3), # o_row_stride, 

1197 out.stride(-2), # o_head_stride, 

1198 q.stride(0), # q_batch_stride, 

1199 k.stride(0), # k_batch_stride, 

1200 v.stride(0), # v_batch_stride, 

1201 out.stride(0), # o_batch_stride, 

1202 False, # is_cu_seqlens_q, 

1203 None, # cu_seqlens_q_ptr, 

1204 False, # is_cu_seqlens_k, 

1205 None, # cu_seqlens_k_ptr, 

1206 False, # is_seqused_k, 

1207 None, # seqused_k_ptr, 

1208 # sizes 

1209 batch_size, # b, 

1210 0, # bk, 

1211 num_heads, # h, 

1212 num_heads_k, # hk, 

1213 num_heads // num_heads_k, # h_hk_ratio, 

1214 seqlen_q, # seqlen_q, 

1215 seqlen_k, # seqlen_k, 

1216 seqlen_q_rounded, # seqlen_q_rounded, 

1217 seqlen_k_rounded, # seqlen_k_rounded, 

1218 head_size, # d, 

1219 head_size_rounded, # d_rounded, 

1220 # scaling factors 

1221 is_softcap, 

1222 adjusted_softcap, # softcap, 

1223 adjusted_scale_softmax, # scale_softmax, 

1224 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

1225 # dropout 

1226 is_dropout, 

1227 p_dropout, 

1228 rp_dropout, 

1229 p_dropout_in_uint8_t, 

1230 philox_args, 

1231 return_softmax, 

1232 # causal and swa 

1233 is_causal, # is_causal, 

1234 is_local, # is_local, 

1235 window_size_left, # window_size_left, 

1236 window_size_right, # window_size_right, 

1237 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

1238 False, # is_paged, 

1239 # alibi 

1240 is_alibi, # 

1241 alibi_slopes, # alibi_slopes_ptr, 

1242 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

1243 # block table params 

1244 0, # total_q, 

1245 None, # page_table_ptr, 

1246 0, # page_table_batch_stride, 

1247 0, # block_size, 

1248 ) 

1249 

1250 # Move TxD to last dims for correct stride in Triton tt.load 

1251 if flag_gems.vendor_name == "iluvatar": 

1252 params.q_ptr = q.transpose(1, 2) 

1253 params.k_ptr = k.transpose(1, 2) 

1254 params.v_ptr = v.transpose(1, 2) 

1255 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

1256 

1257 if _debug: 

1258 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

1259 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

1260 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

1261 # print(kernel.asm['ttgir']) 

1262 

1263 if seqlenq_ngroups_swapped: 

1264 out = out.transpose(1, 2).reshape( 

1265 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

1266 ) 

1267 q = q.transpose(1, 2).reshape( 

1268 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

1269 ) 

1270 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

1271 

1272 unused = torch.empty((), dtype=torch.int64, device=q_device) 

1273 

1274 return out, q, k, v, lse, philox_args, unused, p