Coverage for src/flag_gems/ops/flash_kernel.py: 13%

574 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import triton 

2import triton.language as tl 

3 

4from flag_gems import runtime 

5from flag_gems.utils import libentry, tl_extra_shim 

6 

7 

8@triton.jit 

9def u64_to_lohi(x): 

10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) 

11 

12 

13@triton.jit 

14def u64_from_lohi(lo, hi): 

15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) 

16 

17 

18@triton.jit 

19def philox_(seed, subsequence, offset): 

20 kPhilox10A: tl.constexpr = 0x9E3779B9 

21 kPhilox10B: tl.constexpr = 0xBB67AE85 

22 k0, k1 = u64_to_lohi(seed.to(tl.uint64)) 

23 c0, c1 = u64_to_lohi(offset.to(tl.uint64)) 

24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64)) 

25 

26 # pragma unroll 

27 kPhiloxSA: tl.constexpr = 0xD2511F53 

28 kPhiloxSB: tl.constexpr = 0xCD9E8D57 

29 for _ in tl.static_range(6): 

30 res0 = kPhiloxSA * c0.to(tl.uint64) 

31 res1 = kPhiloxSB * c2.to(tl.uint64) 

32 res0_x, res0_y = u64_to_lohi(res0) 

33 res1_x, res1_y = u64_to_lohi(res1) 

34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

35 k0 += kPhilox10A 

36 k1 += kPhilox10B 

37 

38 res0 = kPhiloxSA * c0.to(tl.uint64) 

39 res1 = kPhiloxSB * c2.to(tl.uint64) 

40 res0_x, res0_y = u64_to_lohi(res0) 

41 res1_x, res1_y = u64_to_lohi(res1) 

42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

43 

44 return c0, c1, c2, c3 

45 

46 

47@triton.jit 

48def apply_dropout_mask( 

49 P, 

50 mask, 

51 encode_dropout_in_sign_bit: tl.constexpr, 

52): 

53 if encode_dropout_in_sign_bit: 

54 P = tl.where(mask, -P, P) 

55 else: 

56 P = tl.where(mask, (P * 0).to(P.dtype), P) 

57 return P 

58 

59 

60@triton.jit 

61def apply_dropout( 

62 P, 

63 row_start, 

64 col_start, 

65 n_cols, 

66 bid, 

67 hid, 

68 philox_seed, 

69 philox_offset, 

70 p_dropout_uint8: tl.constexpr, 

71 is_dropout: tl.constexpr, 

72 encode_dropout_in_sign_bit: tl.constexpr, 

73 NUM_HEADS: tl.constexpr, 

74 BLOCK_M: tl.constexpr, 

75 BLOCK_N: tl.constexpr, 

76): 

77 if is_dropout: 

78 row_start = tl.multiple_of(row_start, BLOCK_M) 

79 col_start = tl.multiple_of(col_start, BLOCK_N) 

80 row = row_start + tl.arange(0, BLOCK_M)[:, None] 

81 # Down scale col_idx by 4 

82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :] 

83 

84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64) 

85 

86 offset = philox_offset + bid * NUM_HEADS + hid 

87 offset += subsequence * 0 

88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset) 

89 

90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N) 

91 

92 mask = (r & 0xFF) >= p_dropout_uint8 

93 

94 P = apply_dropout_mask( 

95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit 

96 ) 

97 return P 

98 

99 

100@triton.jit 

101def apply_alibi( 

102 S, 

103 col_idx, 

104 row_idx, 

105 max_seqlen_q, 

106 max_seqlen_k, 

107 is_causal: tl.constexpr, 

108 is_alibi: tl.constexpr, 

109 alibi_slope: tl.constexpr = None, 

110): 

111 if is_alibi: 

112 if is_causal: 

113 # The row independent alibi bias renders the same attention output 

114 # as with the standard alibi because softmax is shift invariant, i.e., 

115 # softmax(A + bias + const) = softamx(A + bias). The following two 

116 # biases are no different if causal is true. 

117 # bias_1 = [ 

118 # -4, -3, -2, X, X, 

119 # -4, -3, -2, -1, X, 

120 # -4, -3, -2, -1, 0, 

121 # ] 

122 # bias_2 = [ 

123 # -2, -1, 0, X, X, 

124 # -3, -2, -1, 0, X, 

125 # -4, -3, -2, -1, 0, 

126 # ] 

127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32) 

128 S += bias 

129 else: 

130 bias = -alibi_slope * tl.abs( 

131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None] 

132 ).to(tl.float32) 

133 S += bias 

134 

135 return S 

136 

137 

138@triton.jit 

139def apply_mask( 

140 S, 

141 col_idx, 

142 row_idx, 

143 max_seqlen_q, 

144 max_seqlen_k, 

145 window_size_left, 

146 window_size_right, 

147 is_even_mn: tl.constexpr, 

148 is_causal: tl.constexpr, 

149 is_local: tl.constexpr, 

150): 

151 need_mask = is_causal | is_local | (not is_even_mn) 

152 # need_mask: tl.constexpr = is_causal | is_local 

153 if need_mask: 

154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive! 

155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left) 

156 col_rb = min( 

157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right 

158 ) 

159 

160 if is_causal: 

161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S) 

162 

163 if is_local: 

164 S = tl.where( 

165 (col_idx[None, :] > col_rb[:, None]) 

166 | (col_idx[None, :] < col_lb[:, None]), 

167 float("-inf"), 

168 S, 

169 ) 

170 

171 if (not is_local) & (not is_causal) & (not is_even_mn): 

172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S) 

173 

174 return S 

175 

176 

177@triton.jit 

178def softmax_rescale( 

179 O_acc, 

180 S, 

181 row_max, 

182 row_sum, 

183 softmax_scale_log2e: tl.constexpr, 

184 is_border: tl.constexpr, 

185 # is_init: tl.constexpr 

186): 

187 prev_max = row_max 

188 row_max = tl.maximum(row_max, tl.max(S, 1)) 

189 

190 if is_border: 

191 cur_max = tl.where(row_max == float("-inf"), 0, row_max) 

192 else: 

193 cur_max = row_max 

194 

195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) 

196 row_sum *= p_scale 

197 O_acc *= p_scale[:, None] 

198 

199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e) 

200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) 

201 row_sum = row_sum + tl.sum(P, 1) 

202 return O_acc, P, row_max, row_sum 

203 

204 

205@triton.jit 

206def apply_softcap(S, softcap, is_softcap: tl.constexpr): 

207 if is_softcap: 

208 S = tl_extra_shim.tanh(S * softcap) 

209 

210 return S 

211 

212 

213def block_m_splitkv_heuristic(headdim): 

214 return 128 if headdim <= 128 else 64 

215 

216 

217def block_n_splitkv_heuristic(headdim): 

218 return 64 if headdim <= 64 else 32 

219 

220 

221def is_even_mn(M, N, BM, BN, WL, WR): 

222 if M % BM == 0 and N % BN == 0: 

223 if M % N == 0 or N % M == 0: 

224 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0): 

225 return True 

226 return False 

227 

228 

229def block_m_splitkv_heuristic_spec_args(args): 

230 return 128 if args["d"] <= 128 else 64 

231 

232 

233def block_n_splitkv_heuristic_spec_args(args): 

234 return 64 if args["d"] <= 64 else 32 

235 

236 

237def is_even_mn_spec_args(args): 

238 if ( 

239 args["seqlen_q"] % args["BLOCK_M"] == 0 

240 and args["seqlen_k"] % args["BLOCK_N"] == 0 

241 ): 

242 if ( 

243 args["seqlen_q"] % args["seqlen_k"] == 0 

244 or args["seqlen_k"] % args["seqlen_q"] == 0 

245 ): 

246 if ( 

247 args["window_size_left"] == -1 

248 or args["window_size_left"] % args["BLOCK_N"] == 0 

249 ) and ( 

250 args["window_size_right"] == -1 

251 or args["window_size_right"] % args["BLOCK_N"] == 0 

252 ): 

253 return True 

254 return False 

255 

256 

257def keep(cfg, must_keep=None): 

258 BM = cfg.kwargs["BLOCK_M"] 

259 BN = cfg.kwargs["BLOCK_N"] 

260 w = cfg.num_warps 

261 

262 # we always keep configurations in `must_keep` 

263 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or ( 

264 must_keep and cfg in must_keep 

265 ) 

266 

267 

268def prune_fwd_configs(configs, nargs, **kwargs): 

269 is_dropout = nargs["is_dropout"] 

270 if is_dropout: 

271 return list( 

272 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs) 

273 ) 

274 else: 

275 return configs 

276 

277 

278def flash_fwd_kernel_heur_block_k(args): 

279 return triton.next_power_of_2(args["d"]) 

280 

281 

282@libentry() 

283@triton.autotune( 

284 configs=list(filter(keep, runtime.get_tuned_config("attention"))), 

285 prune_configs_by={"early_config_prune": prune_fwd_configs}, 

286 key=["d", "is_dropout"], 

287) 

288@triton.heuristics( 

289 values={ 

290 "BLOCK_K": flash_fwd_kernel_heur_block_k, 

291 "PRE_LOAD_V": lambda args: False, 

292 "IS_EVEN_MN": lambda args: is_even_mn( 

293 args["seqlen_q"], 

294 args["seqlen_k"], 

295 args["BLOCK_M"], 

296 args["BLOCK_N"], 

297 args["window_size_left"], 

298 args["window_size_right"], 

299 ), 

300 } 

301) 

302@triton.jit( 

303 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

304) 

305def flash_fwd_kernel( 

306 q_ptr, 

307 k_ptr, 

308 v_ptr, 

309 o_ptr, 

310 p_ptr, 

311 softmax_lse_ptr, 

312 q_row_stride, 

313 k_row_stride, 

314 v_row_stride, 

315 q_head_stride, 

316 k_head_stride, 

317 v_head_stride, 

318 o_row_stride, 

319 o_head_stride, 

320 q_batch_stride, 

321 k_batch_stride, 

322 v_batch_stride, 

323 o_batch_stride, 

324 is_cu_seqlens_q, 

325 cu_seqlens_q_ptr, 

326 is_cu_seqlens_k, 

327 cu_seqlens_k_ptr, 

328 is_seqused_k, 

329 seqused_k_ptr, 

330 # sizes 

331 b: tl.constexpr, 

332 bk: tl.constexpr, 

333 h: tl.constexpr, 

334 hk: tl.constexpr, 

335 h_hk_ratio: tl.constexpr, 

336 seqlen_q, 

337 seqlen_k, 

338 seqlen_q_rounded, 

339 seqlen_k_rounded, 

340 d: tl.constexpr, 

341 d_rounded: tl.constexpr, 

342 # scaling factors 

343 is_softcap: tl.constexpr, 

344 softcap: tl.constexpr, 

345 scale_softmax: tl.constexpr, 

346 scale_softmax_log2: tl.constexpr, 

347 # dropout 

348 is_dropout: tl.constexpr, 

349 p_dropout: tl.constexpr, 

350 rp_dropout: tl.constexpr, 

351 p_dropout_in_uint8_t: tl.constexpr, 

352 philox_args, 

353 return_softmax: tl.constexpr, 

354 # causal and swa 

355 is_causal: tl.constexpr, 

356 is_local: tl.constexpr, 

357 window_size_left: tl.constexpr, 

358 window_size_right: tl.constexpr, 

359 seqlenq_ngroups_swapped: tl.constexpr, 

360 is_paged: tl.constexpr, 

361 # alibi 

362 is_alibi: tl.constexpr, 

363 alibi_slopes_ptr, 

364 alibi_slopes_batch_stride: tl.constexpr, 

365 # block table 

366 total_q: tl.constexpr, 

367 page_table_ptr, 

368 page_table_batch_stride: tl.constexpr, 

369 block_size: tl.constexpr, 

370 k_page_stride: tl.constexpr, 

371 # kernel params 

372 IS_EVEN_MN: tl.constexpr, 

373 PRE_LOAD_V: tl.constexpr, 

374 BLOCK_M: tl.constexpr, 

375 BLOCK_N: tl.constexpr, 

376 BLOCK_K: tl.constexpr, 

377 num_warps: tl.constexpr, 

378 num_stages: tl.constexpr, 

379): 

380 m_block = tl.program_id(0) 

381 bh = tl.program_id(1) 

382 hid = bh % h 

383 bid = bh // h 

384 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

385 

386 # We draw a minimum covering frame on the attention map that this CTA is assigned to process. 

387 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. 

388 

389 col_min = 0 

390 if is_local: 

391 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left) 

392 if not IS_EVEN_MN: 

393 # round left 

394 col_min = (col_min // BLOCK_N) * BLOCK_N 

395 

396 col_max = seqlen_k 

397 if is_causal or is_local: 

398 col_max += (m_block - num_m_blocks + 1) * BLOCK_M 

399 if is_local: 

400 col_max += window_size_right 

401 col_max = min(seqlen_k, col_max) 

402 

403 if not IS_EVEN_MN: 

404 # round right 

405 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N 

406 

407 if (not is_causal) and (not is_local): 

408 if IS_EVEN_MN: 

409 masking_cols: tl.constexpr = 0 

410 else: 

411 masking_cols: tl.constexpr = BLOCK_N 

412 elif ( 

413 is_causal | is_local 

414 ) and IS_EVEN_MN: # causal implies window_size_right is zero 

415 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N 

416 else: 

417 # local 

418 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N 

419 

420 if is_dropout: 

421 philox_seed = tl.load(philox_args).to(tl.uint64) 

422 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

423 

424 if is_alibi: 

425 alibi_offset = bid * alibi_slopes_batch_stride + hid 

426 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

427 alibi_slope /= scale_softmax 

428 else: 

429 alibi_slope = 0.0 

430 

431 q_batch_stride = tl.multiple_of(q_batch_stride, d * h) 

432 q_ptr += bid * q_batch_stride + hid * q_head_stride 

433 row_start = m_block * BLOCK_M 

434 row_idx = row_start + tl.arange(0, BLOCK_M) 

435 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :] 

436 dmask = tl.arange(0, BLOCK_K) < d 

437 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q) 

438 if IS_EVEN_MN & d == BLOCK_K: 

439 Q = tl.load(q_ptr + q_off, cache_modifier=".cg") 

440 else: 

441 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg") 

442 

443 if return_softmax: 

444 p_ptr += ( 

445 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M 

446 ) * seqlen_k_rounded 

447 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange( 

448 0, BLOCK_N 

449 ) 

450 p_bp0 = p_ptr + p_offset 

451 

452 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

453 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

454 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

455 

456 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk) 

457 h_hk_ratio = h // hk 

458 k_ptr += bid * k_batch_stride 

459 k_ptr += (hid // h_hk_ratio) * k_head_stride 

460 v_ptr += bid * k_batch_stride 

461 v_ptr += (hid // h_hk_ratio) * k_head_stride 

462 

463 k_offset = ( 

464 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None] 

465 ) 

466 v_offset = ( 

467 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :] 

468 ) 

469 

470 p_bk0 = k_ptr + k_offset 

471 p_bv0 = v_ptr + v_offset 

472 

473 if is_causal | is_local | (not IS_EVEN_MN): 

474 # Cut short masking cols if there's not enough cols out there 

475 masking_cols = min(col_max - col_min, masking_cols) 

476 for col_shift in tl.range(0, masking_cols, step=BLOCK_N): 

477 col_start = col_max - col_shift - BLOCK_N 

478 col_start = tl.multiple_of(col_start, BLOCK_N) 

479 off = col_start * k_row_stride 

480 if IS_EVEN_MN & d == BLOCK_K: 

481 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

482 if PRE_LOAD_V: 

483 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

484 elif d == BLOCK_K: 

485 col_idx = col_start + tl.arange(0, BLOCK_N) 

486 kvmask = col_idx < seqlen_k 

487 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") 

488 if PRE_LOAD_V: 

489 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

490 else: 

491 col_idx = col_start + tl.arange(0, BLOCK_N) 

492 kvmask = col_idx < seqlen_k 

493 K = tl.load( 

494 p_bk0 + off, 

495 mask=kvmask[None, :] & dmask[:, None], 

496 cache_modifier=".cg", 

497 ) 

498 if PRE_LOAD_V: 

499 V = tl.load( 

500 p_bv0 + off, 

501 mask=kvmask[:, None] & dmask[None, :], 

502 cache_modifier=".cg", 

503 ) 

504 S = tl.dot(Q, K, allow_tf32=False) 

505 S = apply_softcap(S, softcap, is_softcap) 

506 col_idx = col_start + tl.arange(0, BLOCK_N) 

507 row_idx = row_start + tl.arange(0, BLOCK_M) 

508 S = apply_alibi( 

509 S, 

510 col_idx, 

511 row_idx, 

512 seqlen_q, 

513 seqlen_k, 

514 is_causal=is_causal, 

515 is_alibi=is_alibi, 

516 alibi_slope=alibi_slope, 

517 ) 

518 # tl.store(p_bp0 + col_start, S) 

519 S = apply_mask( 

520 S, 

521 col_idx, 

522 row_idx, 

523 seqlen_q, 

524 seqlen_k, 

525 window_size_left, 

526 window_size_right, 

527 is_even_mn=IS_EVEN_MN, 

528 is_causal=is_causal, 

529 is_local=is_local, 

530 ) 

531 

532 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

533 acc_, 

534 S, 

535 rowmax_, 

536 rowsum_, 

537 softmax_scale_log2e=scale_softmax_log2, 

538 is_border=(is_causal or is_local), 

539 ) 

540 P = P.to(v_ptr.type.element_ty) 

541 

542 if is_dropout: 

543 if return_softmax: 

544 P_drop = P 

545 

546 P_drop = apply_dropout( 

547 P_drop, 

548 row_start, 

549 col_start, 

550 seqlen_k, 

551 bid, 

552 hid, 

553 philox_seed, 

554 philox_offset, 

555 p_dropout_in_uint8_t, 

556 is_dropout, 

557 encode_dropout_in_sign_bit=True, 

558 NUM_HEADS=h, 

559 BLOCK_M=BLOCK_M, 

560 BLOCK_N=BLOCK_N, 

561 ) 

562 if IS_EVEN_MN: 

563 tl.store(p_bp0 + col_start, P_drop) 

564 else: 

565 kvmask = col_idx < seqlen_k 

566 tl.store( 

567 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :] 

568 ) 

569 

570 P = apply_dropout( 

571 P, 

572 row_start, 

573 col_start, 

574 seqlen_k, 

575 bid, 

576 hid, 

577 philox_seed, 

578 philox_offset, 

579 p_dropout_in_uint8_t, 

580 is_dropout, 

581 encode_dropout_in_sign_bit=False, 

582 NUM_HEADS=h, 

583 BLOCK_M=BLOCK_M, 

584 BLOCK_N=BLOCK_N, 

585 ) 

586 

587 if not PRE_LOAD_V: 

588 off = col_start * k_row_stride 

589 if IS_EVEN_MN & d == BLOCK_K: 

590 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

591 elif d == BLOCK_K: 

592 kvmask = col_idx < seqlen_k 

593 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

594 else: 

595 kvmask = col_idx < seqlen_k 

596 V = tl.load( 

597 p_bv0 + off, 

598 mask=kvmask[:, None] & dmask[None, :], 

599 cache_modifier=".cg", 

600 ) 

601 acc_ = tl.dot(P, V, acc_, allow_tf32=False) 

602 

603 for col_start in tl.range( 

604 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages 

605 ): 

606 col_start = tl.multiple_of(col_start, BLOCK_N) 

607 off = col_start * k_row_stride 

608 if d == BLOCK_K: 

609 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

610 if PRE_LOAD_V: 

611 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

612 else: 

613 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg") 

614 if PRE_LOAD_V: 

615 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg") 

616 

617 S = tl.dot(Q, K) 

618 S = apply_softcap(S, softcap, is_softcap) 

619 col_idx = col_start + tl.arange(0, BLOCK_N) 

620 row_idx = row_start + tl.arange(0, BLOCK_M) 

621 S = apply_alibi( 

622 S, 

623 col_idx, 

624 row_idx, 

625 seqlen_q, 

626 seqlen_k, 

627 is_causal=is_causal, 

628 is_alibi=is_alibi, 

629 alibi_slope=alibi_slope, 

630 ) 

631 S = apply_mask( 

632 S, 

633 col_idx, 

634 row_idx, 

635 seqlen_q, 

636 seqlen_k, 

637 window_size_left, 

638 window_size_right, 

639 is_even_mn=True, 

640 is_causal=False, 

641 is_local=is_local, 

642 ) 

643 

644 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

645 acc_, 

646 S, 

647 rowmax_, 

648 rowsum_, 

649 softmax_scale_log2e=scale_softmax_log2, 

650 is_border=is_local, 

651 ) 

652 P = P.to(v_ptr.type.element_ty) 

653 

654 if is_dropout: 

655 if return_softmax: 

656 P_drop = P 

657 P_drop = apply_dropout( 

658 P_drop, 

659 row_start, 

660 col_start, 

661 seqlen_k, 

662 bid, 

663 hid, 

664 philox_seed, 

665 philox_offset, 

666 p_dropout_in_uint8_t, 

667 is_dropout, 

668 encode_dropout_in_sign_bit=True, 

669 NUM_HEADS=h, 

670 BLOCK_M=BLOCK_M, 

671 BLOCK_N=BLOCK_N, 

672 ) 

673 if IS_EVEN_MN: 

674 tl.store(p_bp0 + col_start, P_drop) 

675 else: 

676 kvmask = col_idx < seqlen_k 

677 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) 

678 

679 P = apply_dropout( 

680 P, 

681 row_start, 

682 col_start, 

683 seqlen_k, 

684 bid, 

685 hid, 

686 philox_seed, 

687 philox_offset, 

688 p_dropout_in_uint8_t, 

689 is_dropout, 

690 encode_dropout_in_sign_bit=False, 

691 NUM_HEADS=h, 

692 BLOCK_M=BLOCK_M, 

693 BLOCK_N=BLOCK_N, 

694 ) 

695 

696 if not PRE_LOAD_V: 

697 off = col_start * k_row_stride 

698 if d == BLOCK_K: 

699 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

700 else: 

701 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg") 

702 acc_ = tl.dot(P, V, acc_) 

703 

704 # LSE 

705 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels 

706 # the effect of rowmax and outputs lse only. 

707 lse = tl.where( 

708 rowsum_ == 0 | (rowsum_ != rowsum_), 

709 float("inf"), 

710 rowmax_ * scale_softmax + tl.log(rowsum_), 

711 ) 

712 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

713 

714 if is_dropout: 

715 acc_ *= inv_sum[:, None] * rp_dropout 

716 else: 

717 acc_ *= inv_sum[:, None] 

718 

719 out = acc_.to(o_ptr.type.element_ty) # noqa 

720 

721 # Write back output 

722 o_batch_stride = tl.multiple_of(o_batch_stride, d * h) 

723 o_ptr += bid * o_batch_stride 

724 o_ptr += hid * o_head_stride 

725 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K) 

726 

727 if IS_EVEN_MN & d == BLOCK_K: 

728 tl.store(o_ptr + o_offset, out) 

729 else: 

730 tl.store(o_ptr + o_offset, out, mask=qmask) 

731 

732 # Write back lse 

733 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q 

734 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

735 

736 if IS_EVEN_MN: 

737 tl.store(p_lse + row_idx, lse) 

738 else: 

739 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q) 

740 

741 

742@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"]) 

743def flash_fwd_bh_parallel_kernel(): 

744 # (TODO) 

745 pass 

746 

747 

748def flash_fwd_splitkv_kernel_heur_block_k(args): 

749 return triton.next_power_of_2(args["d"]) 

750 

751 

752@libentry() 

753@triton.heuristics( 

754 values={ 

755 "BLOCK_M": block_m_splitkv_heuristic_spec_args, 

756 "BLOCK_N": block_n_splitkv_heuristic_spec_args, 

757 "BLOCK_K": flash_fwd_splitkv_kernel_heur_block_k, 

758 "num_warps": lambda args: 4, 

759 "num_stages": lambda args: 3, 

760 "PRE_LOAD_V": lambda args: True, 

761 "IS_EVEN_MN": is_even_mn_spec_args, 

762 } 

763) 

764@triton.jit( 

765 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

766) 

767def flash_fwd_splitkv_kernel( 

768 q_ptr, 

769 k_ptr, 

770 v_ptr, 

771 o_ptr, 

772 p_ptr, 

773 softmax_lse_ptr, 

774 q_row_stride, 

775 k_row_stride, 

776 v_row_stride, 

777 q_head_stride, 

778 k_head_stride, 

779 v_head_stride, 

780 o_row_stride, 

781 o_head_stride, 

782 q_batch_stride, 

783 k_batch_stride, 

784 v_batch_stride, 

785 o_batch_stride, 

786 is_cu_seqlens_q, 

787 cu_seqlens_q_ptr, 

788 is_cu_seqlens_k: tl.constexpr, 

789 cu_seqlens_k_ptr, 

790 is_seqused_k: tl.constexpr, 

791 seqused_k_ptr, 

792 # sizes 

793 b: tl.constexpr, 

794 bk: tl.constexpr, 

795 h: tl.constexpr, 

796 hk: tl.constexpr, 

797 h_hk_ratio: tl.constexpr, 

798 seqlen_q, 

799 seqlen_k, 

800 seqlen_q_rounded, 

801 seqlen_k_rounded, 

802 d: tl.constexpr, 

803 d_rounded: tl.constexpr, 

804 # scaling factors 

805 is_softcap: tl.constexpr, 

806 softcap: tl.constexpr, 

807 scale_softmax: tl.constexpr, 

808 scale_softmax_log2: tl.constexpr, 

809 # dropout 

810 is_dropout: tl.constexpr, 

811 p_dropout: tl.constexpr, 

812 rp_dropout: tl.constexpr, 

813 p_dropout_in_uint8_t: tl.constexpr, 

814 philox_args, 

815 return_softmax: tl.constexpr, 

816 # causal and swa 

817 is_causal: tl.constexpr, 

818 is_local: tl.constexpr, 

819 window_size_left: tl.constexpr, 

820 window_size_right: tl.constexpr, 

821 seqlenq_ngroups_swapped: tl.constexpr, 

822 is_paged: tl.constexpr, 

823 # alibi 

824 is_alibi: tl.constexpr, 

825 alibi_slopes_ptr, 

826 alibi_slopes_batch_stride: tl.constexpr, 

827 # block table 

828 total_q, 

829 page_table_ptr, 

830 page_table_batch_stride: tl.constexpr, 

831 block_size: tl.constexpr, 

832 k_page_stride: tl.constexpr, 

833 # kernel params 

834 IS_EVEN_MN: tl.constexpr, 

835 PRE_LOAD_V: tl.constexpr, 

836 blocks_per_split: tl.constexpr, 

837 BLOCK_M: tl.constexpr, 

838 BLOCK_N: tl.constexpr, 

839 BLOCK_K: tl.constexpr, 

840 num_warps: tl.constexpr, 

841 num_stages: tl.constexpr, 

842): 

843 m_block = tl.program_id(0) 

844 split_id = tl.program_id(1) 

845 bid = tl.program_id(2) // h 

846 hid = tl.program_id(2) % h 

847 

848 split_block_min = split_id * blocks_per_split 

849 split_block_max = split_block_min + blocks_per_split 

850 

851 n_block_max = tl.cdiv(seqlen_k, BLOCK_N) 

852 if is_causal: 

853 n_block_max = min( 

854 n_block_max, 

855 tl.cdiv( 

856 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, 

857 BLOCK_N, 

858 ), 

859 ) 

860 

861 if is_alibi: 

862 alibi_offset = bid * alibi_slopes_batch_stride + hid 

863 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

864 alibi_slope /= scale_softmax 

865 else: 

866 alibi_slope = 0 

867 

868 if not is_causal: 

869 if IS_EVEN_MN: 

870 masking_block_min = n_block_max 

871 else: 

872 masking_block_min = n_block_max - 1 

873 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero 

874 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) 

875 else: 

876 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 

877 

878 q_ptr += bid * q_batch_stride 

879 q_ptr += hid * q_head_stride 

880 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

881 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :] 

882 p_qm = q_ptr + q_off 

883 dmask = tl.arange(0, BLOCK_K) < d 

884 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q) 

885 if IS_EVEN_MN & BLOCK_K == d: 

886 Q = tl.load(p_qm, cache_modifier=".cg") 

887 else: 

888 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") 

889 

890 h_hk_ratio = h // hk 

891 k_ptr += bid * k_batch_stride 

892 k_ptr += (hid // h_hk_ratio) * k_head_stride 

893 v_ptr += bid * k_batch_stride 

894 v_ptr += (hid // h_hk_ratio) * k_head_stride 

895 

896 k_offset = ( 

897 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None] 

898 ) 

899 p_k0 = k_ptr + k_offset 

900 

901 v_offset = ( 

902 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :] 

903 ) 

904 p_v0 = v_ptr + v_offset 

905 

906 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

907 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

908 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

909 

910 if split_block_max <= masking_block_min: 

911 # no masking needed 

912 for n_block in tl.range( 

913 split_block_min, split_block_max, num_stages=num_stages 

914 ): 

915 kv_off = n_block * BLOCK_N * k_row_stride 

916 if d == BLOCK_K: 

917 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

918 else: 

919 K = tl.load( 

920 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0 

921 ) 

922 if PRE_LOAD_V: 

923 if d == BLOCK_K: 

924 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

925 else: 

926 V = tl.load( 

927 p_v0 + kv_off, 

928 mask=dmask[None, :], 

929 cache_modifier=".cg", 

930 other=0.0, 

931 ) 

932 S = tl.dot(Q, K) 

933 S = apply_softcap(S, softcap, is_softcap) 

934 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

935 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

936 S = apply_alibi( 

937 S, 

938 col_idx, 

939 row_idx, 

940 seqlen_q, 

941 seqlen_k, 

942 is_causal=is_causal, 

943 is_alibi=is_alibi, 

944 alibi_slope=alibi_slope, 

945 ) 

946 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

947 acc_, 

948 S, 

949 rowmax_, 

950 rowsum_, 

951 softmax_scale_log2e=scale_softmax_log2, 

952 is_border=False, 

953 ) 

954 

955 if not PRE_LOAD_V: 

956 if d == BLOCK_K: 

957 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

958 else: 

959 V = tl.load( 

960 p_v0 + kv_off, 

961 mask=dmask[None, :], 

962 cache_modifier=".cg", 

963 other=0.0, 

964 ) 

965 P = P.to(v_ptr.type.element_ty) 

966 acc_ = tl.dot(P, V, acc_) 

967 else: 

968 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)): 

969 kv_off = n_block * BLOCK_N * k_row_stride 

970 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

971 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

972 if IS_EVEN_MN & d == BLOCK_K: 

973 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

974 if PRE_LOAD_V: 

975 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

976 elif d == BLOCK_K: 

977 kvmask = col_idx < seqlen_k 

978 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") 

979 if PRE_LOAD_V: 

980 V = tl.load( 

981 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

982 ) 

983 else: 

984 kvmask = col_idx < seqlen_k 

985 K = tl.load( 

986 p_k0 + kv_off, 

987 mask=dmask[:, None] & kvmask[None, :], 

988 cache_modifier=".cg", 

989 other=0.0, 

990 ) 

991 if PRE_LOAD_V: 

992 V = tl.load( 

993 p_v0 + kv_off, 

994 mask=dmask[None, :] & kvmask[:, None], 

995 cache_modifier=".cg", 

996 other=0.0, 

997 ) 

998 

999 S = tl.dot(Q, K) 

1000 S = apply_softcap(S, softcap, is_softcap) 

1001 S = apply_alibi( 

1002 S, 

1003 col_idx, 

1004 row_idx, 

1005 seqlen_q, 

1006 seqlen_k, 

1007 is_causal=is_causal, 

1008 is_alibi=is_alibi, 

1009 alibi_slope=alibi_slope, 

1010 ) 

1011 S = apply_mask( 

1012 S, 

1013 col_idx, 

1014 row_idx, 

1015 seqlen_q, 

1016 seqlen_k, 

1017 window_size_left, 

1018 window_size_right, 

1019 is_even_mn=IS_EVEN_MN, 

1020 is_causal=is_causal, 

1021 is_local=False, 

1022 ) 

1023 

1024 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1025 acc_, 

1026 S, 

1027 rowmax_, 

1028 rowsum_, 

1029 softmax_scale_log2e=scale_softmax_log2, 

1030 is_border=(is_causal or is_local), 

1031 ) 

1032 

1033 if not PRE_LOAD_V: 

1034 if IS_EVEN_MN & d == BLOCK_K: 

1035 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

1036 elif d == BLOCK_K: 

1037 V = tl.load( 

1038 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

1039 ) 

1040 else: 

1041 V = tl.load( 

1042 p_v0 + kv_off, 

1043 mask=dmask[None, :] & kvmask[:, None], 

1044 cache_modifier=".cg", 

1045 other=0.0, 

1046 ) 

1047 P = P.to(v_ptr.type.element_ty) 

1048 acc_ = tl.dot(P, V, acc_) 

1049 

1050 # LSE 

1051 lse = tl.where( 

1052 rowsum_ == 0 | (rowsum_ != rowsum_), 

1053 float("-inf"), 

1054 rowmax_ * scale_softmax + tl.log(rowsum_), 

1055 ) 

1056 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1057 

1058 # Rescale output 

1059 acc_ *= inv_sum[:, None] 

1060 

1061 # Write back output 

1062 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) 

1063 # grid = (seq_block, split, batch * head) 

1064 o_split_ptr = o_ptr 

1065 # + split, batch, head offsets, seq_block offsets are already added in row_idx 

1066 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d 

1067 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K) 

1068 o_split_ptr = tl.multiple_of(o_split_ptr, d) 

1069 p_om = o_split_ptr + o_split_offset 

1070 

1071 if IS_EVEN_MN & BLOCK_K == d: 

1072 tl.store(p_om, acc_, cache_modifier=".cg") 

1073 else: 

1074 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg") 

1075 

1076 # Write back lse 

1077 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) 

1078 lse_split_ptr = softmax_lse_ptr 

1079 # + split, batch, head, seq_block offsets 

1080 lse_split_ptr += ( 

1081 split_id * tl.num_programs(2) + tl.program_id(2) 

1082 ) * seqlen_q + m_block * BLOCK_M 

1083 

1084 if IS_EVEN_MN: 

1085 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") 

1086 else: 

1087 tl.store( 

1088 lse_split_ptr + tl.arange(0, BLOCK_M), 

1089 lse, 

1090 mask=row_idx < seqlen_q, 

1091 cache_modifier=".cg", 

1092 ) 

1093 

1094 

1095@libentry() 

1096@triton.jit 

1097def flash_fwd_splitkv_combine_kernel( 

1098 out_ptr, 

1099 lse_ptr, 

1100 out_splits_ptr, 

1101 lse_splits_ptr, 

1102 head_size: tl.constexpr, 

1103 out_split_stride, 

1104 lse_split_stride, 

1105 out_b_stride, 

1106 out_s_stride, 

1107 out_h_stride, 

1108 n_splits, 

1109 BLOCK_M: tl.constexpr, 

1110 BLOCK_K: tl.constexpr, 

1111 q_total, 

1112 MAX_N_SPLITS: tl.constexpr, 

1113): 

1114 pid = tl.program_id(0) 

1115 lse_splits_ptr += pid * BLOCK_M 

1116 lse_ptr += pid * BLOCK_M 

1117 out_splits_ptr += pid * BLOCK_M * head_size 

1118 out_ptr += pid * BLOCK_M * head_size 

1119 

1120 # Subtracting maximum from each of the split lse's for better numerical stability 

1121 lse_split_offset = ( 

1122 tl.arange(0, BLOCK_M)[:, None] 

1123 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride 

1124 ) 

1125 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & ( 

1126 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits 

1127 ) 

1128 lse_splits = tl.load( 

1129 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf") 

1130 ) 

1131 max_lse = tl.max(lse_splits, 1) 

1132 

1133 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse) 

1134 Zi_scaled = tl.exp(lse_splits - max_lse[:, None]) 

1135 Z_scaled = tl.sum(Zi_scaled, 1) 

1136 Zi_Z = Zi_scaled / Z_scaled[:, None] 

1137 

1138 # Write back LSE 

1139 lse = tl.log(Z_scaled) + max_lse 

1140 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total 

1141 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask) 

1142 

1143 out_split_offset = ( 

1144 tl.arange(0, BLOCK_M)[:, None, None] * head_size 

1145 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride 

1146 + tl.arange(0, BLOCK_K)[None, None, :] 

1147 ) 

1148 out_split_mask = ( 

1149 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total) 

1150 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits) 

1151 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size) 

1152 ) 

1153 out_splits = tl.load( 

1154 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0 

1155 ) 

1156 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) 

1157 out = out.to(out_ptr.type.element_ty) 

1158 

1159 # Write back output 

1160 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K) 

1161 dmask = tl.arange(0, BLOCK_K) < head_size 

1162 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :]) 

1163 

1164 

1165@triton.jit 

1166def virtual_to_cache_offset( 

1167 virtual_index, 

1168 max_virtual_index, 

1169 page_table_ptr, 

1170 block_size, 

1171 k_row_stride, 

1172 k_page_stride, 

1173 boundary_check: tl.constexpr = False, 

1174): 

1175 # virtual_index is the kv sequence index in the current batch element 

1176 # page_table_ptr is already pointed at current batch element's block table entry 

1177 # block_size is the size of each block in the page table 

1178 virtual_page_index = virtual_index // block_size 

1179 page_offset = virtual_index % block_size 

1180 if boundary_check: 

1181 page_block_index = tl.load( 

1182 page_table_ptr + virtual_page_index, 

1183 mask=virtual_index < max_virtual_index, 

1184 other=0, 

1185 ).to(tl.int64) 

1186 else: 

1187 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int64) 

1188 return page_block_index * k_page_stride + page_offset * k_row_stride 

1189 

1190 

1191@triton.jit 

1192def load_from_kvcache( 

1193 virtual_index, 

1194 max_virtual_index, 

1195 page_table_ptr, 

1196 k_ptr_base, 

1197 v_ptr_base, 

1198 block_size, 

1199 d: tl.constexpr, 

1200 k_row_stride, 

1201 BLOCK_K: tl.constexpr, 

1202 k_page_stride=0, 

1203 boundary_check: tl.constexpr = False, 

1204): 

1205 cache_offset = virtual_to_cache_offset( 

1206 virtual_index, 

1207 max_virtual_index, 

1208 page_table_ptr, 

1209 block_size, 

1210 k_row_stride, 

1211 k_page_stride, 

1212 boundary_check, 

1213 ) 

1214 k_offset = tl.arange(0, BLOCK_K)[:, None] + cache_offset[None, :] 

1215 v_offset = tl.arange(0, BLOCK_K)[None, :] + cache_offset[:, None] 

1216 if d == BLOCK_K: 

1217 bK_mask = virtual_index[None, :] < max_virtual_index[None, :] 

1218 bV_mask = virtual_index[:, None] < max_virtual_index[:, None] 

1219 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0) 

1220 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0) 

1221 else: 

1222 bK_mask = (tl.arange(0, BLOCK_K)[:, None] < d) & ( 

1223 virtual_index[None, :] < max_virtual_index[None, :] 

1224 ) 

1225 bV_mask = (tl.arange(0, BLOCK_K)[None, :] < d) & ( 

1226 virtual_index[:, None] < max_virtual_index[:, None] 

1227 ) 

1228 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0) 

1229 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0) 

1230 return bK, bV 

1231 

1232 

1233@libentry() 

1234@triton.jit( 

1235 do_not_specialize=[ 

1236 "q_batch_stride", 

1237 "k_batch_stride", 

1238 "v_batch_stride", 

1239 "o_batch_stride", 

1240 "b", 

1241 "bk", 

1242 "seqlen_q", 

1243 "seqlen_k", 

1244 "seqlen_q_rounded", 

1245 "seqlen_k_rounded", 

1246 "total_q", 

1247 "k_page_stride", 

1248 ] 

1249) 

1250def flash_varlen_fwd_kernel( 

1251 q_ptr, 

1252 k_ptr, 

1253 v_ptr, 

1254 o_ptr, 

1255 p_ptr, 

1256 softmax_lse_ptr, 

1257 q_row_stride, 

1258 k_row_stride, 

1259 v_row_stride, 

1260 q_head_stride, 

1261 k_head_stride, 

1262 v_head_stride, 

1263 o_row_stride, 

1264 o_head_stride, 

1265 q_batch_stride, 

1266 k_batch_stride, 

1267 v_batch_stride, 

1268 o_batch_stride, 

1269 is_cu_seqlens_q: tl.constexpr, 

1270 cu_seqlens_q_ptr, 

1271 is_cu_seqlens_k: tl.constexpr, 

1272 cu_seqlens_k_ptr, 

1273 is_seqused_k: tl.constexpr, 

1274 seqused_k_ptr, 

1275 # sizes 

1276 b, 

1277 bk, 

1278 h: tl.constexpr, 

1279 hk: tl.constexpr, 

1280 h_hk_ratio: tl.constexpr, 

1281 seqlen_q, 

1282 seqlen_k, 

1283 seqlen_q_rounded, 

1284 seqlen_k_rounded, 

1285 d: tl.constexpr, 

1286 d_rounded: tl.constexpr, 

1287 # scaling factors 

1288 is_softcap: tl.constexpr, 

1289 softcap: tl.constexpr, 

1290 scale_softmax: tl.constexpr, 

1291 scale_softmax_log2: tl.constexpr, 

1292 # dropout 

1293 is_dropout: tl.constexpr, 

1294 p_dropout: tl.constexpr, 

1295 rp_dropout: tl.constexpr, 

1296 p_dropout_in_uint8_t: tl.constexpr, 

1297 philox_args, 

1298 return_softmax: tl.constexpr, 

1299 # causal and swa 

1300 is_causal: tl.constexpr, 

1301 is_local: tl.constexpr, 

1302 window_size_left: tl.constexpr, 

1303 window_size_right: tl.constexpr, 

1304 seqlenq_ngroups_swapped: tl.constexpr, 

1305 is_paged: tl.constexpr, 

1306 # alibi 

1307 is_alibi: tl.constexpr, 

1308 alibi_slopes_ptr, 

1309 alibi_slopes_batch_stride: tl.constexpr, 

1310 # block table 

1311 total_q, 

1312 page_table_ptr, 

1313 page_table_batch_stride: tl.constexpr, 

1314 block_size: tl.constexpr, 

1315 k_page_stride, 

1316 # kernel params 

1317 BLOCK_M: tl.constexpr, 

1318 BLOCK_N: tl.constexpr, 

1319 BLOCK_K: tl.constexpr, 

1320 num_warps: tl.constexpr, 

1321 num_stages: tl.constexpr, 

1322): 

1323 m_block = tl.program_id(0) 

1324 bid = tl.program_id(1) 

1325 hid = tl.program_id(2) 

1326 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

1327 

1328 if is_cu_seqlens_q: 

1329 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32) 

1330 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32) 

1331 q_len = q_eos - q_bos 

1332 # Current request's start offset in the batched Q 

1333 q_offset = q_bos * q_row_stride 

1334 o_offset = q_bos * o_row_stride 

1335 lse_offset = q_bos * 1 

1336 else: 

1337 q_len = seqlen_q 

1338 q_offset = bid * q_batch_stride 

1339 o_offset = bid * o_batch_stride 

1340 lse_offset = bid * seqlen_q 

1341 

1342 if is_cu_seqlens_k: 

1343 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32) 

1344 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32) 

1345 k_len_cache = k_eos - k_bos 

1346 # k_offset = k_bos * k_row_stride 

1347 else: 

1348 k_len_cache = seqlen_k 

1349 # k_offset = bid * k_batch_stride 

1350 

1351 if is_seqused_k: 

1352 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32) 

1353 else: 

1354 k_len = k_len_cache 

1355 

1356 # Noop CTA 

1357 if m_block * BLOCK_M > q_len: 

1358 return 

1359 

1360 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0) 

1361 is_even_mn: tl.constexpr = False 

1362 

1363 if is_local: 

1364 n_block_min = max( 

1365 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N 

1366 ) 

1367 else: 

1368 n_block_min = 0 

1369 

1370 n_block_max = tl.cdiv(k_len, BLOCK_N) 

1371 if is_causal or is_local: 

1372 n_block_max = min( 

1373 n_block_max, 

1374 tl.cdiv( 

1375 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N 

1376 ), 

1377 ) 

1378 

1379 if is_dropout: 

1380 philox_seed = tl.load(philox_args).to(tl.uint64) 

1381 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

1382 

1383 # Locate the page table entry for the current batch element 

1384 if is_paged: 

1385 page_table_ptr += bid * page_table_batch_stride 

1386 # Calculate the starting offset of q for the current head 

1387 q_row_offset = hid * q_head_stride 

1388 # Calculate the starting offset of k and v for the current head 

1389 k_row_offset = (hid // h_hk_ratio) * k_head_stride 

1390 # Shift the k, v pointers to align with the current head 

1391 k_ptr_base = k_ptr + k_row_offset 

1392 v_ptr_base = v_ptr + k_row_offset 

1393 

1394 gQ = tl.make_block_ptr( 

1395 base=q_ptr + q_offset + q_row_offset, 

1396 shape=(q_len, d), 

1397 strides=(q_row_stride, 1), 

1398 offsets=(0, 0), 

1399 block_shape=(BLOCK_M, BLOCK_K), 

1400 order=(1, 0), 

1401 ) 

1402 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1)) 

1403 

1404 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

1405 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

1406 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

1407 

1408 if is_alibi: 

1409 alibi_offset = bid * alibi_slopes_batch_stride + hid 

1410 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

1411 alibi_slope /= scale_softmax 

1412 else: 

1413 alibi_slope = 0.0 

1414 

1415 if not is_causal and not is_local: 

1416 n_masking_steps = 1 

1417 elif is_even_mn: 

1418 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) 

1419 else: 

1420 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 

1421 

1422 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps) 

1423 

1424 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1425 n_block = n_block_max - 1 

1426 for step in tl.range(0, n_masking_steps): 

1427 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1428 if is_paged: 

1429 bK, bV = load_from_kvcache( 

1430 col_idx, 

1431 k_len, 

1432 page_table_ptr, 

1433 k_ptr_base, 

1434 v_ptr_base, 

1435 block_size, 

1436 d, 

1437 k_row_stride, 

1438 BLOCK_K=BLOCK_K, 

1439 k_page_stride=k_page_stride, 

1440 boundary_check=True, 

1441 ) 

1442 else: 

1443 start_n = n_block * BLOCK_N 

1444 k_ptr_seq = k_ptr_base + k_bos * k_row_stride 

1445 v_ptr_seq = v_ptr_base + k_bos * k_row_stride 

1446 gK = tl.make_block_ptr( 

1447 base=k_ptr_seq, 

1448 shape=(k_len, d), 

1449 strides=(k_row_stride, 1), 

1450 offsets=(start_n, 0), 

1451 block_shape=(BLOCK_N, BLOCK_K), 

1452 order=(0, 1), 

1453 ) 

1454 gV = tl.make_block_ptr( 

1455 base=v_ptr_seq, 

1456 shape=(k_len, d), 

1457 strides=(k_row_stride, 1), 

1458 offsets=(start_n, 0), 

1459 block_shape=(BLOCK_N, BLOCK_K), 

1460 order=(0, 1), 

1461 ) 

1462 bK = tl.load(gK, boundary_check=(0, 1)) 

1463 bK = tl.trans(bK) 

1464 bV = tl.load(gV, boundary_check=(0, 1)) 

1465 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1466 S = apply_softcap(S, softcap, is_softcap) 

1467 S = apply_alibi( 

1468 S, 

1469 col_idx, 

1470 row_idx, 

1471 q_len, 

1472 k_len, 

1473 is_causal=is_causal, 

1474 is_alibi=is_alibi, 

1475 alibi_slope=alibi_slope, 

1476 ) 

1477 S = apply_mask( 

1478 S, 

1479 col_idx, 

1480 row_idx, 

1481 q_len, 

1482 k_len, 

1483 window_size_left, 

1484 window_size_right, 

1485 is_even_mn=is_even_mn, 

1486 is_causal=is_causal, 

1487 is_local=is_local, 

1488 ) 

1489 

1490 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1491 acc_, 

1492 S, 

1493 rowmax_, 

1494 rowsum_, 

1495 softmax_scale_log2e=scale_softmax_log2, 

1496 is_border=True, 

1497 ) 

1498 P = P.to(v_ptr.type.element_ty) 

1499 

1500 if is_dropout: 

1501 P = apply_dropout( 

1502 P, 

1503 n_block * BLOCK_N, 

1504 m_block * BLOCK_M, 

1505 k_len, 

1506 bid, 

1507 hid, 

1508 philox_seed, 

1509 philox_offset, 

1510 p_dropout_in_uint8_t, 

1511 is_dropout, 

1512 encode_dropout_in_sign_bit=False, 

1513 NUM_HEADS=h, 

1514 BLOCK_M=BLOCK_M, 

1515 BLOCK_N=BLOCK_N, 

1516 ) 

1517 

1518 acc_ = tl.dot(P, bV, acc_) 

1519 n_block -= 1 

1520 

1521 for n_block in tl.range( 

1522 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1 

1523 ): 

1524 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1525 if is_paged: 

1526 bK, bV = load_from_kvcache( 

1527 col_idx, 

1528 k_len, 

1529 page_table_ptr, 

1530 k_ptr_base, 

1531 v_ptr_base, 

1532 block_size, 

1533 d, 

1534 k_row_stride, 

1535 BLOCK_K=BLOCK_K, 

1536 k_page_stride=k_page_stride, 

1537 ) 

1538 else: 

1539 start_n = n_block * BLOCK_N 

1540 k_ptr_seq = k_ptr_base + k_bos * k_row_stride 

1541 v_ptr_seq = v_ptr_base + k_bos * k_row_stride 

1542 gK = tl.make_block_ptr( 

1543 base=k_ptr_seq, 

1544 shape=(k_len, d), 

1545 strides=(k_row_stride, 1), 

1546 offsets=(start_n, 0), 

1547 block_shape=(BLOCK_N, BLOCK_K), 

1548 order=(0, 1), 

1549 ) 

1550 gV = tl.make_block_ptr( 

1551 base=v_ptr_seq, 

1552 shape=(k_len, d), 

1553 strides=(k_row_stride, 1), 

1554 offsets=(start_n, 0), 

1555 block_shape=(BLOCK_N, BLOCK_K), 

1556 order=(0, 1), 

1557 ) 

1558 bK = tl.load(gK) 

1559 bK = tl.trans(bK) 

1560 bV = tl.load(gV) 

1561 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1562 S = apply_softcap(S, softcap, is_softcap) 

1563 S = apply_alibi( 

1564 S, 

1565 col_idx, 

1566 row_idx, 

1567 q_len, 

1568 k_len, 

1569 is_causal=is_causal, 

1570 is_alibi=is_alibi, 

1571 alibi_slope=alibi_slope, 

1572 ) 

1573 S = apply_mask( 

1574 S, 

1575 col_idx, 

1576 row_idx, 

1577 q_len, 

1578 k_len, 

1579 window_size_left, 

1580 window_size_right, 

1581 is_even_mn=True, 

1582 is_causal=False, 

1583 is_local=is_local, 

1584 ) 

1585 

1586 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1587 acc_, 

1588 S, 

1589 rowmax_, 

1590 rowsum_, 

1591 softmax_scale_log2e=scale_softmax_log2, 

1592 is_border=is_local, 

1593 ) 

1594 P = P.to(v_ptr.type.element_ty) 

1595 

1596 if is_dropout: 

1597 P = apply_dropout( 

1598 P, 

1599 m_block * BLOCK_M, 

1600 n_block * BLOCK_N, 

1601 k_len, 

1602 bid, 

1603 hid, 

1604 philox_seed, 

1605 philox_offset, 

1606 p_dropout_in_uint8_t, 

1607 is_dropout, 

1608 encode_dropout_in_sign_bit=False, 

1609 NUM_HEADS=h, 

1610 BLOCK_M=BLOCK_M, 

1611 BLOCK_N=BLOCK_N, 

1612 ) 

1613 acc_ = tl.dot(P, bV, acc_) 

1614 

1615 # LSE 

1616 lse = tl.where( 

1617 rowsum_ == 0 | (rowsum_ != rowsum_), 

1618 float("inf"), 

1619 rowmax_ * scale_softmax + tl.log(rowsum_), 

1620 ) 

1621 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1622 

1623 acc_ *= inv_sum[:, None] 

1624 

1625 out = acc_.to(o_ptr.type.element_ty) # noqa 

1626 

1627 # Write back output 

1628 o_row_offset = hid * o_head_stride 

1629 

1630 gO = tl.make_block_ptr( 

1631 base=o_ptr + o_offset + o_row_offset, 

1632 shape=(q_len, d), 

1633 strides=(o_row_stride, 1), 

1634 offsets=(0, 0), 

1635 block_shape=(BLOCK_M, BLOCK_K), 

1636 order=(1, 0), 

1637 ) 

1638 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1)) 

1639 

1640 # Write back lse 

1641 # lse shape: [h, total_q] 

1642 softmax_lse_ptr += hid * total_q 

1643 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1644 tl.store( 

1645 softmax_lse_ptr + lse_row_offset, 

1646 lse, 

1647 mask=lse_row_offset < (lse_offset + q_len), 

1648 )