Coverage for src/flag_gems/runtime/backend/_sunrise/ops/flash_kernel.py: 0%

572 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import triton 

2import triton.language as tl 

3 

4from flag_gems.utils import libentry, tl_extra_shim 

5 

6 

7@triton.jit 

8def u64_to_lohi(x): 

9 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) 

10 

11 

12@triton.jit 

13def u64_from_lohi(lo, hi): 

14 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) 

15 

16 

17@triton.jit 

18def philox_(seed, subsequence, offset): 

19 kPhilox10A: tl.constexpr = 0x9E3779B9 

20 kPhilox10B: tl.constexpr = 0xBB67AE85 

21 k0, k1 = u64_to_lohi(seed.to(tl.uint64)) 

22 c0, c1 = u64_to_lohi(offset.to(tl.uint64)) 

23 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64)) 

24 

25 # pragma unroll 

26 kPhiloxSA: tl.constexpr = 0xD2511F53 

27 kPhiloxSB: tl.constexpr = 0xCD9E8D57 

28 for _ in tl.static_range(6): 

29 res0 = kPhiloxSA * c0.to(tl.uint64) 

30 res1 = kPhiloxSB * c2.to(tl.uint64) 

31 res0_x, res0_y = u64_to_lohi(res0) 

32 res1_x, res1_y = u64_to_lohi(res1) 

33 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

34 k0 += kPhilox10A 

35 k1 += kPhilox10B 

36 

37 res0 = kPhiloxSA * c0.to(tl.uint64) 

38 res1 = kPhiloxSB * c2.to(tl.uint64) 

39 res0_x, res0_y = u64_to_lohi(res0) 

40 res1_x, res1_y = u64_to_lohi(res1) 

41 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

42 

43 return c0, c1, c2, c3 

44 

45 

46@triton.jit 

47def apply_dropout_mask( 

48 P, 

49 mask, 

50 encode_dropout_in_sign_bit: tl.constexpr, 

51): 

52 if encode_dropout_in_sign_bit: 

53 P = tl.where(mask, -P, P) 

54 else: 

55 P = tl.where(mask, (P * 0).to(P.dtype), P) 

56 return P 

57 

58 

59@triton.jit 

60def apply_dropout( 

61 P, 

62 row_start, 

63 col_start, 

64 n_cols, 

65 bid, 

66 hid, 

67 philox_seed, 

68 philox_offset, 

69 p_dropout_uint8: tl.constexpr, 

70 is_dropout: tl.constexpr, 

71 encode_dropout_in_sign_bit: tl.constexpr, 

72 NUM_HEADS: tl.constexpr, 

73 BLOCK_M: tl.constexpr, 

74 BLOCK_N: tl.constexpr, 

75): 

76 if is_dropout: 

77 row_start = tl.multiple_of(row_start, BLOCK_M) 

78 col_start = tl.multiple_of(col_start, BLOCK_N) 

79 row = row_start + tl.arange(0, BLOCK_M)[:, None] 

80 # Down scale col_idx by 4 

81 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :] 

82 

83 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64) 

84 

85 offset = philox_offset + bid * NUM_HEADS + hid 

86 offset += subsequence * 0 

87 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset) 

88 

89 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N) 

90 

91 mask = (r & 0xFF) >= p_dropout_uint8 

92 

93 P = apply_dropout_mask( 

94 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit 

95 ) 

96 return P 

97 

98 

99@triton.jit 

100def apply_alibi( 

101 S, 

102 col_idx, 

103 row_idx, 

104 max_seqlen_q, 

105 max_seqlen_k, 

106 is_causal: tl.constexpr, 

107 is_alibi: tl.constexpr, 

108 alibi_slope: tl.constexpr = None, 

109): 

110 if is_alibi: 

111 if is_causal: 

112 # The row independent alibi bias renders the same attention output 

113 # as with the standard alibi because softmax is shift invariant, i.e., 

114 # softmax(A + bias + const) = softamx(A + bias). The following two 

115 # biases are no different if causal is true. 

116 # bias_1 = [ 

117 # -4, -3, -2, X, X, 

118 # -4, -3, -2, -1, X, 

119 # -4, -3, -2, -1, 0, 

120 # ] 

121 # bias_2 = [ 

122 # -2, -1, 0, X, X, 

123 # -3, -2, -1, 0, X, 

124 # -4, -3, -2, -1, 0, 

125 # ] 

126 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32) 

127 S += bias 

128 else: 

129 bias = -alibi_slope * tl.abs( 

130 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None] 

131 ).to(tl.float32) 

132 S += bias 

133 

134 return S 

135 

136 

137@triton.jit 

138def apply_mask( 

139 S, 

140 col_idx, 

141 row_idx, 

142 max_seqlen_q, 

143 max_seqlen_k, 

144 window_size_left, 

145 window_size_right, 

146 is_even_mn: tl.constexpr, 

147 is_causal: tl.constexpr, 

148 is_local: tl.constexpr, 

149): 

150 need_mask = is_causal | is_local | (not is_even_mn) 

151 # need_mask: tl.constexpr = is_causal | is_local 

152 if need_mask: 

153 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive! 

154 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left) 

155 col_rb = min( 

156 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right 

157 ) 

158 

159 if is_causal: 

160 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S) 

161 

162 if is_local: 

163 S = tl.where( 

164 (col_idx[None, :] > col_rb[:, None]) 

165 | (col_idx[None, :] < col_lb[:, None]), 

166 float("-inf"), 

167 S, 

168 ) 

169 

170 if (not is_local) & (not is_causal) & (not is_even_mn): 

171 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S) 

172 

173 return S 

174 

175 

176@triton.jit 

177def softmax_rescale( 

178 O_acc, 

179 S, 

180 row_max, 

181 row_sum, 

182 softmax_scale_log2e: tl.constexpr, 

183 is_border: tl.constexpr, 

184 # is_init: tl.constexpr 

185): 

186 prev_max = row_max 

187 row_max = tl.maximum(row_max, tl.max(S, 1)) 

188 

189 if is_border: 

190 cur_max = tl.where(row_max == float("-inf"), 0, row_max) 

191 else: 

192 cur_max = row_max 

193 

194 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) 

195 row_sum *= p_scale 

196 O_acc *= p_scale[:, None] 

197 

198 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e) 

199 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) 

200 row_sum = row_sum + tl.sum(P, 1) 

201 return O_acc, P, row_max, row_sum 

202 

203 

204@triton.jit 

205def apply_softcap(S, softcap, is_softcap: tl.constexpr): 

206 if is_softcap: 

207 S = tl_extra_shim.tanh(S * softcap) 

208 

209 return S 

210 

211 

212def block_m_splitkv_heuristic(headdim): 

213 return 128 if headdim <= 128 else 64 

214 

215 

216def block_n_splitkv_heuristic(headdim): 

217 return 64 if headdim <= 64 else 32 

218 

219 

220def is_even_mn(M, N, BM, BN, WL, WR): 

221 if M % BM == 0 and N % BN == 0: 

222 if M % N == 0 or N % M == 0: 

223 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0): 

224 return True 

225 return False 

226 

227 

228def block_m_splitkv_heuristic_spec_args(args): 

229 return 128 if args["d"] <= 128 else 64 

230 

231 

232def block_n_splitkv_heuristic_spec_args(args): 

233 return 64 if args["d"] <= 64 else 32 

234 

235 

236def is_even_mn_spec_args(args): 

237 if ( 

238 args["seqlen_q"] % args["BLOCK_M"] == 0 

239 and args["seqlen_k"] % args["BLOCK_N"] == 0 

240 ): 

241 if ( 

242 args["seqlen_q"] % args["seqlen_k"] == 0 

243 or args["seqlen_k"] % args["seqlen_q"] == 0 

244 ): 

245 if ( 

246 args["window_size_left"] == -1 

247 or args["window_size_left"] % args["BLOCK_N"] == 0 

248 ) and ( 

249 args["window_size_right"] == -1 

250 or args["window_size_right"] % args["BLOCK_N"] == 0 

251 ): 

252 return True 

253 return False 

254 

255 

256def keep(cfg, must_keep=None): 

257 BM = cfg.kwargs["BLOCK_M"] 

258 BN = cfg.kwargs["BLOCK_N"] 

259 w = cfg.num_warps 

260 

261 # we always keep configurations in `must_keep` 

262 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or ( 

263 must_keep and cfg in must_keep 

264 ) 

265 

266 

267def prune_fwd_configs(configs, nargs, **kwargs): 

268 is_dropout = nargs["is_dropout"] 

269 if is_dropout: 

270 return list( 

271 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs) 

272 ) 

273 else: 

274 return configs 

275 

276 

277def flash_fwd_kernel_heur_block_k(args): 

278 return triton.next_power_of_2(args["d"]) 

279 

280 

281@libentry() 

282@triton.heuristics( 

283 values={ 

284 "BLOCK_K": flash_fwd_kernel_heur_block_k, 

285 "PRE_LOAD_V": lambda args: False, 

286 "IS_EVEN_MN": lambda args: is_even_mn( 

287 args["seqlen_q"], 

288 args["seqlen_k"], 

289 args["BLOCK_M"], 

290 args["BLOCK_N"], 

291 args["window_size_left"], 

292 args["window_size_right"], 

293 ), 

294 } 

295) 

296@triton.jit( 

297 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

298) 

299def flash_fwd_kernel( 

300 q_ptr, 

301 k_ptr, 

302 v_ptr, 

303 o_ptr, 

304 p_ptr, 

305 softmax_lse_ptr, 

306 q_row_stride, 

307 k_row_stride, 

308 v_row_stride, 

309 q_head_stride, 

310 k_head_stride, 

311 v_head_stride, 

312 o_row_stride, 

313 o_head_stride, 

314 q_batch_stride, 

315 k_batch_stride, 

316 v_batch_stride, 

317 o_batch_stride, 

318 is_cu_seqlens_q, 

319 cu_seqlens_q_ptr, 

320 is_cu_seqlens_k, 

321 cu_seqlens_k_ptr, 

322 is_seqused_k, 

323 seqused_k_ptr, 

324 # sizes 

325 b: tl.constexpr, 

326 bk: tl.constexpr, 

327 h: tl.constexpr, 

328 hk: tl.constexpr, 

329 h_hk_ratio: tl.constexpr, 

330 seqlen_q, 

331 seqlen_k, 

332 seqlen_q_rounded, 

333 seqlen_k_rounded, 

334 d: tl.constexpr, 

335 d_rounded: tl.constexpr, 

336 # scaling factors 

337 is_softcap: tl.constexpr, 

338 softcap: tl.constexpr, 

339 scale_softmax: tl.constexpr, 

340 scale_softmax_log2: tl.constexpr, 

341 # dropout 

342 is_dropout: tl.constexpr, 

343 p_dropout: tl.constexpr, 

344 rp_dropout: tl.constexpr, 

345 p_dropout_in_uint8_t: tl.constexpr, 

346 philox_args, 

347 return_softmax: tl.constexpr, 

348 # causal and swa 

349 is_causal: tl.constexpr, 

350 is_local: tl.constexpr, 

351 window_size_left: tl.constexpr, 

352 window_size_right: tl.constexpr, 

353 seqlenq_ngroups_swapped: tl.constexpr, 

354 is_paged: tl.constexpr, 

355 # alibi 

356 is_alibi: tl.constexpr, 

357 alibi_slopes_ptr, 

358 alibi_slopes_batch_stride: tl.constexpr, 

359 # block table 

360 total_q: tl.constexpr, 

361 page_table_ptr, 

362 page_table_batch_stride: tl.constexpr, 

363 block_size: tl.constexpr, 

364 # kernel params 

365 IS_EVEN_MN: tl.constexpr, 

366 PRE_LOAD_V: tl.constexpr, 

367 BLOCK_M: tl.constexpr, 

368 BLOCK_N: tl.constexpr, 

369 BLOCK_K: tl.constexpr, 

370 num_warps: tl.constexpr, 

371 num_stages: tl.constexpr, 

372): 

373 m_block = tl.program_id(0) 

374 bh = tl.program_id(1) 

375 hid = bh % h 

376 bid = bh // h 

377 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

378 

379 # We draw a minimum covering frame on the attention map that this CTA is assigned to process. 

380 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. 

381 

382 col_min = 0 

383 if is_local: 

384 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left) 

385 if not IS_EVEN_MN: 

386 # round left 

387 col_min = (col_min // BLOCK_N) * BLOCK_N 

388 

389 col_max = seqlen_k 

390 if is_causal or is_local: 

391 col_max += (m_block - num_m_blocks + 1) * BLOCK_M 

392 if is_local: 

393 col_max += window_size_right 

394 col_max = min(seqlen_k, col_max) 

395 

396 if not IS_EVEN_MN: 

397 # round right 

398 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N 

399 

400 if (not is_causal) and (not is_local): 

401 if IS_EVEN_MN: 

402 masking_cols: tl.constexpr = 0 

403 else: 

404 masking_cols: tl.constexpr = BLOCK_N 

405 elif ( 

406 is_causal | is_local 

407 ) and IS_EVEN_MN: # causal implies window_size_right is zero 

408 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N 

409 else: 

410 # local 

411 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N 

412 

413 if is_dropout: 

414 philox_seed = tl.load(philox_args).to(tl.uint64) 

415 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

416 

417 if is_alibi: 

418 alibi_offset = bid * alibi_slopes_batch_stride + hid 

419 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

420 alibi_slope /= scale_softmax 

421 else: 

422 alibi_slope = 0.0 

423 

424 q_batch_stride = tl.multiple_of(q_batch_stride, d * h) 

425 q_ptr += bid * q_batch_stride + hid * q_head_stride 

426 row_start = m_block * BLOCK_M 

427 row_idx = row_start + tl.arange(0, BLOCK_M) 

428 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :] 

429 dmask = tl.arange(0, BLOCK_K) < d 

430 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q) 

431 if IS_EVEN_MN & d == BLOCK_K: 

432 Q = tl.load(q_ptr + q_off, cache_modifier=".cg") 

433 else: 

434 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg") 

435 

436 if return_softmax: 

437 p_ptr += ( 

438 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M 

439 ) * seqlen_k_rounded 

440 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange( 

441 0, BLOCK_N 

442 ) 

443 p_bp0 = p_ptr + p_offset 

444 

445 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

446 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

447 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

448 

449 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk) 

450 h_hk_ratio = h // hk 

451 k_ptr += bid * k_batch_stride 

452 k_ptr += (hid // h_hk_ratio) * k_head_stride 

453 v_ptr += bid * k_batch_stride 

454 v_ptr += (hid // h_hk_ratio) * k_head_stride 

455 

456 k_offset = ( 

457 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None] 

458 ) 

459 v_offset = ( 

460 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :] 

461 ) 

462 

463 p_bk0 = k_ptr + k_offset 

464 p_bv0 = v_ptr + v_offset 

465 

466 if is_causal | is_local | (not IS_EVEN_MN): 

467 # Cut short masking cols if there's not enough cols out there 

468 masking_cols = min(col_max - col_min, masking_cols) 

469 for col_shift in tl.range(0, masking_cols, step=BLOCK_N): 

470 col_start = col_max - col_shift - BLOCK_N 

471 col_start = tl.multiple_of(col_start, BLOCK_N) 

472 off = col_start * k_row_stride 

473 if IS_EVEN_MN & d == BLOCK_K: 

474 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

475 if PRE_LOAD_V: 

476 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

477 elif d == BLOCK_K: 

478 col_idx = col_start + tl.arange(0, BLOCK_N) 

479 kvmask = col_idx < seqlen_k 

480 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") 

481 if PRE_LOAD_V: 

482 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

483 else: 

484 col_idx = col_start + tl.arange(0, BLOCK_N) 

485 kvmask = col_idx < seqlen_k 

486 K = tl.load( 

487 p_bk0 + off, 

488 mask=kvmask[None, :] & dmask[:, None], 

489 cache_modifier=".cg", 

490 ) 

491 if PRE_LOAD_V: 

492 V = tl.load( 

493 p_bv0 + off, 

494 mask=kvmask[:, None] & dmask[None, :], 

495 cache_modifier=".cg", 

496 ) 

497 S = tl.dot(Q, K, allow_tf32=False) 

498 S = apply_softcap(S, softcap, is_softcap) 

499 col_idx = col_start + tl.arange(0, BLOCK_N) 

500 row_idx = row_start + tl.arange(0, BLOCK_M) 

501 S = apply_alibi( 

502 S, 

503 col_idx, 

504 row_idx, 

505 seqlen_q, 

506 seqlen_k, 

507 is_causal=is_causal, 

508 is_alibi=is_alibi, 

509 alibi_slope=alibi_slope, 

510 ) 

511 # tl.store(p_bp0 + col_start, S) 

512 S = apply_mask( 

513 S, 

514 col_idx, 

515 row_idx, 

516 seqlen_q, 

517 seqlen_k, 

518 window_size_left, 

519 window_size_right, 

520 is_even_mn=IS_EVEN_MN, 

521 is_causal=is_causal, 

522 is_local=is_local, 

523 ) 

524 

525 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

526 acc_, 

527 S, 

528 rowmax_, 

529 rowsum_, 

530 softmax_scale_log2e=scale_softmax_log2, 

531 is_border=(is_causal or is_local), 

532 ) 

533 P = P.to(v_ptr.type.element_ty) 

534 

535 if is_dropout: 

536 if return_softmax: 

537 P_drop = P 

538 

539 P_drop = apply_dropout( 

540 P_drop, 

541 row_start, 

542 col_start, 

543 seqlen_k, 

544 bid, 

545 hid, 

546 philox_seed, 

547 philox_offset, 

548 p_dropout_in_uint8_t, 

549 is_dropout, 

550 encode_dropout_in_sign_bit=True, 

551 NUM_HEADS=h, 

552 BLOCK_M=BLOCK_M, 

553 BLOCK_N=BLOCK_N, 

554 ) 

555 if IS_EVEN_MN: 

556 tl.store(p_bp0 + col_start, P_drop) 

557 else: 

558 kvmask = col_idx < seqlen_k 

559 tl.store( 

560 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :] 

561 ) 

562 

563 P = apply_dropout( 

564 P, 

565 row_start, 

566 col_start, 

567 seqlen_k, 

568 bid, 

569 hid, 

570 philox_seed, 

571 philox_offset, 

572 p_dropout_in_uint8_t, 

573 is_dropout, 

574 encode_dropout_in_sign_bit=False, 

575 NUM_HEADS=h, 

576 BLOCK_M=BLOCK_M, 

577 BLOCK_N=BLOCK_N, 

578 ) 

579 

580 if not PRE_LOAD_V: 

581 off = col_start * k_row_stride 

582 if IS_EVEN_MN & d == BLOCK_K: 

583 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

584 elif d == BLOCK_K: 

585 kvmask = col_idx < seqlen_k 

586 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

587 else: 

588 kvmask = col_idx < seqlen_k 

589 V = tl.load( 

590 p_bv0 + off, 

591 mask=kvmask[:, None] & dmask[None, :], 

592 cache_modifier=".cg", 

593 ) 

594 acc_ = tl.dot(P, V, acc_, allow_tf32=False) 

595 

596 for col_start in tl.range( 

597 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages 

598 ): 

599 col_start = tl.multiple_of(col_start, BLOCK_N) 

600 off = col_start * k_row_stride 

601 if d == BLOCK_K: 

602 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

603 if PRE_LOAD_V: 

604 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

605 else: 

606 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg") 

607 if PRE_LOAD_V: 

608 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg") 

609 

610 S = tl.dot(Q, K) 

611 S = apply_softcap(S, softcap, is_softcap) 

612 col_idx = col_start + tl.arange(0, BLOCK_N) 

613 row_idx = row_start + tl.arange(0, BLOCK_M) 

614 S = apply_alibi( 

615 S, 

616 col_idx, 

617 row_idx, 

618 seqlen_q, 

619 seqlen_k, 

620 is_causal=is_causal, 

621 is_alibi=is_alibi, 

622 alibi_slope=alibi_slope, 

623 ) 

624 S = apply_mask( 

625 S, 

626 col_idx, 

627 row_idx, 

628 seqlen_q, 

629 seqlen_k, 

630 window_size_left, 

631 window_size_right, 

632 is_even_mn=True, 

633 is_causal=False, 

634 is_local=is_local, 

635 ) 

636 

637 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

638 acc_, 

639 S, 

640 rowmax_, 

641 rowsum_, 

642 softmax_scale_log2e=scale_softmax_log2, 

643 is_border=is_local, 

644 ) 

645 P = P.to(v_ptr.type.element_ty) 

646 

647 if is_dropout: 

648 if return_softmax: 

649 P_drop = P 

650 P_drop = apply_dropout( 

651 P_drop, 

652 row_start, 

653 col_start, 

654 seqlen_k, 

655 bid, 

656 hid, 

657 philox_seed, 

658 philox_offset, 

659 p_dropout_in_uint8_t, 

660 is_dropout, 

661 encode_dropout_in_sign_bit=True, 

662 NUM_HEADS=h, 

663 BLOCK_M=BLOCK_M, 

664 BLOCK_N=BLOCK_N, 

665 ) 

666 if IS_EVEN_MN: 

667 tl.store(p_bp0 + col_start, P_drop) 

668 else: 

669 kvmask = col_idx < seqlen_k 

670 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) 

671 

672 P = apply_dropout( 

673 P, 

674 row_start, 

675 col_start, 

676 seqlen_k, 

677 bid, 

678 hid, 

679 philox_seed, 

680 philox_offset, 

681 p_dropout_in_uint8_t, 

682 is_dropout, 

683 encode_dropout_in_sign_bit=False, 

684 NUM_HEADS=h, 

685 BLOCK_M=BLOCK_M, 

686 BLOCK_N=BLOCK_N, 

687 ) 

688 

689 if not PRE_LOAD_V: 

690 off = col_start * k_row_stride 

691 if d == BLOCK_K: 

692 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

693 else: 

694 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg") 

695 acc_ = tl.dot(P, V, acc_) 

696 

697 # LSE 

698 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels 

699 # the effect of rowmax and outputs lse only. 

700 lse = tl.where( 

701 rowsum_ == 0 | (rowsum_ != rowsum_), 

702 float("inf"), 

703 rowmax_ * scale_softmax + tl.log(rowsum_), 

704 ) 

705 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

706 

707 if is_dropout: 

708 acc_ *= inv_sum[:, None] * rp_dropout 

709 else: 

710 acc_ *= inv_sum[:, None] 

711 

712 out = acc_.to(o_ptr.type.element_ty) # noqa 

713 

714 # Write back output 

715 o_batch_stride = tl.multiple_of(o_batch_stride, d * h) 

716 o_ptr += bid * o_batch_stride 

717 o_ptr += hid * o_head_stride 

718 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K) 

719 

720 if IS_EVEN_MN & d == BLOCK_K: 

721 tl.store(o_ptr + o_offset, out) 

722 else: 

723 tl.store(o_ptr + o_offset, out, mask=qmask) 

724 

725 # Write back lse 

726 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q 

727 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

728 

729 if IS_EVEN_MN: 

730 tl.store(p_lse + row_idx, lse) 

731 else: 

732 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q) 

733 

734 

735@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"]) 

736def flash_fwd_bh_parallel_kernel(): 

737 # (TODO) 

738 pass 

739 

740 

741def flash_fwd_splitkv_kernel_heur_block_k(args): 

742 return triton.next_power_of_2(args["d"]) 

743 

744 

745@libentry() 

746@triton.heuristics( 

747 values={ 

748 "BLOCK_M": block_m_splitkv_heuristic_spec_args, 

749 "BLOCK_N": block_n_splitkv_heuristic_spec_args, 

750 "BLOCK_K": flash_fwd_splitkv_kernel_heur_block_k, 

751 "num_warps": lambda args: 4, 

752 "num_stages": lambda args: 3, 

753 "PRE_LOAD_V": lambda args: True, 

754 "IS_EVEN_MN": is_even_mn_spec_args, 

755 } 

756) 

757@triton.jit( 

758 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

759) 

760def flash_fwd_splitkv_kernel( 

761 q_ptr, 

762 k_ptr, 

763 v_ptr, 

764 o_ptr, 

765 p_ptr, 

766 softmax_lse_ptr, 

767 q_row_stride, 

768 k_row_stride, 

769 v_row_stride, 

770 q_head_stride, 

771 k_head_stride, 

772 v_head_stride, 

773 o_row_stride, 

774 o_head_stride, 

775 q_batch_stride, 

776 k_batch_stride, 

777 v_batch_stride, 

778 o_batch_stride, 

779 is_cu_seqlens_q, 

780 cu_seqlens_q_ptr, 

781 is_cu_seqlens_k: tl.constexpr, 

782 cu_seqlens_k_ptr, 

783 is_seqused_k: tl.constexpr, 

784 seqused_k_ptr, 

785 # sizes 

786 b: tl.constexpr, 

787 bk: tl.constexpr, 

788 h: tl.constexpr, 

789 hk: tl.constexpr, 

790 h_hk_ratio: tl.constexpr, 

791 seqlen_q, 

792 seqlen_k, 

793 seqlen_q_rounded, 

794 seqlen_k_rounded, 

795 d: tl.constexpr, 

796 d_rounded: tl.constexpr, 

797 # scaling factors 

798 is_softcap: tl.constexpr, 

799 softcap: tl.constexpr, 

800 scale_softmax: tl.constexpr, 

801 scale_softmax_log2: tl.constexpr, 

802 # dropout 

803 is_dropout: tl.constexpr, 

804 p_dropout: tl.constexpr, 

805 rp_dropout: tl.constexpr, 

806 p_dropout_in_uint8_t: tl.constexpr, 

807 philox_args, 

808 return_softmax: tl.constexpr, 

809 # causal and swa 

810 is_causal: tl.constexpr, 

811 is_local: tl.constexpr, 

812 window_size_left: tl.constexpr, 

813 window_size_right: tl.constexpr, 

814 seqlenq_ngroups_swapped: tl.constexpr, 

815 is_paged: tl.constexpr, 

816 # alibi 

817 is_alibi: tl.constexpr, 

818 alibi_slopes_ptr, 

819 alibi_slopes_batch_stride: tl.constexpr, 

820 # block table 

821 total_q, 

822 page_table_ptr, 

823 page_table_batch_stride: tl.constexpr, 

824 block_size: tl.constexpr, 

825 # kernel params 

826 IS_EVEN_MN: tl.constexpr, 

827 PRE_LOAD_V: tl.constexpr, 

828 blocks_per_split: tl.constexpr, 

829 BLOCK_M: tl.constexpr, 

830 BLOCK_N: tl.constexpr, 

831 BLOCK_K: tl.constexpr, 

832 num_warps: tl.constexpr, 

833 num_stages: tl.constexpr, 

834): 

835 m_block = tl.program_id(0) 

836 split_id = tl.program_id(1) 

837 bid = tl.program_id(2) // h 

838 hid = tl.program_id(2) % h 

839 

840 split_block_min = split_id * blocks_per_split 

841 split_block_max = split_block_min + blocks_per_split 

842 

843 n_block_max = tl.cdiv(seqlen_k, BLOCK_N) 

844 if is_causal: 

845 n_block_max = min( 

846 n_block_max, 

847 tl.cdiv( 

848 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, 

849 BLOCK_N, 

850 ), 

851 ) 

852 

853 if is_alibi: 

854 alibi_offset = bid * alibi_slopes_batch_stride + hid 

855 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

856 alibi_slope /= scale_softmax 

857 else: 

858 alibi_slope = 0 

859 

860 if not is_causal: 

861 if IS_EVEN_MN: 

862 masking_block_min = n_block_max 

863 else: 

864 masking_block_min = n_block_max - 1 

865 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero 

866 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) 

867 else: 

868 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 

869 

870 q_ptr += bid * q_batch_stride 

871 q_ptr += hid * q_head_stride 

872 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

873 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :] 

874 p_qm = q_ptr + q_off 

875 dmask = tl.arange(0, BLOCK_K) < d 

876 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q) 

877 if IS_EVEN_MN & BLOCK_K == d: 

878 Q = tl.load(p_qm, cache_modifier=".cg") 

879 else: 

880 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") 

881 

882 h_hk_ratio = h // hk 

883 k_ptr += bid * k_batch_stride 

884 k_ptr += (hid // h_hk_ratio) * k_head_stride 

885 v_ptr += bid * k_batch_stride 

886 v_ptr += (hid // h_hk_ratio) * k_head_stride 

887 

888 k_offset = ( 

889 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None] 

890 ) 

891 p_k0 = k_ptr + k_offset 

892 

893 v_offset = ( 

894 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :] 

895 ) 

896 p_v0 = v_ptr + v_offset 

897 

898 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

899 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

900 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

901 

902 if split_block_max <= masking_block_min: 

903 # no masking needed 

904 for n_block in tl.range( 

905 split_block_min, split_block_max, num_stages=num_stages 

906 ): 

907 kv_off = n_block * BLOCK_N * k_row_stride 

908 if d == BLOCK_K: 

909 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

910 else: 

911 K = tl.load( 

912 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0 

913 ) 

914 if PRE_LOAD_V: 

915 if d == BLOCK_K: 

916 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

917 else: 

918 V = tl.load( 

919 p_v0 + kv_off, 

920 mask=dmask[None, :], 

921 cache_modifier=".cg", 

922 other=0.0, 

923 ) 

924 S = tl.dot(Q, K) 

925 S = apply_softcap(S, softcap, is_softcap) 

926 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

927 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

928 S = apply_alibi( 

929 S, 

930 col_idx, 

931 row_idx, 

932 seqlen_q, 

933 seqlen_k, 

934 is_causal=is_causal, 

935 is_alibi=is_alibi, 

936 alibi_slope=alibi_slope, 

937 ) 

938 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

939 acc_, 

940 S, 

941 rowmax_, 

942 rowsum_, 

943 softmax_scale_log2e=scale_softmax_log2, 

944 is_border=False, 

945 ) 

946 

947 if not PRE_LOAD_V: 

948 if d == BLOCK_K: 

949 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

950 else: 

951 V = tl.load( 

952 p_v0 + kv_off, 

953 mask=dmask[None, :], 

954 cache_modifier=".cg", 

955 other=0.0, 

956 ) 

957 P = P.to(v_ptr.type.element_ty) 

958 acc_ = tl.dot(P, V, acc_) 

959 else: 

960 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)): 

961 kv_off = n_block * BLOCK_N * k_row_stride 

962 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

963 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

964 if IS_EVEN_MN & d == BLOCK_K: 

965 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

966 if PRE_LOAD_V: 

967 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

968 elif d == BLOCK_K: 

969 kvmask = col_idx < seqlen_k 

970 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") 

971 if PRE_LOAD_V: 

972 V = tl.load( 

973 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

974 ) 

975 else: 

976 kvmask = col_idx < seqlen_k 

977 K = tl.load( 

978 p_k0 + kv_off, 

979 mask=dmask[:, None] & kvmask[None, :], 

980 cache_modifier=".cg", 

981 other=0.0, 

982 ) 

983 if PRE_LOAD_V: 

984 V = tl.load( 

985 p_v0 + kv_off, 

986 mask=dmask[None, :] & kvmask[:, None], 

987 cache_modifier=".cg", 

988 other=0.0, 

989 ) 

990 

991 S = tl.dot(Q, K) 

992 S = apply_softcap(S, softcap, is_softcap) 

993 S = apply_alibi( 

994 S, 

995 col_idx, 

996 row_idx, 

997 seqlen_q, 

998 seqlen_k, 

999 is_causal=is_causal, 

1000 is_alibi=is_alibi, 

1001 alibi_slope=alibi_slope, 

1002 ) 

1003 S = apply_mask( 

1004 S, 

1005 col_idx, 

1006 row_idx, 

1007 seqlen_q, 

1008 seqlen_k, 

1009 window_size_left, 

1010 window_size_right, 

1011 is_even_mn=IS_EVEN_MN, 

1012 is_causal=is_causal, 

1013 is_local=False, 

1014 ) 

1015 

1016 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1017 acc_, 

1018 S, 

1019 rowmax_, 

1020 rowsum_, 

1021 softmax_scale_log2e=scale_softmax_log2, 

1022 is_border=(is_causal or is_local), 

1023 ) 

1024 

1025 if not PRE_LOAD_V: 

1026 if IS_EVEN_MN & d == BLOCK_K: 

1027 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

1028 elif d == BLOCK_K: 

1029 V = tl.load( 

1030 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

1031 ) 

1032 else: 

1033 V = tl.load( 

1034 p_v0 + kv_off, 

1035 mask=dmask[None, :] & kvmask[:, None], 

1036 cache_modifier=".cg", 

1037 other=0.0, 

1038 ) 

1039 P = P.to(v_ptr.type.element_ty) 

1040 acc_ = tl.dot(P, V, acc_) 

1041 

1042 # LSE 

1043 lse = tl.where( 

1044 rowsum_ == 0 | (rowsum_ != rowsum_), 

1045 float("-inf"), 

1046 rowmax_ * scale_softmax + tl.log(rowsum_), 

1047 ) 

1048 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1049 

1050 # Rescale output 

1051 acc_ *= inv_sum[:, None] 

1052 

1053 # Write back output 

1054 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) 

1055 # grid = (seq_block, split, batch * head) 

1056 o_split_ptr = o_ptr 

1057 # + split, batch, head offsets, seq_block offsets are already added in row_idx 

1058 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d 

1059 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K) 

1060 o_split_ptr = tl.multiple_of(o_split_ptr, d) 

1061 p_om = o_split_ptr + o_split_offset 

1062 

1063 if IS_EVEN_MN & BLOCK_K == d: 

1064 tl.store(p_om, acc_, cache_modifier=".cg") 

1065 else: 

1066 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg") 

1067 

1068 # Write back lse 

1069 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) 

1070 lse_split_ptr = softmax_lse_ptr 

1071 # + split, batch, head, seq_block offsets 

1072 lse_split_ptr += ( 

1073 split_id * tl.num_programs(2) + tl.program_id(2) 

1074 ) * seqlen_q + m_block * BLOCK_M 

1075 

1076 if IS_EVEN_MN: 

1077 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") 

1078 else: 

1079 tl.store( 

1080 lse_split_ptr + tl.arange(0, BLOCK_M), 

1081 lse, 

1082 mask=row_idx < seqlen_q, 

1083 cache_modifier=".cg", 

1084 ) 

1085 

1086 

1087@libentry() 

1088@triton.jit 

1089def flash_fwd_splitkv_combine_kernel( 

1090 out_ptr, 

1091 lse_ptr, 

1092 out_splits_ptr, 

1093 lse_splits_ptr, 

1094 head_size: tl.constexpr, 

1095 out_split_stride, 

1096 lse_split_stride, 

1097 out_b_stride, 

1098 out_s_stride, 

1099 out_h_stride, 

1100 n_splits, 

1101 BLOCK_M: tl.constexpr, 

1102 BLOCK_K: tl.constexpr, 

1103 q_total, 

1104 MAX_N_SPLITS: tl.constexpr, 

1105): 

1106 pid = tl.program_id(0) 

1107 lse_splits_ptr += pid * BLOCK_M 

1108 lse_ptr += pid * BLOCK_M 

1109 out_splits_ptr += pid * BLOCK_M * head_size 

1110 out_ptr += pid * BLOCK_M * head_size 

1111 

1112 # Subtracting maximum from each of the split lse's for better numerical stability 

1113 lse_split_offset = ( 

1114 tl.arange(0, BLOCK_M)[:, None] 

1115 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride 

1116 ) 

1117 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & ( 

1118 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits 

1119 ) 

1120 lse_splits = tl.load( 

1121 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf") 

1122 ) 

1123 max_lse = tl.max(lse_splits, 1) 

1124 

1125 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse) 

1126 Zi_scaled = tl.exp(lse_splits - max_lse[:, None]) 

1127 Z_scaled = tl.sum(Zi_scaled, 1) 

1128 Zi_Z = Zi_scaled / Z_scaled[:, None] 

1129 

1130 # Write back LSE 

1131 lse = tl.log(Z_scaled) + max_lse 

1132 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total 

1133 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask) 

1134 

1135 out_split_offset = ( 

1136 tl.arange(0, BLOCK_M)[:, None, None] * head_size 

1137 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride 

1138 + tl.arange(0, BLOCK_K)[None, None, :] 

1139 ) 

1140 out_split_mask = ( 

1141 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total) 

1142 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits) 

1143 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size) 

1144 ) 

1145 out_splits = tl.load( 

1146 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0 

1147 ) 

1148 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) 

1149 out = out.to(out_ptr.type.element_ty) 

1150 

1151 # Write back output 

1152 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K) 

1153 dmask = tl.arange(0, BLOCK_K) < head_size 

1154 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :]) 

1155 

1156 

1157@triton.jit 

1158def virtual_to_cache( 

1159 virtual_index, 

1160 max_virtual_index, 

1161 page_table_ptr, 

1162 block_size, 

1163 boundary_check: tl.constexpr = False, 

1164): 

1165 # virtual_index is the kv sequence index in the current batch element 

1166 # page_table_ptr is already pointed at current batch element's block table entry 

1167 # block_size is the size of each block in the page table 

1168 virtual_page_index = virtual_index // block_size 

1169 page_offset = virtual_index % block_size 

1170 if boundary_check: 

1171 page_block_index = tl.load( 

1172 page_table_ptr + virtual_page_index, 

1173 mask=virtual_index < max_virtual_index, 

1174 other=0, 

1175 ).to(tl.int32) 

1176 else: 

1177 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32) 

1178 return page_block_index * block_size + page_offset 

1179 

1180 

1181@triton.jit 

1182def load_from_kvcache( 

1183 virtual_index, 

1184 max_virtual_index, 

1185 page_table_ptr, 

1186 k_ptr_base, 

1187 v_ptr_base, 

1188 block_size, 

1189 d: tl.constexpr, 

1190 k_row_stride, 

1191 BLOCK_K: tl.constexpr, 

1192 boundary_check: tl.constexpr = False, 

1193): 

1194 kvcache_idx = virtual_to_cache( 

1195 virtual_index, max_virtual_index, page_table_ptr, block_size, boundary_check 

1196 ) 

1197 k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride 

1198 v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride 

1199 if d == BLOCK_K: 

1200 bK_mask = virtual_index[None, :] < max_virtual_index[None, :] 

1201 bV_mask = virtual_index[:, None] < max_virtual_index[:, None] 

1202 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0) 

1203 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0) 

1204 else: 

1205 bK_mask = (tl.arange(0, BLOCK_K)[:, None] < d) & ( 

1206 virtual_index[None, :] < max_virtual_index[None, :] 

1207 ) 

1208 bV_mask = (tl.arange(0, BLOCK_K)[None, :] < d) & ( 

1209 virtual_index[:, None] < max_virtual_index[:, None] 

1210 ) 

1211 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0) 

1212 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0) 

1213 return bK, bV 

1214 

1215 

1216@libentry() 

1217@triton.jit( 

1218 do_not_specialize=[ 

1219 "q_batch_stride", 

1220 "k_batch_stride", 

1221 "v_batch_stride", 

1222 "o_batch_stride", 

1223 "b", 

1224 "bk", 

1225 "seqlen_q", 

1226 "seqlen_k", 

1227 "seqlen_q_rounded", 

1228 "seqlen_k_rounded", 

1229 "total_q", 

1230 ] 

1231) 

1232def flash_varlen_fwd_kernel( 

1233 q_ptr, 

1234 k_ptr, 

1235 v_ptr, 

1236 o_ptr, 

1237 p_ptr, 

1238 softmax_lse_ptr, 

1239 q_row_stride, 

1240 k_row_stride, 

1241 v_row_stride, 

1242 q_head_stride, 

1243 k_head_stride, 

1244 v_head_stride, 

1245 o_row_stride, 

1246 o_head_stride, 

1247 q_batch_stride, 

1248 k_batch_stride, 

1249 v_batch_stride, 

1250 o_batch_stride, 

1251 is_cu_seqlens_q: tl.constexpr, 

1252 cu_seqlens_q_ptr, 

1253 is_cu_seqlens_k: tl.constexpr, 

1254 cu_seqlens_k_ptr, 

1255 is_seqused_k: tl.constexpr, 

1256 seqused_k_ptr, 

1257 # sizes 

1258 b, 

1259 bk, 

1260 h: tl.constexpr, 

1261 hk: tl.constexpr, 

1262 h_hk_ratio: tl.constexpr, 

1263 seqlen_q, 

1264 seqlen_k, 

1265 seqlen_q_rounded, 

1266 seqlen_k_rounded, 

1267 d: tl.constexpr, 

1268 d_rounded: tl.constexpr, 

1269 # scaling factors 

1270 is_softcap: tl.constexpr, 

1271 softcap: tl.constexpr, 

1272 scale_softmax: tl.constexpr, 

1273 scale_softmax_log2: tl.constexpr, 

1274 # dropout 

1275 is_dropout: tl.constexpr, 

1276 p_dropout: tl.constexpr, 

1277 rp_dropout: tl.constexpr, 

1278 p_dropout_in_uint8_t: tl.constexpr, 

1279 philox_args, 

1280 return_softmax: tl.constexpr, 

1281 # causal and swa 

1282 is_causal: tl.constexpr, 

1283 is_local: tl.constexpr, 

1284 window_size_left: tl.constexpr, 

1285 window_size_right: tl.constexpr, 

1286 seqlenq_ngroups_swapped: tl.constexpr, 

1287 is_paged: tl.constexpr, 

1288 # alibi 

1289 is_alibi: tl.constexpr, 

1290 alibi_slopes_ptr, 

1291 alibi_slopes_batch_stride: tl.constexpr, 

1292 # block table 

1293 total_q, 

1294 page_table_ptr, 

1295 page_table_batch_stride: tl.constexpr, 

1296 block_size: tl.constexpr, 

1297 # kernel params 

1298 BLOCK_M: tl.constexpr, 

1299 BLOCK_N: tl.constexpr, 

1300 BLOCK_K: tl.constexpr, 

1301 num_warps: tl.constexpr, 

1302 num_stages: tl.constexpr, 

1303): 

1304 m_block = tl.program_id(0) 

1305 bid = tl.program_id(1) 

1306 hid = tl.program_id(2) 

1307 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

1308 

1309 if is_cu_seqlens_q: 

1310 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32) 

1311 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32) 

1312 q_len = q_eos - q_bos 

1313 # Current request's start offset in the batched Q 

1314 q_offset = q_bos * q_row_stride 

1315 o_offset = q_bos * o_row_stride 

1316 lse_offset = q_bos * 1 

1317 else: 

1318 q_len = seqlen_q 

1319 q_offset = bid * q_batch_stride 

1320 o_offset = bid * o_batch_stride 

1321 lse_offset = bid * seqlen_q 

1322 

1323 if is_cu_seqlens_k: 

1324 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32) 

1325 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32) 

1326 k_len_cache = k_eos - k_bos 

1327 # k_offset = k_bos * k_row_stride 

1328 else: 

1329 k_len_cache = seqlen_k 

1330 # k_offset = bid * k_batch_stride 

1331 

1332 if is_seqused_k: 

1333 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32) 

1334 else: 

1335 k_len = k_len_cache 

1336 

1337 # Noop CTA 

1338 if m_block * BLOCK_M > q_len: 

1339 return 

1340 

1341 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0) 

1342 is_even_mn: tl.constexpr = False 

1343 

1344 if is_local: 

1345 n_block_min = max( 

1346 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N 

1347 ) 

1348 else: 

1349 n_block_min = 0 

1350 

1351 n_block_max = tl.cdiv(k_len, BLOCK_N) 

1352 if is_causal or is_local: 

1353 n_block_max = min( 

1354 n_block_max, 

1355 tl.cdiv( 

1356 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N 

1357 ), 

1358 ) 

1359 

1360 if is_dropout: 

1361 philox_seed = tl.load(philox_args).to(tl.uint64) 

1362 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

1363 

1364 # Locate the page table entry for the current batch element 

1365 if is_paged: 

1366 page_table_ptr += bid * page_table_batch_stride 

1367 # Calculate the starting offset of q for the current head 

1368 q_row_offset = hid * q_head_stride 

1369 # Calculate the starting offset of k and v for the current head 

1370 k_row_offset = (hid // h_hk_ratio) * k_head_stride 

1371 # Shift the k, v pointers to align with the current head 

1372 k_ptr_base = k_ptr + k_row_offset 

1373 v_ptr_base = v_ptr + k_row_offset 

1374 

1375 gQ = tl.make_block_ptr( 

1376 base=q_ptr + q_offset + q_row_offset, 

1377 shape=(q_len, d), 

1378 strides=(q_row_stride, 1), 

1379 offsets=(0, 0), 

1380 block_shape=(BLOCK_M, BLOCK_K), 

1381 order=(1, 0), 

1382 ) 

1383 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1)) 

1384 

1385 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) 

1386 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

1387 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

1388 

1389 if is_alibi: 

1390 alibi_offset = bid * alibi_slopes_batch_stride + hid 

1391 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

1392 alibi_slope /= scale_softmax 

1393 else: 

1394 alibi_slope = 0.0 

1395 

1396 if not is_causal and not is_local: 

1397 n_masking_steps = 1 

1398 elif is_even_mn: 

1399 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) 

1400 else: 

1401 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 

1402 

1403 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps) 

1404 

1405 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1406 n_block = n_block_max - 1 

1407 for step in tl.range(0, n_masking_steps): 

1408 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1409 if is_paged: 

1410 bK, bV = load_from_kvcache( 

1411 col_idx, 

1412 k_len, 

1413 page_table_ptr, 

1414 k_ptr_base, 

1415 v_ptr_base, 

1416 block_size, 

1417 d, 

1418 k_row_stride, 

1419 BLOCK_K=BLOCK_K, 

1420 boundary_check=True, 

1421 ) 

1422 else: 

1423 start_n = n_block * BLOCK_N 

1424 k_ptr_seq = k_ptr_base + k_bos * k_row_stride 

1425 v_ptr_seq = v_ptr_base + k_bos * k_row_stride 

1426 gK = tl.make_block_ptr( 

1427 base=k_ptr_seq, 

1428 shape=(k_len, d), 

1429 strides=(k_row_stride, 1), 

1430 offsets=(start_n, 0), 

1431 block_shape=(BLOCK_N, BLOCK_K), 

1432 order=(0, 1), 

1433 ) 

1434 gV = tl.make_block_ptr( 

1435 base=v_ptr_seq, 

1436 shape=(k_len, d), 

1437 strides=(k_row_stride, 1), 

1438 offsets=(start_n, 0), 

1439 block_shape=(BLOCK_N, BLOCK_K), 

1440 order=(0, 1), 

1441 ) 

1442 bK = tl.load(gK, boundary_check=(0, 1)) 

1443 bK = tl.trans(bK) 

1444 bV = tl.load(gV, boundary_check=(0, 1)) 

1445 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1446 S = apply_softcap(S, softcap, is_softcap) 

1447 S = apply_alibi( 

1448 S, 

1449 col_idx, 

1450 row_idx, 

1451 q_len, 

1452 k_len, 

1453 is_causal=is_causal, 

1454 is_alibi=is_alibi, 

1455 alibi_slope=alibi_slope, 

1456 ) 

1457 S = apply_mask( 

1458 S, 

1459 col_idx, 

1460 row_idx, 

1461 q_len, 

1462 k_len, 

1463 window_size_left, 

1464 window_size_right, 

1465 is_even_mn=is_even_mn, 

1466 is_causal=is_causal, 

1467 is_local=is_local, 

1468 ) 

1469 

1470 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1471 acc_, 

1472 S, 

1473 rowmax_, 

1474 rowsum_, 

1475 softmax_scale_log2e=scale_softmax_log2, 

1476 is_border=True, 

1477 ) 

1478 P = P.to(v_ptr.type.element_ty) 

1479 

1480 if is_dropout: 

1481 P = apply_dropout( 

1482 P, 

1483 n_block * BLOCK_N, 

1484 m_block * BLOCK_M, 

1485 k_len, 

1486 bid, 

1487 hid, 

1488 philox_seed, 

1489 philox_offset, 

1490 p_dropout_in_uint8_t, 

1491 is_dropout, 

1492 encode_dropout_in_sign_bit=False, 

1493 NUM_HEADS=h, 

1494 BLOCK_M=BLOCK_M, 

1495 BLOCK_N=BLOCK_N, 

1496 ) 

1497 

1498 acc_ = tl.dot(P, bV, acc_) 

1499 n_block -= 1 

1500 

1501 for n_block in tl.range( 

1502 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1 

1503 ): 

1504 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1505 if is_paged: 

1506 bK, bV = load_from_kvcache( 

1507 col_idx, 

1508 k_len, 

1509 page_table_ptr, 

1510 k_ptr_base, 

1511 v_ptr_base, 

1512 block_size, 

1513 d, 

1514 k_row_stride, 

1515 BLOCK_K=BLOCK_K, 

1516 ) 

1517 else: 

1518 start_n = n_block * BLOCK_N 

1519 k_ptr_seq = k_ptr_base + k_bos * k_row_stride 

1520 v_ptr_seq = v_ptr_base + k_bos * k_row_stride 

1521 gK = tl.make_block_ptr( 

1522 base=k_ptr_seq, 

1523 shape=(k_len, d), 

1524 strides=(k_row_stride, 1), 

1525 offsets=(start_n, 0), 

1526 block_shape=(BLOCK_N, BLOCK_K), 

1527 order=(0, 1), 

1528 ) 

1529 gV = tl.make_block_ptr( 

1530 base=v_ptr_seq, 

1531 shape=(k_len, d), 

1532 strides=(k_row_stride, 1), 

1533 offsets=(start_n, 0), 

1534 block_shape=(BLOCK_N, BLOCK_K), 

1535 order=(0, 1), 

1536 ) 

1537 bK = tl.load(gK) 

1538 bK = tl.trans(bK) 

1539 bV = tl.load(gV) 

1540 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1541 S = apply_softcap(S, softcap, is_softcap) 

1542 S = apply_alibi( 

1543 S, 

1544 col_idx, 

1545 row_idx, 

1546 q_len, 

1547 k_len, 

1548 is_causal=is_causal, 

1549 is_alibi=is_alibi, 

1550 alibi_slope=alibi_slope, 

1551 ) 

1552 S = apply_mask( 

1553 S, 

1554 col_idx, 

1555 row_idx, 

1556 q_len, 

1557 k_len, 

1558 window_size_left, 

1559 window_size_right, 

1560 is_even_mn=True, 

1561 is_causal=False, 

1562 is_local=is_local, 

1563 ) 

1564 

1565 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1566 acc_, 

1567 S, 

1568 rowmax_, 

1569 rowsum_, 

1570 softmax_scale_log2e=scale_softmax_log2, 

1571 is_border=is_local, 

1572 ) 

1573 P = P.to(v_ptr.type.element_ty) 

1574 

1575 if is_dropout: 

1576 P = apply_dropout( 

1577 P, 

1578 m_block * BLOCK_M, 

1579 n_block * BLOCK_N, 

1580 k_len, 

1581 bid, 

1582 hid, 

1583 philox_seed, 

1584 philox_offset, 

1585 p_dropout_in_uint8_t, 

1586 is_dropout, 

1587 encode_dropout_in_sign_bit=False, 

1588 NUM_HEADS=h, 

1589 BLOCK_M=BLOCK_M, 

1590 BLOCK_N=BLOCK_N, 

1591 ) 

1592 acc_ = tl.dot(P, bV, acc_) 

1593 

1594 # LSE 

1595 lse = tl.where( 

1596 rowsum_ == 0 | (rowsum_ != rowsum_), 

1597 float("inf"), 

1598 rowmax_ * scale_softmax + tl.log(rowsum_), 

1599 ) 

1600 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1601 

1602 acc_ *= inv_sum[:, None] 

1603 

1604 out = acc_.to(o_ptr.type.element_ty) # noqa 

1605 

1606 # Write back output 

1607 o_row_offset = hid * o_head_stride 

1608 

1609 gO = tl.make_block_ptr( 

1610 base=o_ptr + o_offset + o_row_offset, 

1611 shape=(q_len, d), 

1612 strides=(o_row_stride, 1), 

1613 offsets=(0, 0), 

1614 block_shape=(BLOCK_M, BLOCK_K), 

1615 order=(1, 0), 

1616 ) 

1617 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1)) 

1618 

1619 # Write back lse 

1620 # lse shape: [h, total_q] 

1621 softmax_lse_ptr += hid * total_q 

1622 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1623 tl.store( 

1624 softmax_lse_ptr + lse_row_offset, 

1625 lse, 

1626 mask=lse_row_offset < (lse_offset + q_len), 

1627 )