Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/flash_api.py: 0%

349 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6 

7import flag_gems 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.random_utils import philox_backend_seed_offset 

11 

12from .flash_kernel import ( 

13 block_m_splitkv_heuristic, 

14 block_n_splitkv_heuristic, 

15 flash_fwd_kernel, 

16 flash_fwd_splitkv_combine_kernel, 

17 flash_fwd_splitkv_kernel, 

18 flash_varlen_fwd_kernel, 

19) 

20 

21TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count 

22 

23logger = logging.getLogger(__name__) 

24_debug = False 

25 

26 

27def CHECK_DEVICE(x): 

28 assert x.device.type == flag_gems.device 

29 

30 

31class fwd_params: 

32 __slots__ = ( 

33 # pointers and strides 

34 "q_ptr", 

35 "k_ptr", 

36 "v_ptr", 

37 "o_ptr", 

38 "p_ptr", 

39 "softmax_lse_ptr", 

40 "q_row_stride", 

41 "k_row_stride", 

42 "v_row_stride", 

43 "q_head_stride", 

44 "k_head_stride", 

45 "v_head_stride", 

46 "o_row_stride", 

47 "o_head_stride", 

48 "q_batch_stride", 

49 "k_batch_stride", 

50 "v_batch_stride", 

51 "o_batch_stride", 

52 "is_cu_seqlens_q", 

53 "cu_seqlens_q_ptr", 

54 "is_cu_seqlens_k", 

55 "cu_seqlens_k_ptr", 

56 "is_seqused_k", 

57 "seqused_k_ptr", 

58 # sizes 

59 "b", 

60 "bk", 

61 "h", 

62 "hk", 

63 "h_hk_ratio", 

64 "seqlen_q", 

65 "seqlen_k", 

66 "seqlen_q_rounded", 

67 "seqlen_k_rounded", 

68 "d", 

69 "d_rounded", 

70 # scaling factors 

71 "is_softcap", 

72 "softcap", 

73 "scale_softmax", 

74 "scale_softmax_log2", 

75 # dropout 

76 "is_dropout", 

77 "p_dropout", 

78 "rp_dropout", 

79 "p_dropout_in_uint8_t", 

80 "philox_args", 

81 "return_softmax", 

82 # masking 

83 "is_causal", 

84 "is_local", 

85 "window_size_left", 

86 "window_size_right", 

87 "seqlenq_ngroups_swapped", 

88 # alibi 

89 "is_alibi", 

90 "alibi_slopes_ptr", 

91 "alibi_slopes_batch_stride", 

92 # block table 

93 "total_q", 

94 "page_table_ptr", 

95 "page_table_batch_stride", 

96 "block_size", 

97 ) 

98 

99 def __init__( 

100 self, 

101 q_ptr, 

102 k_ptr, 

103 v_ptr, 

104 o_ptr, 

105 p_ptr, 

106 softmax_lse_ptr, 

107 q_row_stride, 

108 k_row_stride, 

109 v_row_stride, 

110 q_head_stride, 

111 k_head_stride, 

112 v_head_stride, 

113 o_row_stride, 

114 o_head_stride, 

115 q_batch_stride, 

116 k_batch_stride, 

117 v_batch_stride, 

118 o_batch_stride, 

119 is_cu_seqlens_q, 

120 cu_seqlens_q_ptr, 

121 is_cu_seqlens_k, 

122 cu_seqlens_k_ptr, 

123 is_seqused_k, 

124 seqused_k_ptr, 

125 # sizes 

126 b, 

127 bk, 

128 h, 

129 hk, 

130 h_hk_ratio, 

131 seqlen_q, 

132 seqlen_k, 

133 seqlen_q_rounded, 

134 seqlen_k_rounded, 

135 d, 

136 d_rounded, 

137 # scaling factors 

138 is_softcap, 

139 softcap, 

140 scale_softmax, 

141 scale_softmax_log2, 

142 # dropout 

143 is_dropout, 

144 p_dropout, 

145 rp_dropout, 

146 p_dropout_in_uint8_t, 

147 philox_args, 

148 return_softmax, 

149 # masking 

150 is_causal, 

151 is_local, 

152 window_size_left, 

153 window_size_right, 

154 seqlenq_ngroups_swapped, 

155 # alibi 

156 is_alibi, 

157 alibi_slopes_ptr, 

158 alibi_slopes_batch_stride, 

159 # block table 

160 total_q, 

161 page_table_ptr, 

162 page_table_batch_stride, 

163 block_size, 

164 ): 

165 self.q_ptr = q_ptr 

166 self.k_ptr = k_ptr 

167 self.v_ptr = v_ptr 

168 self.o_ptr = o_ptr 

169 self.p_ptr = p_ptr 

170 self.softmax_lse_ptr = softmax_lse_ptr 

171 self.q_row_stride = q_row_stride 

172 self.k_row_stride = k_row_stride 

173 self.v_row_stride = v_row_stride 

174 self.q_head_stride = q_head_stride 

175 self.k_head_stride = k_head_stride 

176 self.v_head_stride = v_head_stride 

177 self.o_row_stride = o_row_stride 

178 self.o_head_stride = o_head_stride 

179 self.q_batch_stride = q_batch_stride 

180 self.k_batch_stride = k_batch_stride 

181 self.v_batch_stride = v_batch_stride 

182 self.o_batch_stride = o_batch_stride 

183 self.is_cu_seqlens_q = is_cu_seqlens_q 

184 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr 

185 self.is_cu_seqlens_k = is_cu_seqlens_k 

186 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr 

187 self.is_seqused_k = is_seqused_k 

188 self.seqused_k_ptr = seqused_k_ptr 

189 # sizes 

190 self.b = b 

191 self.bk = bk 

192 self.h = h 

193 self.hk = hk 

194 self.h_hk_ratio = h_hk_ratio 

195 self.seqlen_q = seqlen_q 

196 self.seqlen_k = seqlen_k 

197 self.seqlen_q_rounded = seqlen_q_rounded 

198 self.seqlen_k_rounded = seqlen_k_rounded 

199 self.d = d 

200 self.d_rounded = d_rounded 

201 # scaling factors 

202 self.is_softcap = is_softcap 

203 self.softcap = softcap 

204 self.scale_softmax = scale_softmax 

205 self.scale_softmax_log2 = scale_softmax_log2 

206 # dropout 

207 self.is_dropout = is_dropout 

208 self.p_dropout = p_dropout 

209 self.rp_dropout = rp_dropout 

210 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t 

211 self.philox_args = philox_args 

212 self.return_softmax = return_softmax 

213 # masking 

214 self.is_causal = is_causal 

215 self.is_local = is_local 

216 self.window_size_left = window_size_left 

217 self.window_size_right = window_size_right 

218 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped 

219 # alibi 

220 self.is_alibi = is_alibi 

221 self.alibi_slopes_ptr = alibi_slopes_ptr 

222 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride 

223 # block table 

224 self.total_q = total_q 

225 self.page_table_ptr = page_table_ptr 

226 self.page_table_batch_stride = page_table_batch_stride 

227 self.block_size = block_size 

228 

229 def args(self): 

230 return tuple(getattr(self, k) for k in self.__slots__) 

231 

232 

233def mha_varlan_fwd( 

234 q, 

235 k, 

236 v, 

237 out, 

238 cu_seqlens_q, 

239 cu_seqlens_k, 

240 seqused_k, 

241 leftpad_k, 

242 page_table, 

243 alibi_slopes, 

244 max_seqlen_q, 

245 max_seqlen_k, 

246 p_dropout, 

247 softmax_scale, 

248 zero_tensors, 

249 is_causal, 

250 window_size_left, 

251 window_size_right, 

252 softcap, 

253 return_softmax, 

254 gen, 

255): 

256 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

257 q_device = q.device 

258 q_dtype = q.dtype 

259 assert q_dtype in ( 

260 torch.float16, 

261 torch.bfloat16, 

262 ), "FlashAttention only support fp16 and bf16 data type" 

263 assert q_dtype == k.dtype 

264 assert q_dtype == v.dtype 

265 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

266 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

267 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

268 

269 assert cu_seqlens_q.dtype == torch.int32 

270 assert cu_seqlens_q.is_contiguous() 

271 

272 assert cu_seqlens_k.dtype == torch.int32 

273 assert cu_seqlens_k.is_contiguous() 

274 

275 assert page_table is not None 

276 

277 # q shape: [total_q_tokens, num_heads, head_size] 

278 # k shape: 

279 # paged_kv: [num_pages, block_size, num_heads_k, head_size] 

280 # batch_size, number of sentences 

281 total_q, num_heads, head_size = q.size() 

282 num_heads_k = k.size(2) 

283 batch_size = cu_seqlens_q.numel() - 1 

284 block_size = k.size(1) 

285 num_pages = k.size(0) 

286 k_batch_size = num_pages 

287 # max_num_pages_per_seq = page_table.size(1) 

288 page_table_batch_stride = page_table.stride(0) 

289 k_batch_stride = k.stride(0) 

290 v_batch_stride = v.stride(0) 

291 

292 assert k.size() == v.size() 

293 assert cu_seqlens_q.size() == (batch_size + 1,) 

294 assert cu_seqlens_k.size() == (batch_size + 1,) 

295 

296 # Check output shape 

297 if out is not None: 

298 assert out.stride(-1) == 1 

299 assert out.dtype == q.dtype 

300 assert out.size() == (total_q, num_heads, head_size) 

301 

302 if seqused_k is not None: 

303 assert seqused_k.is_contiguous() 

304 assert seqused_k.size() == (batch_size,) 

305 

306 if max_seqlen_q == 1 and alibi_slopes is None: 

307 is_causal = False 

308 

309 if is_causal: 

310 window_size_right = 0 

311 

312 # check disable swa 

313 if window_size_left >= max_seqlen_k: 

314 window_size_left = -1 

315 if window_size_right >= max_seqlen_k: 

316 window_size_right = -1 

317 

318 is_local = window_size_left >= 0 

319 

320 # Optimize all single-query sequences by swapping the query-group and sequence dimensions 

321 seqlenq_ngroups_swapped = ( 

322 max_seqlen_q == 1 

323 and alibi_slopes is None 

324 and num_heads > num_heads_k 

325 and window_size_left < 0 

326 and window_size_right < 0 

327 and p_dropout == 0 

328 ) 

329 q_groups = num_heads // num_heads_k 

330 if seqlenq_ngroups_swapped: 

331 q = ( 

332 q.reshape((batch_size, num_heads_k, q_groups, head_size)) 

333 .transpose(1, 2) 

334 .reshape(batch_size * q_groups, num_heads_k, head_size) 

335 ) 

336 max_seqlen_q = q_groups 

337 num_heads = num_heads_k 

338 cu_seqlens_q = None 

339 q_batch_stride = q.stride(0) * max_seqlen_q 

340 k_batch_stride = k.stride(0) 

341 v_batch_stride = v.stride(0) 

342 # o_batch_stride = out.stride(0) * max_seqlen_q 

343 else: 

344 q_batch_stride = 0 

345 k_batch_stride = 0 

346 v_batch_stride = 0 

347 o_batch_stride = 0 

348 

349 total_q = q.size(0) 

350 

351 assert leftpad_k is None, "leftpad_k is not supported." 

352 assert ( 

353 head_size <= 256 

354 ), "FlashAttention forward only supports head dimension at most 256" 

355 assert ( 

356 head_size % 8 == 0 

357 ), "head_size must be a multiple of 8, this is ensured by padding!" 

358 assert ( 

359 num_heads % num_heads_k == 0 

360 ), "Number of heads in key/value must divide number of heads in query" 

361 

362 assert q.shape == (total_q, num_heads, head_size) 

363 assert k.shape == (num_pages, block_size, num_heads_k, head_size) 

364 assert v.shape == (num_pages, block_size, num_heads_k, head_size) 

365 assert k.stride() == v.stride() 

366 

367 if softcap > 0.0: 

368 assert p_dropout == 0, "dropout is not supported if softcap is used." 

369 

370 round_multiple = lambda x, m: (x + m - 1) // m * m 

371 head_size_rounded = round_multiple(head_size, 32) if head_size < 192 else 256 

372 seqlen_q_rounded = round_multiple(max_seqlen_q, 128) 

373 seqlen_k_rounded = round_multiple(max_seqlen_k, 32) 

374 

375 M_LOG2E = 1.4426950408889634074 

376 if softcap > 0.0: 

377 is_softcap = True 

378 adjusted_scale_softmax = softcap 

379 adjusted_softcap = softmax_scale / softcap 

380 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

381 else: 

382 is_softcap = False 

383 adjusted_softcap = 0.0 

384 adjusted_scale_softmax = softmax_scale 

385 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

386 

387 # Set alibi params 

388 if alibi_slopes is not None: 

389 assert alibi_slopes.device == q_device 

390 assert alibi_slopes.dtype in (torch.float,) 

391 assert alibi_slopes.stride(-1) == 1 

392 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

393 batch_size, 

394 num_heads, 

395 ) 

396 alibi_slopes_batch_stride = ( 

397 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

398 ) 

399 is_alibi = True 

400 else: 

401 alibi_slopes_batch_stride = 0 

402 is_alibi = False 

403 

404 # Prepare params to kernel 

405 with torch_device_fn.device(q_device): 

406 if out is not None: 

407 out_ = out 

408 if seqlenq_ngroups_swapped: 

409 out = torch.empty_like(q, dtype=v.dtype) 

410 else: 

411 out_ = None 

412 out = torch.empty_like(q, dtype=v.dtype) 

413 

414 if seqlenq_ngroups_swapped: 

415 o_batch_stride = out.stride(0) * max_seqlen_q 

416 

417 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device) 

418 

419 if p_dropout > 0: 

420 is_dropout = True 

421 increment = batch_size * num_heads * 32 

422 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

423 philox_args = torch.tensor( 

424 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

425 ) 

426 else: 

427 is_dropout = False 

428 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

429 

430 p_dropout = 1 - p_dropout 

431 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

432 rp_dropout = 1.0 / p_dropout 

433 

434 if return_softmax: 

435 assert is_dropout, "Only supported with non-zero dropout." 

436 p = torch.empty( 

437 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

438 device=q_device, 

439 ) 

440 else: 

441 p = torch.empty((), device=q_device) 

442 

443 if zero_tensors: 

444 out.zero_() 

445 lse.fill_(float("-inf")) 

446 

447 params = fwd_params( 

448 q, # q_ptr, 

449 k, # k_ptr, 

450 v, # v_ptr, 

451 out, # o_ptr, 

452 p, # p_ptr, 

453 lse, # softmax_lse_ptr, 

454 q.stride(-3), # q_row_stride, 

455 k.stride(-3), # k_row_stride, 

456 v.stride(-3), # v_row_stride, 

457 q.stride(-2), # q_head_stride, 

458 k.stride(-2), # k_head_stride, 

459 v.stride(-2), # v_head_stride, 

460 out.stride(-3), # o_row_stride, 

461 out.stride(-2), # o_head_stride, 

462 q_batch_stride, # q_batch_stride, 

463 k_batch_stride, # k_batch_stride, 

464 v_batch_stride, # v_batch_stride, 

465 o_batch_stride, # o_batch_stride, 

466 cu_seqlens_q is not None, # is_cu_seqlens_q, 

467 cu_seqlens_q, # cu_seqlens_q_ptr, 

468 seqused_k is None, # is_cu_seqlens_k, 

469 cu_seqlens_k, # cu_seqlens_k_ptr, 

470 seqused_k is not None, # is_seqused_k, 

471 seqused_k, # seqused_k_ptr, 

472 # sizes 

473 batch_size, # b, 

474 k_batch_size, # bk, 

475 num_heads, # h, 

476 num_heads_k, # hk, 

477 num_heads // num_heads_k, # h_hk_ratio, 

478 max_seqlen_q, # seqlen_q, 

479 max_seqlen_k, # seqlen_k, 

480 seqlen_q_rounded, # seqlen_q_rounded, 

481 seqlen_k_rounded, # seqlen_k_rounded, 

482 head_size, # d, 

483 head_size_rounded, # d_rounded, 

484 # scaling factors 

485 is_softcap, 

486 adjusted_softcap, # softcap, 

487 adjusted_scale_softmax, # scale_softmax, 

488 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

489 # dropout 

490 is_dropout, 

491 p_dropout, 

492 rp_dropout, 

493 p_dropout_in_uint8_t, 

494 philox_args, 

495 return_softmax, 

496 # causal and swa 

497 is_causal, # is_causal, 

498 is_local, # is_local, 

499 window_size_left, # window_size_left, 

500 window_size_right, # window_size_right, 

501 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

502 # alibi 

503 is_alibi, # 

504 alibi_slopes, # alibi_slopes_ptr, 

505 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

506 # block table params 

507 total_q, # total_q, 

508 page_table, # page_table_ptr, 

509 page_table_batch_stride, # page_table_batch_stride, 

510 block_size, # block_size, 

511 ) 

512 

513 logger.debug("kernel: flash_varlen_fwd") 

514 grid = lambda args: ( 

515 triton.cdiv(max_seqlen_q, args["BLOCK_M"]), 

516 batch_size, 

517 num_heads, 

518 ) 

519 kernel = flash_varlen_fwd_kernel[grid] 

520 args = tuple(getattr(params, k) for k in params.__slots__) 

521 

522 # We have to forego parameter autotuning and particularly fix BLOCK_N 

523 # to avoid breaking a kv block onto multiple cache pages. 

524 cfg = runtime.get_heuristic_config("mha_varlen_fwd") 

525 cfg_params = { 

526 "BLOCK_M": cfg["BLOCK_M"](params), 

527 "BLOCK_N": cfg["BLOCK_N"](params), 

528 "num_warps": cfg["num_warps"](params), 

529 "num_stages": cfg["num_stages"](params), 

530 } 

531 # BLOCK_M, BLOCK_N, num_warps, num_stages = 128, 32, 4, 3 

532 assert ( 

533 block_size % cfg_params["BLOCK_N"] == 0 

534 ), f"block_size must be divisible by {cfg_params['BLOCK_N']}." 

535 kernel(*args, **cfg_params) 

536 

537 if seqlenq_ngroups_swapped: 

538 out = out.reshape( 

539 batch_size, max_seqlen_q, num_heads_k, head_size 

540 ).transpose(1, 2) 

541 if out_ is not None: 

542 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out) 

543 out = out_ 

544 else: 

545 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size) 

546 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q) 

547 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size) 

548 

549 unused = torch.empty((), dtype=torch.int64, device=q_device) 

550 return out, q, k, v, lse, philox_args, unused, p 

551 

552 

553def mha_fwd( 

554 q, 

555 k, 

556 v, 

557 out, 

558 alibi_slopes, 

559 p_dropout, 

560 softmax_scale, 

561 is_causal, 

562 window_size_left, 

563 window_size_right, 

564 softcap, 

565 return_softmax, 

566 disable_splitkv=False, 

567): 

568 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v) 

569 q_dtype = q.dtype 

570 q_device = q.device 

571 assert q_dtype in ( 

572 torch.float16, 

573 torch.bfloat16, 

574 ), "FlashAttention only support fp16 and bf16 data type" 

575 assert q_dtype == k.dtype 

576 assert q_dtype == v.dtype 

577 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

578 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

579 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension" 

580 batch_size, seqlen_q, num_heads, head_size = q.size() 

581 _, seqlen_k, num_heads_k, _ = k.size() 

582 

583 # Check output shape 

584 if out is not None: 

585 assert out.stride(-1) == 1 

586 assert out.dtype == q.dtype 

587 assert out.size() == (batch_size, seqlen_q, num_heads, head_size) 

588 CHECK_DEVICE(out) 

589 

590 assert ( 

591 head_size % 8 == 0 

592 ), "head_size must be a multiple of 8, this is ensured by padding!" 

593 assert ( 

594 num_heads % num_heads_k == 0 

595 ), "Number of heads in key/value must divide number of heads in query" 

596 if window_size_left >= seqlen_k: 

597 window_size_left = -1 

598 if window_size_right >= seqlen_k: 

599 window_size_right = -1 

600 if seqlen_q == 1 and alibi_slopes is None: 

601 is_causal = False 

602 if is_causal: 

603 window_size_right = 0 

604 

605 is_causal = window_size_left < 0 and window_size_right == 0 

606 is_local = window_size_left >= 0 and window_size_right >= 0 

607 

608 seqlenq_ngroups_swapped = ( 

609 seqlen_q == 1 

610 and alibi_slopes is None 

611 and num_heads > num_heads_k 

612 and window_size_left < 0 

613 and window_size_right < 0 

614 and p_dropout == 0 

615 ) 

616 q_groups = num_heads // num_heads_k 

617 

618 if seqlenq_ngroups_swapped: 

619 logger.debug("q_kg swapped.") 

620 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2) 

621 seqlen_q = q_groups 

622 num_heads = num_heads_k 

623 

624 round_multiple = lambda x, m: (x + m - 1) // m * m 

625 head_size_rounded = round_multiple(head_size, 32) 

626 seqlen_q_rounded = round_multiple(seqlen_q, 128) 

627 seqlen_k_rounded = round_multiple(seqlen_k, 32) 

628 

629 def splits_heuristic(num_tasks, num_sms, n_blocks): 

630 # splits when wave efficiency is low 

631 n_waves = triton.cdiv(num_tasks, num_sms) 

632 eff = (num_tasks / num_sms) / n_waves 

633 if eff > 0.8 or n_waves > 1: 

634 return 1 

635 

636 min_blocks_per_split = 2 

637 best_splits = min( 

638 triton.cdiv(n_blocks, min_blocks_per_split), 

639 int(math.floor(1.0 / eff)), 

640 num_sms, 

641 ) 

642 

643 return best_splits 

644 

645 with torch_device_fn.device(q_device): 

646 # Set softmax params 

647 lse = torch.empty( 

648 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device 

649 ) 

650 

651 if out is not None: 

652 if seqlenq_ngroups_swapped: 

653 out = out.reshape( 

654 batch_size, num_heads_k, q_groups, head_size 

655 ).transpose(1, 2) 

656 else: 

657 out = torch.empty_like(q, dtype=v.dtype) 

658 

659 # Set dropout params 

660 if p_dropout > 0: 

661 is_dropout = True 

662 increment = batch_size * num_heads * 32 

663 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

664 philox_args = torch.tensor( 

665 [philox_seed, philox_offset], dtype=torch.int64, device=q_device 

666 ) 

667 else: 

668 is_dropout = False 

669 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device) 

670 

671 p_dropout = 1 - p_dropout 

672 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0) 

673 rp_dropout = 1.0 / p_dropout 

674 

675 if return_softmax: 

676 assert is_dropout, "Only supported with non-zero dropout." 

677 p = torch.empty( 

678 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

679 device=q_device, 

680 ) 

681 else: 

682 p = torch.empty((), device=q_device) 

683 

684 M_LOG2E = 1.4426950408889634074 

685 if softcap > 0.0: 

686 is_softcap = True 

687 adjusted_scale_softmax = softcap 

688 adjusted_softcap = softmax_scale / softcap 

689 adjusted_scale_softmax_log2e = softcap * M_LOG2E 

690 else: 

691 is_softcap = False 

692 adjusted_softcap = 0.0 

693 adjusted_scale_softmax = softmax_scale 

694 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E 

695 

696 # Set alibi params 

697 if alibi_slopes is not None: 

698 assert alibi_slopes.device == q_device 

699 assert alibi_slopes.dtype in (torch.float,) 

700 assert alibi_slopes.stride(-1) == 1 

701 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == ( 

702 batch_size, 

703 num_heads, 

704 ) 

705 alibi_slopes_batch_stride = ( 

706 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0 

707 ) 

708 is_alibi = True 

709 else: 

710 alibi_slopes_batch_stride = 0 

711 is_alibi = False 

712 

713 # ONLY EVEN_K IS SUPPORTED 

714 assert head_size == head_size_rounded 

715 

716 # Do kernel dispatching 

717 def dispatch(B, H, Q, K, D, params): 

718 num_sms = TOTAL_CORE_NUM 

719 

720 # Try bh parallel 

721 # if B * H > 0.8 * num_sms: 

722 # kernel = flash_fwd_bh_parallel_kernel[(H, B)] 

723 # # Yield kernel and prefilled args 

724 # return kernel, default_args, None, None 

725 

726 # Try splitkv 

727 if not is_dropout and not is_local and not disable_splitkv: 

728 BM = block_m_splitkv_heuristic(D) 

729 n_tasks = B * H * triton.cdiv(seqlen_q, BM) 

730 BN = block_n_splitkv_heuristic(D) 

731 n_blocks = triton.cdiv(seqlen_k, BN) 

732 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks) 

733 

734 # if _debug: 

735 # n_splits = 32 

736 # n_blocks = triton.cdiv(K, BN) 

737 # blocks_per_split = triton.cdiv(n_blocks, n_splits) 

738 # print("block_n:", BN) 

739 # print("n_splits:", n_splits) 

740 # print("blocks_per_split", blocks_per_split) 

741 

742 if n_splits > 1: 

743 logger.debug("kernel: flash_fwd_splitkv") 

744 lse_splits = torch.empty( 

745 (n_splits, B, H, Q), dtype=torch.float, device=q_device 

746 ) 

747 out_splits = torch.empty( 

748 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device 

749 ) 

750 grid = lambda args: ( 

751 triton.cdiv(Q, args["BLOCK_M"]), 

752 n_splits, 

753 B * H, 

754 ) 

755 splitkv_kernel = flash_fwd_splitkv_kernel[grid] 

756 params.o_ptr = out_splits 

757 params.softmax_lse_ptr = lse_splits 

758 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)} 

759 kernel = splitkv_kernel(*params.args(), **extra_args) 

760 

761 if D % 128 == 0: 

762 BLOCK_M = 4 

763 elif D % 64 == 0: 

764 BLOCK_M = 8 

765 else: 

766 BLOCK_M = 16 

767 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),) 

768 combine_kernel = flash_fwd_splitkv_combine_kernel[grid] 

769 combine_args = { 

770 "out_ptr": out, 

771 "lse_ptr": lse, 

772 "head_size": head_size, 

773 "out_b_stride": out.stride(0), 

774 "out_s_stride": out.stride(-3), 

775 "out_h_stride": out.stride(-1), 

776 "out_splits_ptr": out_splits, 

777 "lse_splits_ptr": lse_splits, 

778 "n_splits": n_splits, 

779 "BLOCK_M": BLOCK_M, 

780 "q_total": B * H * Q, 

781 "MAX_N_SPLITS": triton.next_power_of_2(n_splits), 

782 } 

783 combine_kernel(**combine_args) 

784 return kernel 

785 

786 # Last option: flash_fwd 

787 logger.debug("kernel: flash_fwd") 

788 grid = lambda args: ( 

789 triton.cdiv(Q, args["BLOCK_M"]), 

790 H * B, 

791 ) 

792 kernel = flash_fwd_kernel[grid] 

793 kernel = kernel(*params.args()) 

794 return kernel 

795 

796 if _debug: 

797 p = torch.empty( 

798 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), 

799 dtype=torch.float32, 

800 device=q_device, 

801 ) 

802 return_softmax = True 

803 

804 params = fwd_params( 

805 q, # q_ptr, 

806 k, # k_ptr, 

807 v, # v_ptr, 

808 out, # o_ptr, 

809 p, # p_ptr, 

810 lse, # softmax_lse_ptr, 

811 q.stride(-3), # q_row_stride, 

812 k.stride(-3), # k_row_stride, 

813 v.stride(-3), # v_row_stride, 

814 q.stride(-2), # q_head_stride, 

815 k.stride(-2), # k_head_stride, 

816 v.stride(-2), # v_head_stride, 

817 out.stride(-3), # o_row_stride, 

818 out.stride(-2), # o_head_stride, 

819 q.stride(0), # q_batch_stride, 

820 k.stride(0), # k_batch_stride, 

821 v.stride(0), # v_batch_stride, 

822 out.stride(0), # o_batch_stride, 

823 False, # is_cu_seqlens_q, 

824 None, # cu_seqlens_q_ptr, 

825 False, # is_cu_seqlens_k, 

826 None, # cu_seqlens_k_ptr, 

827 False, # is_seqused_k, 

828 None, # seqused_k_ptr, 

829 # sizes 

830 batch_size, # b, 

831 0, # bk, 

832 num_heads, # h, 

833 num_heads_k, # hk, 

834 num_heads // num_heads_k, # h_hk_ratio, 

835 seqlen_q, # seqlen_q, 

836 seqlen_k, # seqlen_k, 

837 seqlen_q_rounded, # seqlen_q_rounded, 

838 seqlen_k_rounded, # seqlen_k_rounded, 

839 head_size, # d, 

840 head_size_rounded, # d_rounded, 

841 # scaling factors 

842 is_softcap, 

843 adjusted_softcap, # softcap, 

844 adjusted_scale_softmax, # scale_softmax, 

845 adjusted_scale_softmax_log2e, # scale_softmax_log2, 

846 # dropout 

847 is_dropout, 

848 p_dropout, 

849 rp_dropout, 

850 p_dropout_in_uint8_t, 

851 philox_args, 

852 return_softmax, 

853 # causal and swa 

854 is_causal, # is_causal, 

855 is_local, # is_local, 

856 window_size_left, # window_size_left, 

857 window_size_right, # window_size_right, 

858 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped, 

859 # alibi 

860 is_alibi, # 

861 alibi_slopes, # alibi_slopes_ptr, 

862 alibi_slopes_batch_stride, # alibi_slopes_batch_stride, 

863 # block table params 

864 0, # total_q, 

865 None, # page_table_ptr, 

866 0, # page_table_batch_stride, 

867 0, # block_size, 

868 ) 

869 

870 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params) 

871 

872 if _debug: 

873 print(f"{kernel.name} shared memory:", kernel.metadata.shared) 

874 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps) 

875 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages) 

876 # print(kernel.asm['ttgir']) 

877 

878 if seqlenq_ngroups_swapped: 

879 out = out.transpose(1, 2).reshape( 

880 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

881 ) 

882 q = q.transpose(1, 2).reshape( 

883 (batch_size, 1, num_heads_k * seqlen_q, head_size) 

884 ) 

885 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1)) 

886 

887 unused = torch.empty((), dtype=torch.int64, device=q_device) 

888 

889 return out, q, k, v, lse, philox_args, unused, p