Coverage for src/flag_gems/ops/flash_api.py: 86%

537 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.ops.flash_kernel import ( 

10 block_m_splitkv_heuristic, 

11 block_n_splitkv_heuristic, 

12 flash_fwd_kernel, 

13 flash_fwd_splitkv_combine_kernel, 

14 flash_fwd_splitkv_kernel, 

15 flash_varlen_fwd_kernel, 

16) 

17from flag_gems.runtime import torch_device_fn 

18from flag_gems.utils.random_utils import philox_backend_seed_offset 

19 

20logger = logging.getLogger(__name__) 

21_debug = False 

22 

23 

24def CHECK_DEVICE(x): 

25 assert x.device.type == flag_gems.device 

26 

27 

28class fwd_params: 

29 __slots__ = ( 

30 # pointers and strides 

31 "q_ptr", 

32 "k_ptr", 

33 "v_ptr", 

34 "o_ptr", 

35 "p_ptr", 

36 "softmax_lse_ptr", 

37 "q_row_stride", 

38 "k_row_stride", 

39 "v_row_stride", 

40 "q_head_stride", 

41 "k_head_stride", 

42 "v_head_stride", 

43 "o_row_stride", 

44 "o_head_stride", 

45 "q_batch_stride", 

46 "k_batch_stride", 

47 "v_batch_stride", 

48 "o_batch_stride", 

49 "is_cu_seqlens_q", 

50 "cu_seqlens_q_ptr", 

51 "is_cu_seqlens_k", 

52 "cu_seqlens_k_ptr", 

53 "is_seqused_k", 

54 "seqused_k_ptr", 

55 # sizes 

56 "b", 

57 "bk", 

58 "h", 

59 "hk", 

60 "h_hk_ratio", 

61 "seqlen_q", 

62 "seqlen_k", 

63 "seqlen_q_rounded", 

64 "seqlen_k_rounded", 

65 "d", 

66 "d_rounded", 

67 # scaling factors 

68 "is_softcap", 

69 "softcap", 

70 "scale_softmax", 

71 "scale_softmax_log2", 

72 # dropout 

73 "is_dropout", 

74 "p_dropout", 

75 "rp_dropout", 

76 "p_dropout_in_uint8_t", 

77 "philox_args", 

78 "return_softmax", 

79 # masking 

80 "is_causal", 

81 "is_local", 

82 "window_size_left", 

83 "window_size_right", 

84 "seqlenq_ngroups_swapped", 

85 "is_paged", 

86 # alibi 

87 "is_alibi", 

88 "alibi_slopes_ptr", 

89 "alibi_slopes_batch_stride", 

90 # block table 

91 "total_q", 

92 "page_table_ptr", 

93 "page_table_batch_stride", 

94 "block_size", 

95 "k_page_stride", 

96 ) 

97 

98 def __init__( 

99 self, 

100 q_ptr, 

101 k_ptr, 

102 v_ptr, 

103 o_ptr, 

104 p_ptr, 

105 softmax_lse_ptr, 

106 q_row_stride, 

107 k_row_stride, 

108 v_row_stride, 

109 q_head_stride, 

110 k_head_stride, 

111 v_head_stride, 

112 o_row_stride, 

113 o_head_stride, 

114 q_batch_stride, 

115 k_batch_stride, 

116 v_batch_stride, 

117 o_batch_stride, 

118 is_cu_seqlens_q, 

119 cu_seqlens_q_ptr, 

120 is_cu_seqlens_k, 

121 cu_seqlens_k_ptr, 

122 is_seqused_k, 

123 seqused_k_ptr, 

124 # sizes 

125 b, 

126 bk, 

127 h, 

128 hk, 

129 h_hk_ratio, 

130 seqlen_q, 

131 seqlen_k, 

132 seqlen_q_rounded, 

133 seqlen_k_rounded, 

134 d, 

135 d_rounded, 

136 # scaling factors 

137 is_softcap, 

138 softcap, 

139 scale_softmax, 

140 scale_softmax_log2, 

141 # dropout 

142 is_dropout, 

143 p_dropout, 

144 rp_dropout, 

145 p_dropout_in_uint8_t, 

146 philox_args, 

147 return_softmax, 

148 # masking 

149 is_causal, 

150 is_local, 

151 window_size_left, 

152 window_size_right, 

153 seqlenq_ngroups_swapped, 

154 is_paged, 

155 # alibi 

156 is_alibi, 

157 alibi_slopes_ptr, 

158 alibi_slopes_batch_stride, 

159 # block table 

160 total_q, 

161 page_table_ptr, 

162 page_table_batch_stride, 

163 block_size, 

164 k_page_stride, 

165 ): 

166 self.q_ptr = q_ptr 

167 self.k_ptr = k_ptr 

168 self.v_ptr = v_ptr 

169 self.o_ptr = o_ptr 

170 self.p_ptr = p_ptr 

171 self.softmax_lse_ptr = softmax_lse_ptr 

172 self.q_row_stride = q_row_stride 

173 self.k_row_stride = k_row_stride 

174 self.v_row_stride = v_row_stride 

175 self.q_head_stride = q_head_stride 

176 self.k_head_stride = k_head_stride 

177 self.v_head_stride = v_head_stride 

178 self.o_row_stride = o_row_stride 

179 self.o_head_stride = o_head_stride 

180 self.q_batch_stride = q_batch_stride 

181 self.k_batch_stride = k_batch_stride 

182 self.v_batch_stride = v_batch_stride 

183 self.o_batch_stride = o_batch_stride 

184 self.is_cu_seqlens_q = is_cu_seqlens_q 

185 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

186 self.is_cu_seqlens_k = is_cu_seqlens_k 

187 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

188 self.is_seqused_k = is_seqused_k 

189 self.seqused_k_ptr = seqused_k_ptr 

190 # sizes 

191 self.b = b 

192 self.bk = bk 

193 self.h = h 

194 self.hk = hk 

195 self.h_hk_ratio = h_hk_ratio 

196 self.seqlen_q = seqlen_q 

197 self.seqlen_k = seqlen_k 

198 self.seqlen_q_rounded = seqlen_q_rounded 

199 self.seqlen_k_rounded = seqlen_k_rounded 

200 self.d = d 

201 self.d_rounded = d_rounded 

202 # scaling factors 

203 self.is_softcap = is_softcap 

204 self.softcap = softcap 

205 self.scale_softmax = scale_softmax 

206 self.scale_softmax_log2 = scale_softmax_log2 

207 # dropout 

208 self.is_dropout = is_dropout 

209 self.p_dropout = p_dropout 

210 self.rp_dropout = rp_dropout 

211 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

212 self.philox_args = philox_args 

213 self.return_softmax = return_softmax 

214 # masking 

215 self.is_causal = is_causal 

216 self.is_local = is_local 

217 self.window_size_left = window_size_left 

218 self.window_size_right = window_size_right 

219 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

220 self.is_paged = is_paged 

221 # alibi 

222 self.is_alibi = is_alibi 

223 self.alibi_slopes_ptr = alibi_slopes_ptr 

224 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

225 # block table 

226 self.total_q = total_q 

227 self.page_table_ptr = page_table_ptr 

228 self.page_table_batch_stride = page_table_batch_stride 

229 self.block_size = block_size 

230 self.k_page_stride = k_page_stride 

231 

232 def args(self): 

233 return tuple(getattr(self, k) for k in self.__slots__) 

234 

235 

236def mha_varlan_fwd( 

237 q, 

238 k, 

239 v, 

240 out, 

241 cu_seqlens_q, 

242 cu_seqlens_k, 

243 seqused_k, 

244 leftpad_k, 

245 page_table, 

246 alibi_slopes, 

247 max_seqlen_q, 

248 max_seqlen_k, 

249 p_dropout, 

250 softmax_scale, 

251 zero_tensors, 

252 is_causal, 

253 window_size_left, 

254 window_size_right, 

255 softcap, 

256 return_softmax, 

257 gen, 

258): 

259 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

260 q_device = q.device 

261 q_dtype = q.dtype 

262 assert q_dtype in ( 

263 torch.float16, 

264 torch.bfloat16, 

265 ), "FlashAttention only support fp16 and bf16 data type" 

266 assert q_dtype == k.dtype 

267 assert q_dtype == v.dtype 

268 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

269 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

270 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

271 

272 assert cu_seqlens_q.dtype == torch.int32 

273 assert cu_seqlens_q.is_contiguous() 

274 

275 assert cu_seqlens_k.dtype == torch.int32 

276 assert cu_seqlens_k.is_contiguous() 

277 

278 is_paged = page_table is not None 

279 if not is_paged: 

280 page_table = torch.empty((0, 0), device=q_device, dtype=torch.int32) 

281 

282 # q shape: [total_q_tokens, num_heads, head_size] 

283 # k shape: 

284 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

285 # batch_size, number of sentences 

286 total_q, num_heads, head_size = q.size() 

287 num_heads_k = k.size(2) if is_paged else k.size(1) 

288 batch_size = cu_seqlens_q.numel() - 1 

289 block_size = k.size(1) if is_paged else 1 

290 num_pages = k.size(0) if is_paged else 0 

291 k_batch_size = num_pages 

292 # max_num_pages_per_seq = page_table.size(1) 

293 page_table_batch_stride = page_table.stride(0) 

294 k_batch_stride = k.stride(0) 

295 v_batch_stride = v.stride(0) 

296 

297 assert k.size() == v.size() 

298 assert cu_seqlens_q.size() == (batch_size + 1,) 

299 assert cu_seqlens_k.size() == (batch_size + 1,) 

300 

301 # Check output shape 

302 if out is not None: 

303 assert out.stride(-1) == 1 

304 assert out.dtype == q.dtype 

305 assert out.size() == (total_q, num_heads, head_size) 

306 

307 if seqused_k is not None: 

308 assert seqused_k.is_contiguous() 

309 assert seqused_k.size() == (batch_size,) 

310 

311 if max_seqlen_q == 1 and alibi_slopes is None: 

312 is_causal = False 

313 

314 if is_causal: 

315 window_size_right = 0 

316 

317 # check disable swa 

318 if window_size_left >= max_seqlen_k: 

319 window_size_left = -1 

320 if window_size_right >= max_seqlen_k: 

321 window_size_right = -1 

322 

323 is_local = window_size_left >= 0 

324 

325 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

326 seqlenq_ngroups_swapped = ( 

327 max_seqlen_q == 1 

328 and alibi_slopes is None 

329 and num_heads > num_heads_k 

330 and window_size_left < 0 

331 and window_size_right < 0 

332 and p_dropout == 0 

333 ) 

334 q_groups = num_heads // num_heads_k 

335 if seqlenq_ngroups_swapped: 

336 logger.debug("Swapping query groups and sequence dimensions") 

337 q = ( 

338 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

339 .transpose(1, 2) 

340 .reshape(batch_size * q_groups, num_heads_k, head_size) 

341 ) 

342 max_seqlen_q = q_groups 

343 num_heads = num_heads_k 

344 cu_seqlens_q = None 

345 q_batch_stride = q.stride(0) * max_seqlen_q 

346 k_batch_stride = k.stride(0) 

347 v_batch_stride = v.stride(0) 

348 # o_batch_stride = out.stride(0) * max_seqlen_q 

349 else: 

350 q_batch_stride = 0 

351 k_batch_stride = 0 

352 v_batch_stride = 0 

353 o_batch_stride = 0 

354 

355 total_q = q.size(0) 

356 

357 assert leftpad_k is None, "leftpad_k is not supported." 

358 assert ( 

359 head_size <= 256 

360 ), "FlashAttention forward only supports head dimension at most 256" 

361 assert ( 

362 head_size % 8 == 0 

363 ), "head_size must be a multiple of 8, this is ensured by padding!" 

364 assert ( 

365 num_heads % num_heads_k == 0 

366 ), "Number of heads in key/value must divide number of heads in query" 

367 

368 assert q.shape == (total_q, num_heads, head_size) 

369 if is_paged: 

370 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

371 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

372 assert k.stride() == v.stride() 

373 

374 if softcap > 0.0: 

375 assert p_dropout == 0, "dropout is not supported if softcap is used." 

376 

377 round_multiple = lambda x, m: (x + m - 1) // m * m 

378 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

379 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

380 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

381 

382 M_LOG2E = 1.4426950408889634074 

383 if softcap > 0.0: 

384 is_softcap = True 

385 adjusted_scale_softmax = softcap 

386 adjusted_softcap = softmax_scale / softcap 

387 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

388 else: 

389 is_softcap = False 

390 adjusted_softcap = 0.0 

391 adjusted_scale_softmax = softmax_scale 

392 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

393 

394 # Set alibi params 

395 if alibi_slopes is not None: 

396 assert alibi_slopes.device == q_device 

397 assert alibi_slopes.dtype in (torch.float,) 

398 assert alibi_slopes.stride(-1) == 1 

399 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

400 batch_size, 

401 num_heads, 

402 ) 

403 alibi_slopes_batch_stride = ( 

404 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

405 ) 

406 is_alibi = True 

407 else: 

408 alibi_slopes_batch_stride = 0 

409 is_alibi = False 

410 

411 # Prepare params to kernel 

412 with torch_device_fn.device(q_device): 

413 if out is not None: 

414 out_ = out 

415 if seqlenq_ngroups_swapped: 

416 out = torch.empty_like(q, dtype=v.dtype) 

417 else: 

418 out_ = None 

419 out = torch.empty_like(q, dtype=v.dtype) 

420 

421 if seqlenq_ngroups_swapped: 

422 o_batch_stride = out.stride(0) * max_seqlen_q 

423 

424 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

425 

426 if p_dropout > 0: 

427 is_dropout = True 

428 increment = batch_size * num_heads * 32 

429 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

430 philox_args = torch.tensor( 

431 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

432 ) 

433 else: 

434 is_dropout = False 

435 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

436 

437 p_dropout = 1 - p_dropout 

438 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

439 rp_dropout = 1.0 / p_dropout 

440 

441 if return_softmax: 

442 assert is_dropout, "Only supported with non-zero dropout." 

443 p = torch.empty( 

444 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

445 device=q_device, 

446 ) 

447 else: 

448 p = torch.empty((), device=q_device) 

449 

450 if zero_tensors: 

451 out.zero_() 

452 lse.fill_(float("-inf")) 

453 

454 params = fwd_params( 

455 q, # q_ptr, 

456 k, # k_ptr, 

457 v, # v_ptr, 

458 out, # o_ptr, 

459 p, # p_ptr, 

460 lse, # softmax_lse_ptr, 

461 q.stride(-3), # q_row_stride, 

462 k.stride(-3), # k_row_stride, 

463 v.stride(-3), # v_row_stride, 

464 q.stride(-2), # q_head_stride, 

465 k.stride(-2), # k_head_stride, 

466 v.stride(-2), # v_head_stride, 

467 out.stride(-3), # o_row_stride, 

468 out.stride(-2), # o_head_stride, 

469 q_batch_stride, # q_batch_stride, 

470 k_batch_stride, # k_batch_stride, 

471 v_batch_stride, # v_batch_stride, 

472 o_batch_stride, # o_batch_stride, 

473 cu_seqlens_q is not None, # is_cu_seqlens_q, 

474 cu_seqlens_q, # cu_seqlens_q_ptr, 

475 seqused_k is None, # is_cu_seqlens_k, 

476 cu_seqlens_k, # cu_seqlens_k_ptr, 

477 seqused_k is not None, # is_seqused_k, 

478 seqused_k, # seqused_k_ptr, 

479 # sizes 

480 batch_size, # b, 

481 k_batch_size, # bk, 

482 num_heads, # h, 

483 num_heads_k, # hk, 

484 num_heads // num_heads_k, # h_hk_ratio, 

485 max_seqlen_q, # seqlen_q, 

486 max_seqlen_k, # seqlen_k, 

487 seqlen_q_rounded, # seqlen_q_rounded, 

488 seqlen_k_rounded, # seqlen_k_rounded, 

489 head_size, # d, 

490 head_size_rounded, # d_rounded, 

491 # scaling factors 

492 is_softcap, 

493 adjusted_softcap, # softcap, 

494 adjusted_scale_softmax, # scale_softmax, 

495 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

496 # dropout 

497 is_dropout, 

498 p_dropout, 

499 rp_dropout, 

500 p_dropout_in_uint8_t, 

501 philox_args, 

502 return_softmax, 

503 # causal and swa 

504 is_causal, # is_causal, 

505 is_local, # is_local, 

506 window_size_left, # window_size_left, 

507 window_size_right, # window_size_right, 

508 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

509 is_paged, 

510 # alibi 

511 is_alibi, # 

512 alibi_slopes, # alibi_slopes_ptr, 

513 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

514 # block table params 

515 total_q, # total_q, 

516 page_table, # page_table_ptr, 

517 page_table_batch_stride, # page_table_batch_stride, 

518 block_size, # block_size, 

519 k.stride(0) if is_paged else 0, # k_page_stride, 

520 ) 

521 

522 if flag_gems.vendor_name == "iluvatar": 

523 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

524 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

525 logger.debug("kernel: flash_varlen_fwd") 

526 grid = lambda args: ( 

527 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

528 batch_size, 

529 num_heads, 

530 ) 

531 kernel = flash_varlen_fwd_kernel[grid] 

532 args = tuple(getattr(params, k) for k in params.__slots__) 

533 

534 # We assess which phase the requests are likely to be in and set the config accordingly. 

535 total_rows = total_q * num_heads 

536 num_sms = torch_device_fn.get_device_properties( 

537 flag_gems.device 

538 ).multi_processor_count 

539 avg_rows_per_sm = total_rows / num_sms 

540 avg_rows_per_batch = total_q / batch_size 

541 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

542 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

543 # This is a rough heuristic and may not be accurate for all scenarios. 

544 if avg_rows_per_cta > 64: 

545 varlen_fwd_config_str = "mha_block_128" 

546 elif avg_rows_per_cta > 32: 

547 varlen_fwd_config_str = "mha_block_64" 

548 elif avg_rows_per_cta > 16: 

549 varlen_fwd_config_str = "mha_block_32" 

550 else: 

551 varlen_fwd_config_str = "mha_block_16" 

552 if flag_gems.vendor_name == "mthreads": 

553 varlen_fwd_config_str = "mha_block_32" 

554 

555 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

556 cfg_params = { 

557 "BLOCK_M": cfg["BLOCK_M"](args), 

558 "BLOCK_N": cfg["BLOCK_N"](args), 

559 "BLOCK_K": triton.next_power_of_2(head_size), 

560 "num_warps": cfg["num_warps"](args), 

561 "num_stages": 1 if not is_paged else cfg["num_stages"](args), 

562 } 

563 

564 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

565 kernel(*args, **cfg_params) 

566 

567 if seqlenq_ngroups_swapped: 

568 out = out.reshape( 

569 batch_size, max_seqlen_q, num_heads_k, head_size 

570 ).transpose(1, 2) 

571 if out_ is not None: 

572 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

573 out = out_ 

574 else: 

575 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

576 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

577 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

578 

579 unused = torch.empty((), dtype=torch.int64, device=q_device) 

580 return out, q, k, v, lse, philox_args, unused, p 

581 

582 

583def mha_varlan_fwd_opt( 

584 q, 

585 k, 

586 v, 

587 out, 

588 lse, 

589 cu_seqlens_q, 

590 cu_seqlens_k, 

591 seqused_k, 

592 leftpad_k, 

593 page_table, 

594 alibi_slopes, 

595 max_seqlen_q, 

596 max_seqlen_k, 

597 p_dropout, 

598 softmax_scale, 

599 zero_tensors, 

600 is_causal, 

601 window_size_left, 

602 window_size_right, 

603 softcap, 

604 return_softmax, 

605 gen, 

606): 

607 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

608 q_device = q.device 

609 q_dtype = q.dtype 

610 assert q_dtype in ( 

611 torch.float16, 

612 torch.bfloat16, 

613 ), "FlashAttention only support fp16 and bf16 data type" 

614 assert q_dtype == k.dtype 

615 assert q_dtype == v.dtype 

616 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

617 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

618 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

619 

620 assert cu_seqlens_q.dtype == torch.int32 

621 assert cu_seqlens_q.is_contiguous() 

622 

623 assert cu_seqlens_k.dtype == torch.int32 

624 assert cu_seqlens_k.is_contiguous() 

625 

626 is_paged = page_table is not None 

627 if not is_paged: 

628 page_table = torch.emtpty((0, 0), device=q_device, dtype=torch.int32) 

629 

630 # q shape: [total_q_tokens, num_heads, head_size] 

631 # k shape: 

632 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

633 # batch_size, number of sentences 

634 total_q, num_heads, head_size = q.size() 

635 num_heads_k = k.size(2) if is_paged else k.size(1) 

636 batch_size = cu_seqlens_q.numel() - 1 

637 block_size = k.size(1) if is_paged else 1 

638 num_pages = k.size(0) if is_paged else 0 

639 k_batch_size = num_pages 

640 # max_num_pages_per_seq = page_table.size(1) 

641 page_table_batch_stride = page_table.stride(0) 

642 k_batch_stride = k.stride(0) 

643 v_batch_stride = v.stride(0) 

644 

645 assert k.size() == v.size() 

646 assert cu_seqlens_q.size() == (batch_size + 1,) 

647 assert cu_seqlens_k.size() == (batch_size + 1,) 

648 

649 # Check output shape 

650 if out is not None: 

651 assert out.stride(-1) == 1 

652 assert out.dtype == q.dtype 

653 assert out.size() == (total_q, num_heads, head_size) 

654 

655 if seqused_k is not None: 

656 assert seqused_k.is_contiguous() 

657 assert seqused_k.size() == (batch_size,) 

658 

659 if max_seqlen_q == 1 and alibi_slopes is None: 

660 is_causal = False 

661 

662 if is_causal: 

663 window_size_right = 0 

664 

665 # check disable swa 

666 if window_size_left >= max_seqlen_k: 

667 window_size_left = -1 

668 if window_size_right >= max_seqlen_k: 

669 window_size_right = -1 

670 

671 is_local = window_size_left >= 0 

672 

673 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

674 seqlenq_ngroups_swapped = ( 

675 max_seqlen_q == 1 

676 and alibi_slopes is None 

677 and num_heads > num_heads_k 

678 and window_size_left < 0 

679 and window_size_right < 0 

680 and p_dropout == 0 

681 ) 

682 q_groups = num_heads // num_heads_k 

683 if seqlenq_ngroups_swapped: 

684 logger.debug("Swapping query groups and sequence dimensions") 

685 q = ( 

686 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

687 .transpose(1, 2) 

688 .reshape(batch_size * q_groups, num_heads_k, head_size) 

689 ) 

690 max_seqlen_q = q_groups 

691 num_heads = num_heads_k 

692 cu_seqlens_q = None 

693 q_batch_stride = q.stride(0) * max_seqlen_q 

694 k_batch_stride = k.stride(0) 

695 v_batch_stride = v.stride(0) 

696 # o_batch_stride = out.stride(0) * max_seqlen_q 

697 else: 

698 q_batch_stride = 0 

699 k_batch_stride = 0 

700 v_batch_stride = 0 

701 o_batch_stride = 0 

702 

703 total_q = q.size(0) 

704 

705 assert leftpad_k is None, "leftpad_k is not supported." 

706 assert ( 

707 head_size <= 256 

708 ), "FlashAttention forward only supports head dimension at most 256" 

709 assert ( 

710 head_size % 8 == 0 

711 ), "head_size must be a multiple of 8, this is ensured by padding!" 

712 assert ( 

713 num_heads % num_heads_k == 0 

714 ), "Number of heads in key/value must divide number of heads in query" 

715 

716 assert q.shape == (total_q, num_heads, head_size) 

717 if is_paged: 

718 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

719 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

720 assert k.stride() == v.stride() 

721 

722 if softcap > 0.0: 

723 assert p_dropout == 0, "dropout is not supported if softcap is used." 

724 

725 round_multiple = lambda x, m: (x + m - 1) // m * m 

726 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256 

727 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

728 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

729 

730 M_LOG2E = 1.4426950408889634074 

731 if softcap > 0.0: 

732 is_softcap = True 

733 adjusted_scale_softmax = softcap 

734 adjusted_softcap = softmax_scale / softcap 

735 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

736 else: 

737 is_softcap = False 

738 adjusted_softcap = 0.0 

739 adjusted_scale_softmax = softmax_scale 

740 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

741 

742 # Set alibi params 

743 if alibi_slopes is not None: 

744 assert alibi_slopes.device == q_device 

745 assert alibi_slopes.dtype in (torch.float,) 

746 assert alibi_slopes.stride(-1) == 1 

747 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

748 batch_size, 

749 num_heads, 

750 ) 

751 alibi_slopes_batch_stride = ( 

752 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

753 ) 

754 is_alibi = True 

755 else: 

756 alibi_slopes_batch_stride = 0 

757 is_alibi = False 

758 

759 # Prepare params to kernel 

760 with torch_device_fn.device(q_device): 

761 if out is not None: 

762 out_ = out 

763 if seqlenq_ngroups_swapped: 

764 out = torch.empty_like(q, dtype=v.dtype) 

765 else: 

766 out_ = None 

767 out = torch.empty_like(q, dtype=v.dtype) 

768 

769 if seqlenq_ngroups_swapped: 

770 o_batch_stride = out.stride(0) * max_seqlen_q 

771 

772 if lse is None: 

773 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

774 

775 if p_dropout > 0: 

776 is_dropout = True 

777 increment = batch_size * num_heads * 32 

778 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

779 philox_args = torch.tensor( 

780 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

781 ) 

782 else: 

783 is_dropout = False 

784 # philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

785 philox_args = None 

786 

787 p_dropout = 1 - p_dropout 

788 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

789 rp_dropout = 1.0 / p_dropout 

790 

791 if return_softmax: 

792 assert is_dropout, "Only supported with non-zero dropout." 

793 p = torch.empty( 

794 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

795 device=q_device, 

796 ) 

797 else: 

798 # p = torch.empty((), device=q_device) 

799 p = None 

800 if zero_tensors: 

801 out.zero_() 

802 lse.fill_(float("-inf")) 

803 

804 params = fwd_params( 

805 q, # q_ptr, 

806 k, # k_ptr, 

807 v, # v_ptr, 

808 out, # o_ptr, 

809 p, # p_ptr, 

810 lse, # softmax_lse_ptr, 

811 q.stride(-3), # q_row_stride, 

812 k.stride(-3), # k_row_stride, 

813 v.stride(-3), # v_row_stride, 

814 q.stride(-2), # q_head_stride, 

815 k.stride(-2), # k_head_stride, 

816 v.stride(-2), # v_head_stride, 

817 out.stride(-3), # o_row_stride, 

818 out.stride(-2), # o_head_stride, 

819 q_batch_stride, # q_batch_stride, 

820 k_batch_stride, # k_batch_stride, 

821 v_batch_stride, # v_batch_stride, 

822 o_batch_stride, # o_batch_stride, 

823 cu_seqlens_q is not None, # is_cu_seqlens_q, 

824 cu_seqlens_q, # cu_seqlens_q_ptr, 

825 cu_seqlens_k is not None, # is_cu_seqlens_k, 

826 cu_seqlens_k, # cu_seqlens_k_ptr, 

827 seqused_k is not None, # is_seqused_k, 

828 seqused_k, # seqused_k_ptr, 

829 # sizes 

830 batch_size, # b, 

831 k_batch_size, # bk, 

832 num_heads, # h, 

833 num_heads_k, # hk, 

834 num_heads // num_heads_k, # h_hk_ratio, 

835 max_seqlen_q, # seqlen_q, 

836 max_seqlen_k, # seqlen_k, 

837 seqlen_q_rounded, # seqlen_q_rounded, 

838 seqlen_k_rounded, # seqlen_k_rounded, 

839 head_size, # d, 

840 head_size_rounded, # d_rounded, 

841 # scaling factors 

842 is_softcap, 

843 adjusted_softcap, # softcap, 

844 adjusted_scale_softmax, # scale_softmax, 

845 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

846 # dropout 

847 is_dropout, 

848 p_dropout, 

849 rp_dropout, 

850 p_dropout_in_uint8_t, 

851 philox_args, 

852 return_softmax, 

853 # causal and swa 

854 is_causal, # is_causal, 

855 is_local, # is_local, 

856 window_size_left, # window_size_left, 

857 window_size_right, # window_size_right, 

858 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

859 is_paged, 

860 # alibi 

861 is_alibi, # 

862 alibi_slopes, # alibi_slopes_ptr, 

863 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

864 # block table params 

865 total_q, # total_q, 

866 page_table, # page_table_ptr, 

867 page_table_batch_stride, # page_table_batch_stride, 

868 block_size, # block_size, 

869 k.stride(0) if is_paged else 0, # k_page_stride, 

870 ) 

871 

872 if flag_gems.vendor_name == "iluvatar": 

873 params.k_ptr = k.view(k.shape[0], k.shape[1], -1) 

874 params.v_ptr = v.view(v.shape[0], v.shape[1], -1) 

875 logger.debug("kernel: flash_varlen_fwd") 

876 grid = lambda args: ( 

877 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

878 batch_size, 

879 num_heads, 

880 ) 

881 kernel = flash_varlen_fwd_kernel[grid] 

882 args = tuple(getattr(params, k) for k in params.__slots__) 

883 

884 # We assess which phase the requests are likely to be in and set the config accordingly. 

885 total_rows = total_q * num_heads 

886 num_sms = torch_device_fn.get_device_properties( 

887 flag_gems.device 

888 ).multi_processor_count 

889 avg_rows_per_sm = total_rows / num_sms 

890 avg_rows_per_batch = total_q / batch_size 

891 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm) 

892 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase. 

893 # This is a rough heuristic and may not be accurate for all scenarios. 

894 if avg_rows_per_cta > 64: 

895 varlen_fwd_config_str = "mha_block_128" 

896 elif avg_rows_per_cta > 32: 

897 varlen_fwd_config_str = "mha_block_64" 

898 elif avg_rows_per_cta > 16: 

899 varlen_fwd_config_str = "mha_block_32" 

900 else: 

901 varlen_fwd_config_str = "mha_block_16" 

902 if flag_gems.vendor_name == "mthreads": 

903 varlen_fwd_config_str = "mha_block_32" 

904 

905 cfg = runtime.get_heuristic_config(varlen_fwd_config_str) 

906 cfg_params = { 

907 "BLOCK_M": cfg["BLOCK_M"](args), 

908 "BLOCK_N": cfg["BLOCK_N"](args), 

909 "BLOCK_K": triton.next_power_of_2(head_size), 

910 "num_warps": cfg["num_warps"](args), 

911 "num_stages": 1 if not is_paged else cfg["num_stages"](args), 

912 } 

913 

914 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params) 

915 kernel(*args, **cfg_params) 

916 

917 if seqlenq_ngroups_swapped: 

918 out = out.reshape( 

919 batch_size, max_seqlen_q, num_heads_k, head_size 

920 ).transpose(1, 2) 

921 if out_ is not None: 

922 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

923 out = out_ 

924 else: 

925 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

926 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

927 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

928 

929 # unused = torch.empty((), dtype=torch.int64, device=q_device) 

930 unused = None 

931 return out, q, k, v, lse, philox_args, unused, p 

932 

933 

934def mha_fwd( 

935 q, 

936 k, 

937 v, 

938 out, 

939 alibi_slopes, 

940 p_dropout, 

941 softmax_scale, 

942 is_causal, 

943 window_size_left, 

944 window_size_right, 

945 softcap, 

946 return_softmax, 

947 disable_splitkv=False, 

948): 

949 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

950 q_dtype = q.dtype 

951 q_device = q.device 

952 assert q_dtype in ( 

953 torch.float16, 

954 torch.bfloat16, 

955 ), "FlashAttention only support fp16 and bf16 data type" 

956 assert q_dtype == k.dtype 

957 assert q_dtype == v.dtype 

958 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

959 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

960 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

961 batch_size, seqlen_q, num_heads, head_size = q.size() 

962 _, seqlen_k, num_heads_k, _ = k.size() 

963 

964 # Check output shape 

965 if out is not None: 

966 assert out.stride(-1) == 1 

967 assert out.dtype == q.dtype 

968 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

969 CHECK_DEVICE(out) 

970 

971 assert ( 

972 head_size % 8 == 0 

973 ), "head_size must be a multiple of 8, this is ensured by padding!" 

974 assert ( 

975 num_heads % num_heads_k == 0 

976 ), "Number of heads in key/value must divide number of heads in query" 

977 if window_size_left >= seqlen_k: 

978 window_size_left = -1 

979 if window_size_right >= seqlen_k: 

980 window_size_right = -1 

981 if seqlen_q == 1 and alibi_slopes is None: 

982 is_causal = False 

983 if is_causal: 

984 window_size_right = 0 

985 

986 is_causal = window_size_left < 0 and window_size_right == 0 

987 is_local = window_size_left >= 0 and window_size_right >= 0 

988 

989 seqlenq_ngroups_swapped = ( 

990 seqlen_q == 1 

991 and alibi_slopes is None 

992 and num_heads > num_heads_k 

993 and window_size_left < 0 

994 and window_size_right < 0 

995 and p_dropout == 0 

996 ) 

997 q_groups = num_heads // num_heads_k 

998 

999 if seqlenq_ngroups_swapped: 

1000 logger.debug("q_kg swapped.") 

1001 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

1002 seqlen_q = q_groups 

1003 num_heads = num_heads_k 

1004 

1005 round_multiple = lambda x, m: (x + m - 1) // m * m 

1006 head_size_rounded = round_multiple(head_size, 32) 

1007 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

1008 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

1009 

1010 assert ( 

1011 head_size <= 256 

1012 ), "FlashAttention forward only supports head dimension at most 256" 

1013 assert head_size == head_size_rounded, "head_size must be rounded to 32" 

1014 

1015 def splits_heuristic(num_tasks, num_sms, n_blocks): 

1016 # splits when wave efficiency is low 

1017 n_waves = triton.cdiv(num_tasks, num_sms) 

1018 eff = (num_tasks / num_sms) / n_waves 

1019 if eff > 0.8 or n_waves > 1: 

1020 return 1 

1021 

1022 min_blocks_per_split = 2 

1023 best_splits = min( 

1024 triton.cdiv(n_blocks, min_blocks_per_split), 

1025 int(math.floor(1.0 / eff)), 

1026 num_sms, 

1027 ) 

1028 

1029 return best_splits 

1030 

1031 with torch_device_fn.device(q_device): 

1032 # Set softmax params 

1033 lse = torch.empty( 

1034 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

1035 ) 

1036 

1037 if out is not None: 

1038 if seqlenq_ngroups_swapped: 

1039 out = out.reshape( 

1040 batch_size, num_heads_k, q_groups, head_size 

1041 ).transpose(1, 2) 

1042 else: 

1043 out = torch.empty_like(q, dtype=v.dtype) 

1044 

1045 # Set dropout params 

1046 if p_dropout > 0: 

1047 is_dropout = True 

1048 increment = batch_size * num_heads * 32 

1049 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

1050 philox_args = torch.tensor( 

1051 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

1052 ) 

1053 else: 

1054 is_dropout = False 

1055 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

1056 

1057 p_dropout = 1 - p_dropout 

1058 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

1059 rp_dropout = 1.0 / p_dropout 

1060 

1061 if return_softmax: 

1062 assert is_dropout, "Only supported with non-zero dropout." 

1063 p = torch.empty( 

1064 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

1065 device=q_device, 

1066 ) 

1067 else: 

1068 p = torch.empty((), device=q_device) 

1069 

1070 M_LOG2E = 1.4426950408889634074 

1071 if softcap > 0.0: 

1072 is_softcap = True 

1073 adjusted_scale_softmax = softcap 

1074 adjusted_softcap = softmax_scale / softcap 

1075 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

1076 else: 

1077 is_softcap = False 

1078 adjusted_softcap = 0.0 

1079 adjusted_scale_softmax = softmax_scale 

1080 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

1081 

1082 # Set alibi params 

1083 if alibi_slopes is not None: 

1084 assert alibi_slopes.device == q_device 

1085 assert alibi_slopes.dtype in (torch.float,) 

1086 assert alibi_slopes.stride(-1) == 1 

1087 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

1088 batch_size, 

1089 num_heads, 

1090 ) 

1091 alibi_slopes_batch_stride = ( 

1092 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

1093 ) 

1094 is_alibi = True 

1095 else: 

1096 alibi_slopes_batch_stride = 0 

1097 is_alibi = False 

1098 

1099 # ONLY EVEN_K IS SUPPORTED 

1100 assert head_size == head_size_rounded 

1101 

1102 # Do kernel dispatching 

1103 def dispatch(B, H, Q, K, D, params): 

1104 num_sms = torch_device_fn.get_device_properties( 

1105 "cuda" 

1106 ).multi_processor_count 

1107 

1108 # Try bh parallel 

1109 # if B * H > 0.8 * num_sms: 

1110 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

1111 # # Yield kernel and prefilled args 

1112 # return kernel, default_args, None, None 

1113 

1114 # Try splitkv 

1115 if not is_dropout and not is_local and not disable_splitkv: 

1116 BM = block_m_splitkv_heuristic(D) 

1117 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

1118 BN = block_n_splitkv_heuristic(D) 

1119 n_blocks = triton.cdiv(seqlen_k, BN) 

1120 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

1121 

1122 if n_splits > 1: 

1123 logger.debug("kernel: flash_fwd_splitkv") 

1124 lse_splits = torch.empty( 

1125 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

1126 ) 

1127 out_splits = torch.empty( 

1128 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device 

1129 ) 

1130 grid = lambda args: ( 

1131 triton.cdiv(Q, args["BLOCK_M"]), 

1132 n_splits, 

1133 B * H, 

1134 ) 

1135 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

1136 params.o_ptr = out_splits 

1137 params.softmax_lse_ptr = lse_splits 

1138 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

1139 kernel = splitkv_kernel(*params.args(), **extra_args) 

1140 

1141 if D >= 128: 

1142 BLOCK_M = 4 

1143 elif D >= 64: 

1144 BLOCK_M = 8 

1145 else: 

1146 BLOCK_M = 16 

1147 BLOCK_K = triton.next_power_of_2(D) 

1148 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

1149 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

1150 combine_args = { 

1151 "out_ptr": out, 

1152 "lse_ptr": lse, 

1153 "head_size": head_size, 

1154 "out_split_stride": out_splits.stride(0), 

1155 "lse_split_stride": lse_splits.stride(0), 

1156 "out_b_stride": out.stride(0), 

1157 "out_s_stride": out.stride(-3), 

1158 "out_h_stride": out.stride(-1), 

1159 "out_splits_ptr": out_splits, 

1160 "lse_splits_ptr": lse_splits, 

1161 "n_splits": n_splits, 

1162 "BLOCK_M": BLOCK_M, 

1163 "BLOCK_K": BLOCK_K, 

1164 "q_total": B * H * Q, 

1165 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

1166 } 

1167 combine_kernel(**combine_args) 

1168 return kernel 

1169 

1170 # Last option: flash_fwd 

1171 logger.debug("kernel: flash_fwd") 

1172 grid = lambda args: ( 

1173 triton.cdiv(Q, args["BLOCK_M"]), 

1174 H * B, 

1175 ) 

1176 kernel = flash_fwd_kernel[grid] 

1177 kernel = kernel(*params.args()) 

1178 return kernel 

1179 

1180 if _debug: 

1181 p = torch.empty( 

1182 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

1183 dtype=torch.float32, 

1184 device=q_device, 

1185 ) 

1186 return_softmax = True 

1187 

1188 params = fwd_params( 

1189 q, # q_ptr, 

1190 k, # k_ptr, 

1191 v, # v_ptr, 

1192 out, # o_ptr, 

1193 p, # p_ptr, 

1194 lse, # softmax_lse_ptr, 

1195 q.stride(-3), # q_row_stride, 

1196 k.stride(-3), # k_row_stride, 

1197 v.stride(-3), # v_row_stride, 

1198 q.stride(-2), # q_head_stride, 

1199 k.stride(-2), # k_head_stride, 

1200 v.stride(-2), # v_head_stride, 

1201 out.stride(-3), # o_row_stride, 

1202 out.stride(-2), # o_head_stride, 

1203 q.stride(0), # q_batch_stride, 

1204 k.stride(0), # k_batch_stride, 

1205 v.stride(0), # v_batch_stride, 

1206 out.stride(0), # o_batch_stride, 

1207 False, # is_cu_seqlens_q, 

1208 None, # cu_seqlens_q_ptr, 

1209 False, # is_cu_seqlens_k, 

1210 None, # cu_seqlens_k_ptr, 

1211 False, # is_seqused_k, 

1212 None, # seqused_k_ptr, 

1213 # sizes 

1214 batch_size, # b, 

1215 0, # bk, 

1216 num_heads, # h, 

1217 num_heads_k, # hk, 

1218 num_heads // num_heads_k, # h_hk_ratio, 

1219 seqlen_q, # seqlen_q, 

1220 seqlen_k, # seqlen_k, 

1221 seqlen_q_rounded, # seqlen_q_rounded, 

1222 seqlen_k_rounded, # seqlen_k_rounded, 

1223 head_size, # d, 

1224 head_size_rounded, # d_rounded, 

1225 # scaling factors 

1226 is_softcap, 

1227 adjusted_softcap, # softcap, 

1228 adjusted_scale_softmax, # scale_softmax, 

1229 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

1230 # dropout 

1231 is_dropout, 

1232 p_dropout, 

1233 rp_dropout, 

1234 p_dropout_in_uint8_t, 

1235 philox_args, 

1236 return_softmax, 

1237 # causal and swa 

1238 is_causal, # is_causal, 

1239 is_local, # is_local, 

1240 window_size_left, # window_size_left, 

1241 window_size_right, # window_size_right, 

1242 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

1243 False, # is_paged, 

1244 # alibi 

1245 is_alibi, # 

1246 alibi_slopes, # alibi_slopes_ptr, 

1247 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

1248 # block table params 

1249 0, # total_q, 

1250 None, # page_table_ptr, 

1251 0, # page_table_batch_stride, 

1252 0, # block_size, 

1253 0, # k_page_stride, 

1254 ) 

1255 

1256 # Move TxD to last dims for correct stride in Triton tt.load 

1257 if flag_gems.vendor_name == "iluvatar": 

1258 params.q_ptr = q.transpose(1, 2) 

1259 params.k_ptr = k.transpose(1, 2) 

1260 params.v_ptr = v.transpose(1, 2) 

1261 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

1262 

1263 if _debug: 

1264 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

1265 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

1266 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

1267 # print(kernel.asm['ttgir']) 

1268 

1269 if seqlenq_ngroups_swapped: 

1270 out = out.transpose(1, 2).reshape( 

1271 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

1272 ) 

1273 q = q.transpose(1, 2).reshape( 

1274 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

1275 ) 

1276 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

1277 

1278 unused = torch.empty((), dtype=torch.int64, device=q_device) 

1279 

1280 return out, q, k, v, lse, philox_args, unused, p