Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/flash_kernel.py: 0%

521 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import triton 

2import triton.language as tl 

3 

4from flag_gems import runtime 

5from flag_gems.utils import libentry, libtuner, tl_extra_shim 

6 

7 

8@triton.jit 

9def u64_to_lohi(x): 

10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32) 

11 

12 

13@triton.jit 

14def u64_from_lohi(lo, hi): 

15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64) 

16 

17 

18@triton.jit 

19def philox_(seed, subsequence, offset): 

20 kPhilox10A: tl.constexpr = 0x9E3779B9 

21 kPhilox10B: tl.constexpr = 0xBB67AE85 

22 k0, k1 = u64_to_lohi(seed.to(tl.uint64)) 

23 c0, c1 = u64_to_lohi(offset.to(tl.uint64)) 

24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64)) 

25 

26 # pragma unroll 

27 kPhiloxSA: tl.constexpr = 0xD2511F53 

28 kPhiloxSB: tl.constexpr = 0xCD9E8D57 

29 for _ in tl.static_range(6): 

30 res0 = kPhiloxSA * c0.to(tl.uint64) 

31 res1 = kPhiloxSB * c2.to(tl.uint64) 

32 res0_x, res0_y = u64_to_lohi(res0) 

33 res1_x, res1_y = u64_to_lohi(res1) 

34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

35 k0 += kPhilox10A 

36 k1 += kPhilox10B 

37 

38 res0 = kPhiloxSA * c0.to(tl.uint64) 

39 res1 = kPhiloxSB * c2.to(tl.uint64) 

40 res0_x, res0_y = u64_to_lohi(res0) 

41 res1_x, res1_y = u64_to_lohi(res1) 

42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x 

43 

44 return c0, c1, c2, c3 

45 

46 

47@triton.jit 

48def apply_dropout_mask( 

49 P, 

50 mask, 

51 encode_dropout_in_sign_bit: tl.constexpr, 

52): 

53 if encode_dropout_in_sign_bit: 

54 P = tl.where(mask, -P, P) 

55 else: 

56 P = tl.where(mask, (P * 0).to(P.dtype), P) 

57 return P 

58 

59 

60@triton.jit 

61def apply_dropout( 

62 P, 

63 row_start, 

64 col_start, 

65 n_cols, 

66 bid, 

67 hid, 

68 philox_seed, 

69 philox_offset, 

70 p_dropout_uint8: tl.constexpr, 

71 is_dropout: tl.constexpr, 

72 encode_dropout_in_sign_bit: tl.constexpr, 

73 NUM_HEADS: tl.constexpr, 

74 BLOCK_M: tl.constexpr, 

75 BLOCK_N: tl.constexpr, 

76): 

77 if is_dropout: 

78 row_start = tl.multiple_of(row_start, BLOCK_M) 

79 col_start = tl.multiple_of(col_start, BLOCK_N) 

80 row = row_start + tl.arange(0, BLOCK_M)[:, None] 

81 # Down scale col_idx by 4 

82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :] 

83 

84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64) 

85 

86 offset = philox_offset + bid * NUM_HEADS + hid 

87 offset += subsequence * 0 

88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset) 

89 

90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N) 

91 

92 mask = (r & 0xFF) >= p_dropout_uint8 

93 

94 P = apply_dropout_mask( 

95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit 

96 ) 

97 return P 

98 

99 

100@triton.jit 

101def apply_alibi( 

102 S, 

103 col_idx, 

104 row_idx, 

105 max_seqlen_q, 

106 max_seqlen_k, 

107 is_causal: tl.constexpr, 

108 is_alibi: tl.constexpr, 

109 alibi_slope: tl.constexpr = None, 

110): 

111 if is_alibi: 

112 if is_causal: 

113 # The row independent alibi bias renders the same attention output 

114 # as with the standard alibi because softmax is shift invariant, i.e., 

115 # softmax(A + bias + const) = softamx(A + bias). The following two 

116 # biases are no different if causal is true. 

117 # bias_1 = [ 

118 # -4, -3, -2, X, X, 

119 # -4, -3, -2, -1, X, 

120 # -4, -3, -2, -1, 0, 

121 # ] 

122 # bias_2 = [ 

123 # -2, -1, 0, X, X, 

124 # -3, -2, -1, 0, X, 

125 # -4, -3, -2, -1, 0, 

126 # ] 

127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32) 

128 S += bias 

129 else: 

130 bias = -alibi_slope * tl.abs( 

131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None] 

132 ).to(tl.float32) 

133 S += bias 

134 

135 return S 

136 

137 

138@triton.jit 

139def apply_mask( 

140 S, 

141 col_idx, 

142 row_idx, 

143 max_seqlen_q, 

144 max_seqlen_k, 

145 window_size_left, 

146 window_size_right, 

147 is_even_mn: tl.constexpr, 

148 is_causal: tl.constexpr, 

149 is_local: tl.constexpr, 

150): 

151 need_mask = is_causal | is_local | (not is_even_mn) 

152 # need_mask: tl.constexpr = is_causal | is_local 

153 if need_mask: 

154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive! 

155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left) 

156 col_rb = min( 

157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right 

158 ) 

159 

160 if is_causal: 

161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S) 

162 

163 if is_local: 

164 S = tl.where( 

165 (col_idx[None, :] > col_rb[:, None]) 

166 | (col_idx[None, :] < col_lb[:, None]), 

167 float("-inf"), 

168 S, 

169 ) 

170 

171 if (not is_local) & (not is_causal) & (not is_even_mn): 

172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S) 

173 

174 return S 

175 

176 

177@triton.jit 

178def softmax_rescale( 

179 O_acc, 

180 S, 

181 row_max, 

182 row_sum, 

183 softmax_scale_log2e: tl.constexpr, 

184 is_border: tl.constexpr, 

185 # is_init: tl.constexpr 

186): 

187 prev_max = row_max 

188 row_max = tl.maximum(row_max, tl.max(S, 1)) 

189 

190 if is_border: 

191 cur_max = tl.where(row_max == float("-inf"), 0, row_max) 

192 else: 

193 cur_max = row_max 

194 

195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e) 

196 row_sum *= p_scale 

197 O_acc *= p_scale[:, None] 

198 

199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e) 

200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None]) 

201 row_sum = row_sum + tl.sum(P, 1) 

202 return O_acc, P, row_max, row_sum 

203 

204 

205@triton.jit 

206def apply_softcap(S, softcap, is_softcap: tl.constexpr): 

207 if is_softcap: 

208 S = tl_extra_shim.tanh(S * softcap) 

209 

210 return S 

211 

212 

213def block_m_splitkv_heuristic(headdim): 

214 return 128 if headdim <= 128 else 64 

215 

216 

217def block_n_splitkv_heuristic(headdim): 

218 return 64 if headdim <= 64 else 32 

219 

220 

221def is_even_mn(M, N, BM, BN, WL, WR): 

222 if M % BM == 0 and N % BN == 0: 

223 if M % N == 0 or N % M == 0: 

224 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0): 

225 return True 

226 return False 

227 

228 

229def block_m_splitkv_heuristic_spec_args(args): 

230 return 128 if args["d"] <= 128 else 64 

231 

232 

233def block_n_splitkv_heuristic_spec_args(args): 

234 return 64 if args["d"] <= 64 else 32 

235 

236 

237def is_even_mn_spec_args(args): 

238 if ( 

239 args["seqlen_q"] % args["BLOCK_M"] == 0 

240 and args["seqlen_k"] % args["BLOCK_N"] == 0 

241 ): 

242 if ( 

243 args["seqlen_q"] % args["seqlen_k"] == 0 

244 or args["seqlen_k"] % args["seqlen_q"] == 0 

245 ): 

246 if ( 

247 args["window_size_left"] == -1 

248 or args["window_size_left"] % args["BLOCK_N"] == 0 

249 ) and ( 

250 args["window_size_right"] == -1 

251 or args["window_size_right"] % args["BLOCK_N"] == 0 

252 ): 

253 return True 

254 return False 

255 

256 

257def keep(cfg): 

258 BM = cfg.kwargs["BLOCK_M"] 

259 BN = cfg.kwargs["BLOCK_N"] 

260 w = cfg.num_warps 

261 

262 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) 

263 

264 

265def prune_fwd_configs(configs, nargs, **kwargs): 

266 is_dropout = nargs["is_dropout"] 

267 if is_dropout: 

268 return list( 

269 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs) 

270 ) 

271 else: 

272 seqlen_q = nargs["seqlen_q"] 

273 if seqlen_q >= 1024: 

274 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 512, configs)) 

275 elif seqlen_q >= 512: 

276 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 256, configs)) 

277 elif seqlen_q >= 256: 

278 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 128, configs)) 

279 elif seqlen_q >= 128: 

280 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 64, configs)) 

281 elif seqlen_q >= 64: 

282 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 32, configs)) 

283 return configs 

284 

285 

286@libentry() 

287@libtuner( 

288 # configs=list(filter(keep, runtime.get_tuned_config("attention"))), 

289 configs=runtime.get_tuned_config("attention"), 

290 prune_configs_by={"early_config_prune": prune_fwd_configs}, 

291 key=["seqlen_q", "d", "is_dropout"], 

292 strategy=[ 

293 "align32", 

294 "align32", 

295 lambda a: a, 

296 ], 

297 warmup=1, 

298 rep=1, 

299) 

300@triton.heuristics( 

301 values={ 

302 "PRE_LOAD_V": lambda args: False, 

303 "IS_EVEN_MN": lambda args: is_even_mn( 

304 args["seqlen_q"], 

305 args["seqlen_k"], 

306 args["BLOCK_M"], 

307 args["BLOCK_N"], 

308 args["window_size_left"], 

309 args["window_size_right"], 

310 ), 

311 } 

312) 

313@triton.jit( 

314 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

315) 

316def flash_fwd_kernel( 

317 q_ptr, 

318 k_ptr, 

319 v_ptr, 

320 o_ptr, 

321 p_ptr, 

322 softmax_lse_ptr, 

323 q_row_stride, 

324 k_row_stride, 

325 v_row_stride, 

326 q_head_stride, 

327 k_head_stride, 

328 v_head_stride, 

329 o_row_stride, 

330 o_head_stride, 

331 q_batch_stride, 

332 k_batch_stride, 

333 v_batch_stride, 

334 o_batch_stride, 

335 is_cu_seqlens_q, 

336 cu_seqlens_q_ptr, 

337 is_cu_seqlens_k, 

338 cu_seqlens_k_ptr, 

339 is_seqused_k, 

340 seqused_k_ptr, 

341 # sizes 

342 b: tl.constexpr, 

343 bk: tl.constexpr, 

344 h: tl.constexpr, 

345 hk: tl.constexpr, 

346 h_hk_ratio: tl.constexpr, 

347 seqlen_q, 

348 seqlen_k, 

349 seqlen_q_rounded, 

350 seqlen_k_rounded, 

351 d: tl.constexpr, 

352 d_rounded: tl.constexpr, 

353 # scaling factors 

354 is_softcap: tl.constexpr, 

355 softcap: tl.constexpr, 

356 scale_softmax: tl.constexpr, 

357 scale_softmax_log2: tl.constexpr, 

358 # dropout 

359 is_dropout: tl.constexpr, 

360 p_dropout: tl.constexpr, 

361 rp_dropout: tl.constexpr, 

362 p_dropout_in_uint8_t: tl.constexpr, 

363 philox_args, 

364 return_softmax: tl.constexpr, 

365 # causal and swa 

366 is_causal: tl.constexpr, 

367 is_local: tl.constexpr, 

368 window_size_left: tl.constexpr, 

369 window_size_right: tl.constexpr, 

370 seqlenq_ngroups_swapped: tl.constexpr, 

371 # alibi 

372 is_alibi: tl.constexpr, 

373 alibi_slopes_ptr, 

374 alibi_slopes_batch_stride: tl.constexpr, 

375 # block table 

376 total_q: tl.constexpr, 

377 page_table_ptr, 

378 page_table_batch_stride: tl.constexpr, 

379 block_size: tl.constexpr, 

380 # kernel params 

381 IS_EVEN_MN: tl.constexpr, 

382 PRE_LOAD_V: tl.constexpr, 

383 BLOCK_M: tl.constexpr, 

384 BLOCK_N: tl.constexpr, 

385 num_warps: tl.constexpr, 

386 num_stages: tl.constexpr, 

387): 

388 m_block = tl.program_id(0) 

389 bh = tl.program_id(1) 

390 hid = bh % h 

391 bid = bh // h 

392 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

393 

394 # We draw a minimum covering frame on the attention map that this CTA is assigned to process. 

395 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively. 

396 

397 col_min = 0 

398 if is_local: 

399 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left) 

400 if not IS_EVEN_MN: 

401 # round left 

402 col_min = (col_min // BLOCK_N) * BLOCK_N 

403 

404 col_max = seqlen_k 

405 if is_causal or is_local: 

406 col_max += (m_block - num_m_blocks + 1) * BLOCK_M 

407 if is_local: 

408 col_max += window_size_right 

409 col_max = min(seqlen_k, col_max) 

410 

411 if not IS_EVEN_MN: 

412 # round right 

413 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N 

414 

415 if (not is_causal) and (not is_local): 

416 if IS_EVEN_MN: 

417 masking_cols: tl.constexpr = 0 

418 else: 

419 masking_cols: tl.constexpr = BLOCK_N 

420 elif ( 

421 is_causal | is_local 

422 ) and IS_EVEN_MN: # causal implies window_size_right is zero 

423 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N 

424 else: 

425 # local 

426 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N 

427 

428 if is_dropout: 

429 philox_seed = tl.load(philox_args).to(tl.uint64) 

430 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

431 

432 if is_alibi: 

433 alibi_offset = bid * alibi_slopes_batch_stride + hid 

434 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

435 alibi_slope /= scale_softmax 

436 else: 

437 alibi_slope = 0.0 

438 

439 q_batch_stride = tl.multiple_of(q_batch_stride, d * h) 

440 q_ptr += bid * q_batch_stride + hid * q_head_stride 

441 row_start = m_block * BLOCK_M 

442 row_idx = row_start + tl.arange(0, BLOCK_M) 

443 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, d)[None, :] 

444 qmask = row_idx[:, None] < seqlen_q 

445 if IS_EVEN_MN: 

446 Q = tl.load(q_ptr + q_off, cache_modifier=".cg") 

447 else: 

448 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg") 

449 

450 if return_softmax: 

451 p_ptr += ( 

452 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M 

453 ) * seqlen_k_rounded 

454 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange( 

455 0, BLOCK_N 

456 ) 

457 p_bp0 = p_ptr + p_offset 

458 

459 acc_ = tl.zeros((BLOCK_M, d), dtype=tl.float32) 

460 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

461 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

462 

463 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk) 

464 h_hk_ratio = h // hk 

465 k_ptr += bid * k_batch_stride 

466 k_ptr += (hid // h_hk_ratio) * k_head_stride 

467 v_ptr += bid * k_batch_stride 

468 v_ptr += (hid // h_hk_ratio) * k_head_stride 

469 

470 k_offset = tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, d)[:, None] 

471 v_offset = tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, d)[None, :] 

472 

473 p_bk0 = k_ptr + k_offset 

474 p_bv0 = v_ptr + v_offset 

475 

476 if is_causal | is_local | (not IS_EVEN_MN): 

477 # Cut short masking cols if there's not enough cols out there 

478 masking_cols = min(col_max - col_min, masking_cols) 

479 for col_shift in tl.range(0, masking_cols, step=BLOCK_N): 

480 col_start = col_max - col_shift - BLOCK_N 

481 col_start = tl.multiple_of(col_start, BLOCK_N) 

482 off = col_start * k_row_stride 

483 if IS_EVEN_MN: 

484 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

485 if PRE_LOAD_V: 

486 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

487 else: 

488 col_idx = col_start + tl.arange(0, BLOCK_N) 

489 kvmask = col_idx < seqlen_k 

490 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg") 

491 if PRE_LOAD_V: 

492 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

493 S = tl.dot(Q, K, allow_tf32=False) 

494 S = apply_softcap(S, softcap, is_softcap) 

495 col_idx = col_start + tl.arange(0, BLOCK_N) 

496 row_idx = row_start + tl.arange(0, BLOCK_M) 

497 S = apply_alibi( 

498 S, 

499 col_idx, 

500 row_idx, 

501 seqlen_q, 

502 seqlen_k, 

503 is_causal=is_causal, 

504 is_alibi=is_alibi, 

505 alibi_slope=alibi_slope, 

506 ) 

507 # tl.store(p_bp0 + col_start, S) 

508 S = apply_mask( 

509 S, 

510 col_idx, 

511 row_idx, 

512 seqlen_q, 

513 seqlen_k, 

514 window_size_left, 

515 window_size_right, 

516 is_even_mn=IS_EVEN_MN, 

517 is_causal=is_causal, 

518 is_local=is_local, 

519 ) 

520 

521 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

522 acc_, 

523 S, 

524 rowmax_, 

525 rowsum_, 

526 softmax_scale_log2e=scale_softmax_log2, 

527 is_border=(is_causal or is_local), 

528 ) 

529 P = P.to(v_ptr.type.element_ty) 

530 

531 if is_dropout: 

532 if return_softmax: 

533 P_drop = P 

534 

535 P_drop = apply_dropout( 

536 P_drop, 

537 row_start, 

538 col_start, 

539 seqlen_k, 

540 bid, 

541 hid, 

542 philox_seed, 

543 philox_offset, 

544 p_dropout_in_uint8_t, 

545 is_dropout, 

546 encode_dropout_in_sign_bit=True, 

547 NUM_HEADS=h, 

548 BLOCK_M=BLOCK_M, 

549 BLOCK_N=BLOCK_N, 

550 ) 

551 if IS_EVEN_MN: 

552 tl.store(p_bp0 + col_start, P_drop) 

553 else: 

554 kvmask = col_idx < seqlen_k 

555 tl.store( 

556 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :] 

557 ) 

558 

559 P = apply_dropout( 

560 P, 

561 row_start, 

562 col_start, 

563 seqlen_k, 

564 bid, 

565 hid, 

566 philox_seed, 

567 philox_offset, 

568 p_dropout_in_uint8_t, 

569 is_dropout, 

570 encode_dropout_in_sign_bit=False, 

571 NUM_HEADS=h, 

572 BLOCK_M=BLOCK_M, 

573 BLOCK_N=BLOCK_N, 

574 ) 

575 

576 if not PRE_LOAD_V: 

577 off = col_start * k_row_stride 

578 if IS_EVEN_MN: 

579 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

580 else: 

581 kvmask = col_idx < seqlen_k 

582 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg") 

583 acc_ = tl.dot(P, V, acc_, allow_tf32=False) 

584 

585 for col_start in tl.range( 

586 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages 

587 ): 

588 col_start = tl.multiple_of(col_start, BLOCK_N) 

589 off = col_start * k_row_stride 

590 K = tl.load(p_bk0 + off, cache_modifier=".cg") 

591 if PRE_LOAD_V: 

592 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

593 S = tl.dot(Q, K) 

594 S = apply_softcap(S, softcap, is_softcap) 

595 col_idx = col_start + tl.arange(0, BLOCK_N) 

596 row_idx = row_start + tl.arange(0, BLOCK_M) 

597 S = apply_alibi( 

598 S, 

599 col_idx, 

600 row_idx, 

601 seqlen_q, 

602 seqlen_k, 

603 is_causal=is_causal, 

604 is_alibi=is_alibi, 

605 alibi_slope=alibi_slope, 

606 ) 

607 S = apply_mask( 

608 S, 

609 col_idx, 

610 row_idx, 

611 seqlen_q, 

612 seqlen_k, 

613 window_size_left, 

614 window_size_right, 

615 is_even_mn=True, 

616 is_causal=False, 

617 is_local=is_local, 

618 ) 

619 

620 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

621 acc_, 

622 S, 

623 rowmax_, 

624 rowsum_, 

625 softmax_scale_log2e=scale_softmax_log2, 

626 is_border=is_local, 

627 ) 

628 P = P.to(v_ptr.type.element_ty) 

629 

630 if is_dropout: 

631 if return_softmax: 

632 P_drop = P 

633 P_drop = apply_dropout( 

634 P_drop, 

635 row_start, 

636 col_start, 

637 seqlen_k, 

638 bid, 

639 hid, 

640 philox_seed, 

641 philox_offset, 

642 p_dropout_in_uint8_t, 

643 is_dropout, 

644 encode_dropout_in_sign_bit=True, 

645 NUM_HEADS=h, 

646 BLOCK_M=BLOCK_M, 

647 BLOCK_N=BLOCK_N, 

648 ) 

649 if IS_EVEN_MN: 

650 tl.store(p_bp0 + col_start, P_drop) 

651 else: 

652 kvmask = col_idx < seqlen_k 

653 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]) 

654 

655 P = apply_dropout( 

656 P, 

657 row_start, 

658 col_start, 

659 seqlen_k, 

660 bid, 

661 hid, 

662 philox_seed, 

663 philox_offset, 

664 p_dropout_in_uint8_t, 

665 is_dropout, 

666 encode_dropout_in_sign_bit=False, 

667 NUM_HEADS=h, 

668 BLOCK_M=BLOCK_M, 

669 BLOCK_N=BLOCK_N, 

670 ) 

671 

672 if not PRE_LOAD_V: 

673 off = col_start * k_row_stride 

674 V = tl.load(p_bv0 + off, cache_modifier=".cg") 

675 acc_ = tl.dot(P, V, acc_) 

676 

677 # LSE 

678 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels 

679 # the effect of rowmax and outputs lse only. 

680 lse = tl.where( 

681 rowsum_ == 0 | (rowsum_ != rowsum_), 

682 float("inf"), 

683 rowmax_ * scale_softmax + tl.log(rowsum_), 

684 ) 

685 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

686 

687 if is_dropout: 

688 acc_ *= inv_sum[:, None] * rp_dropout 

689 else: 

690 acc_ *= inv_sum[:, None] 

691 

692 out = acc_.to(o_ptr.type.element_ty) # noqa 

693 

694 # Write back output 

695 o_batch_stride = tl.multiple_of(o_batch_stride, d * h) 

696 o_ptr += bid * o_batch_stride 

697 o_ptr += hid * o_head_stride 

698 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, d) 

699 

700 if IS_EVEN_MN: 

701 tl.store(o_ptr + o_offset, out) 

702 else: 

703 tl.store(o_ptr + o_offset, out, mask=qmask) 

704 

705 # Write back lse 

706 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q 

707 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

708 

709 if IS_EVEN_MN: 

710 tl.store(p_lse + row_idx, lse) 

711 else: 

712 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q) 

713 

714 

715@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"]) 

716def flash_fwd_bh_parallel_kernel(): 

717 # (TODO) 

718 pass 

719 

720 

721@libentry() 

722@triton.heuristics( 

723 values={ 

724 "BLOCK_M": block_m_splitkv_heuristic_spec_args, 

725 "BLOCK_N": block_n_splitkv_heuristic_spec_args, 

726 "num_warps": lambda args: 4, 

727 "num_stages": lambda args: 3, 

728 "PRE_LOAD_V": lambda args: True, 

729 "IS_EVEN_MN": is_even_mn_spec_args, 

730 } 

731) 

732@triton.jit( 

733 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"] 

734) 

735def flash_fwd_splitkv_kernel( 

736 q_ptr, 

737 k_ptr, 

738 v_ptr, 

739 o_ptr, 

740 p_ptr, 

741 softmax_lse_ptr, 

742 q_row_stride, 

743 k_row_stride, 

744 v_row_stride, 

745 q_head_stride, 

746 k_head_stride, 

747 v_head_stride, 

748 o_row_stride, 

749 o_head_stride, 

750 q_batch_stride, 

751 k_batch_stride, 

752 v_batch_stride, 

753 o_batch_stride, 

754 is_cu_seqlens_q, 

755 cu_seqlens_q_ptr, 

756 is_cu_seqlens_k: tl.constexpr, 

757 cu_seqlens_k_ptr, 

758 is_seqused_k: tl.constexpr, 

759 seqused_k_ptr, 

760 # sizes 

761 b: tl.constexpr, 

762 bk: tl.constexpr, 

763 h: tl.constexpr, 

764 hk: tl.constexpr, 

765 h_hk_ratio: tl.constexpr, 

766 seqlen_q, 

767 seqlen_k, 

768 seqlen_q_rounded, 

769 seqlen_k_rounded, 

770 d: tl.constexpr, 

771 d_rounded: tl.constexpr, 

772 # scaling factors 

773 is_softcap: tl.constexpr, 

774 softcap: tl.constexpr, 

775 scale_softmax: tl.constexpr, 

776 scale_softmax_log2: tl.constexpr, 

777 # dropout 

778 is_dropout: tl.constexpr, 

779 p_dropout: tl.constexpr, 

780 rp_dropout: tl.constexpr, 

781 p_dropout_in_uint8_t: tl.constexpr, 

782 philox_args, 

783 return_softmax: tl.constexpr, 

784 # causal and swa 

785 is_causal: tl.constexpr, 

786 is_local: tl.constexpr, 

787 window_size_left: tl.constexpr, 

788 window_size_right: tl.constexpr, 

789 seqlenq_ngroups_swapped: tl.constexpr, 

790 # alibi 

791 is_alibi: tl.constexpr, 

792 alibi_slopes_ptr, 

793 alibi_slopes_batch_stride: tl.constexpr, 

794 # block table 

795 total_q, 

796 page_table_ptr, 

797 page_table_batch_stride: tl.constexpr, 

798 block_size: tl.constexpr, 

799 # kernel params 

800 IS_EVEN_MN: tl.constexpr, 

801 PRE_LOAD_V: tl.constexpr, 

802 blocks_per_split: tl.constexpr, 

803 BLOCK_M: tl.constexpr, 

804 BLOCK_N: tl.constexpr, 

805 num_warps: tl.constexpr, 

806 num_stages: tl.constexpr, 

807): 

808 m_block = tl.program_id(0) 

809 split_id = tl.program_id(1) 

810 bid = tl.program_id(2) // h 

811 hid = tl.program_id(2) % h 

812 

813 split_block_min = split_id * blocks_per_split 

814 split_block_max = split_block_min + blocks_per_split 

815 

816 n_block_max = tl.cdiv(seqlen_k, BLOCK_N) 

817 if is_causal: 

818 n_block_max = min( 

819 n_block_max, 

820 tl.cdiv( 

821 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right, 

822 BLOCK_N, 

823 ), 

824 ) 

825 

826 if is_alibi: 

827 alibi_offset = bid * alibi_slopes_batch_stride + hid 

828 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

829 alibi_slope /= scale_softmax 

830 else: 

831 alibi_slope = 0 

832 

833 if not is_causal: 

834 if IS_EVEN_MN: 

835 masking_block_min = n_block_max 

836 else: 

837 masking_block_min = n_block_max - 1 

838 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero 

839 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) 

840 else: 

841 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1 

842 

843 q_ptr += bid * q_batch_stride 

844 q_ptr += hid * q_head_stride 

845 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

846 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, d)[None, :] 

847 p_qm = q_ptr + q_off 

848 qmask = row_idx[:, None] < seqlen_q 

849 if IS_EVEN_MN: 

850 Q = tl.load(p_qm, cache_modifier=".cg") 

851 else: 

852 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg") 

853 

854 h_hk_ratio = h // hk 

855 k_ptr += bid * k_batch_stride 

856 k_ptr += (hid // h_hk_ratio) * k_head_stride 

857 v_ptr += bid * k_batch_stride 

858 v_ptr += (hid // h_hk_ratio) * k_head_stride 

859 

860 k_offset = tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, d)[:, None] 

861 p_k0 = k_ptr + k_offset 

862 

863 v_offset = tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, d)[None, :] 

864 p_v0 = v_ptr + v_offset 

865 

866 acc_ = tl.zeros((BLOCK_M, d), dtype=tl.float32) 

867 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

868 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

869 

870 if split_block_max <= masking_block_min: 

871 # no masking needed 

872 for n_block in tl.range( 

873 split_block_min, split_block_max, num_stages=num_stages 

874 ): 

875 kv_off = n_block * BLOCK_N * k_row_stride 

876 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

877 if PRE_LOAD_V: 

878 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

879 S = tl.dot(Q, K) 

880 S = apply_softcap(S, softcap, is_softcap) 

881 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

882 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

883 S = apply_alibi( 

884 S, 

885 col_idx, 

886 row_idx, 

887 seqlen_q, 

888 seqlen_k, 

889 is_causal=is_causal, 

890 is_alibi=is_alibi, 

891 alibi_slope=alibi_slope, 

892 ) 

893 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

894 acc_, 

895 S, 

896 rowmax_, 

897 rowsum_, 

898 softmax_scale_log2e=scale_softmax_log2, 

899 is_border=False, 

900 ) 

901 

902 if not PRE_LOAD_V: 

903 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

904 P = P.to(v_ptr.type.element_ty) 

905 acc_ = tl.dot(P, V, acc_) 

906 else: 

907 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)): 

908 kv_off = n_block * BLOCK_N * k_row_stride 

909 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

910 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

911 if IS_EVEN_MN: 

912 K = tl.load(p_k0 + kv_off, cache_modifier=".cg") 

913 if PRE_LOAD_V: 

914 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

915 else: 

916 kvmask = col_idx < seqlen_k 

917 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg") 

918 if PRE_LOAD_V: 

919 V = tl.load( 

920 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

921 ) 

922 

923 S = tl.dot(Q, K) 

924 S = apply_softcap(S, softcap, is_softcap) 

925 S = apply_alibi( 

926 S, 

927 col_idx, 

928 row_idx, 

929 seqlen_q, 

930 seqlen_k, 

931 is_causal=is_causal, 

932 is_alibi=is_alibi, 

933 alibi_slope=alibi_slope, 

934 ) 

935 S = apply_mask( 

936 S, 

937 col_idx, 

938 row_idx, 

939 seqlen_q, 

940 seqlen_k, 

941 window_size_left, 

942 window_size_right, 

943 is_even_mn=IS_EVEN_MN, 

944 is_causal=is_causal, 

945 is_local=False, 

946 ) 

947 

948 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

949 acc_, 

950 S, 

951 rowmax_, 

952 rowsum_, 

953 softmax_scale_log2e=scale_softmax_log2, 

954 is_border=(is_causal or is_local), 

955 ) 

956 

957 if not PRE_LOAD_V: 

958 if IS_EVEN_MN: 

959 V = tl.load(p_v0 + kv_off, cache_modifier=".cg") 

960 else: 

961 V = tl.load( 

962 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg" 

963 ) 

964 P = P.to(v_ptr.type.element_ty) 

965 acc_ = tl.dot(P, V, acc_) 

966 

967 # LSE 

968 lse = tl.where( 

969 rowsum_ == 0 | (rowsum_ != rowsum_), 

970 float("-inf"), 

971 rowmax_ * scale_softmax + tl.log(rowsum_), 

972 ) 

973 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

974 

975 # Rescale output 

976 acc_ *= inv_sum[:, None] 

977 

978 # Write back output 

979 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size) 

980 # grid = (seq_block, split, batch * head) 

981 o_split_ptr = o_ptr 

982 # + split, batch, head offsets, seq_block offsets are already added in row_idx 

983 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d 

984 o_split_offset = row_idx[:, None] * d + tl.arange(0, d) 

985 o_split_ptr = tl.multiple_of(o_split_ptr, d) 

986 p_om = o_split_ptr + o_split_offset 

987 

988 if IS_EVEN_MN: 

989 tl.store(p_om, acc_, cache_modifier=".cg") 

990 else: 

991 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg") 

992 

993 # Write back lse 

994 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q) 

995 lse_split_ptr = softmax_lse_ptr 

996 # + split, batch, head, seq_block offsets 

997 lse_split_ptr += ( 

998 split_id * tl.num_programs(2) + tl.program_id(2) 

999 ) * seqlen_q + m_block * BLOCK_M 

1000 

1001 if IS_EVEN_MN: 

1002 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg") 

1003 else: 

1004 tl.store( 

1005 lse_split_ptr + tl.arange(0, BLOCK_M), 

1006 lse, 

1007 mask=row_idx < seqlen_q, 

1008 cache_modifier=".cg", 

1009 ) 

1010 

1011 

1012@libentry() 

1013@triton.jit 

1014def flash_fwd_splitkv_combine_kernel( 

1015 out_ptr, 

1016 lse_ptr, 

1017 out_splits_ptr, 

1018 lse_splits_ptr, 

1019 head_size: tl.constexpr, 

1020 out_b_stride, 

1021 out_s_stride, 

1022 out_h_stride, 

1023 n_splits, 

1024 BLOCK_M: tl.constexpr, 

1025 q_total, 

1026 MAX_N_SPLITS: tl.constexpr, 

1027): 

1028 pid = tl.program_id(0) 

1029 lse_splits_ptr += pid * BLOCK_M 

1030 lse_ptr += pid * BLOCK_M 

1031 out_splits_ptr += pid * BLOCK_M * head_size 

1032 out_ptr += pid * BLOCK_M * head_size 

1033 lse_split_stride = tl.num_programs(0) * BLOCK_M 

1034 out_split_stride = tl.num_programs(0) * BLOCK_M * head_size 

1035 

1036 # Subtracting maximum from each of the split lse's for better numerical stability 

1037 lse_split_offset = ( 

1038 tl.arange(0, BLOCK_M)[:, None] 

1039 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride 

1040 ) 

1041 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & ( 

1042 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits 

1043 ) 

1044 lse_splits = tl.load( 

1045 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf") 

1046 ) 

1047 max_lse = tl.max(lse_splits, 1) 

1048 

1049 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse) 

1050 Zi_scaled = tl.exp(lse_splits - max_lse[:, None]) 

1051 Z_scaled = tl.sum(Zi_scaled, 1) 

1052 Zi_Z = Zi_scaled / Z_scaled[:, None] 

1053 

1054 # Write back LSE 

1055 lse = tl.log(Z_scaled) + max_lse 

1056 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total 

1057 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask) 

1058 

1059 out_split_offset = ( 

1060 tl.arange(0, BLOCK_M)[:, None, None] * head_size 

1061 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride 

1062 + tl.arange(0, head_size)[None, None, :] 

1063 ) 

1064 out_split_mask = ( 

1065 pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total 

1066 ) & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits) 

1067 out_splits = tl.load( 

1068 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0 

1069 ) 

1070 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1) 

1071 out = out.to(out_ptr.type.element_ty) 

1072 

1073 # Write back output 

1074 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size) 

1075 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None]) 

1076 

1077 

1078@triton.jit 

1079def block_to_cache_index( 

1080 n_block, page_table_ptr, block_size, page_stride, row_stride, BLOCK_N 

1081): 

1082 row_index = n_block * BLOCK_N 

1083 page_offset = row_index % block_size 

1084 virtual_page_index = row_index // block_size 

1085 # page_table_ptr is already pointed at the start of the current batch element 

1086 cache_page_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32) 

1087 return cache_page_index * block_size + page_offset 

1088 

1089 

1090@libentry() 

1091@triton.jit( 

1092 do_not_specialize=[ 

1093 "seqlen_q", 

1094 "seqlen_k", 

1095 "seqlen_q_rounded", 

1096 "seqlen_k_rounded", 

1097 "total_q", 

1098 ] 

1099) 

1100def flash_varlen_fwd_kernel( 

1101 q_ptr, 

1102 k_ptr, 

1103 v_ptr, 

1104 o_ptr, 

1105 p_ptr, 

1106 softmax_lse_ptr, 

1107 q_row_stride, 

1108 k_row_stride, 

1109 v_row_stride, 

1110 q_head_stride, 

1111 k_head_stride, 

1112 v_head_stride, 

1113 o_row_stride, 

1114 o_head_stride, 

1115 q_batch_stride, 

1116 k_batch_stride, 

1117 v_batch_stride, 

1118 o_batch_stride, 

1119 is_cu_seqlens_q: tl.constexpr, 

1120 cu_seqlens_q_ptr, 

1121 is_cu_seqlens_k: tl.constexpr, 

1122 cu_seqlens_k_ptr, 

1123 is_seqused_k: tl.constexpr, 

1124 seqused_k_ptr, 

1125 # sizes 

1126 b: tl.constexpr, 

1127 bk: tl.constexpr, 

1128 h: tl.constexpr, 

1129 hk: tl.constexpr, 

1130 h_hk_ratio: tl.constexpr, 

1131 seqlen_q, 

1132 seqlen_k, 

1133 seqlen_q_rounded, 

1134 seqlen_k_rounded, 

1135 d: tl.constexpr, 

1136 d_rounded: tl.constexpr, 

1137 # scaling factors 

1138 is_softcap: tl.constexpr, 

1139 softcap: tl.constexpr, 

1140 scale_softmax: tl.constexpr, 

1141 scale_softmax_log2: tl.constexpr, 

1142 # dropout 

1143 is_dropout: tl.constexpr, 

1144 p_dropout: tl.constexpr, 

1145 rp_dropout: tl.constexpr, 

1146 p_dropout_in_uint8_t: tl.constexpr, 

1147 philox_args, 

1148 return_softmax: tl.constexpr, 

1149 # causal and swa 

1150 is_causal: tl.constexpr, 

1151 is_local: tl.constexpr, 

1152 window_size_left: tl.constexpr, 

1153 window_size_right: tl.constexpr, 

1154 seqlenq_ngroups_swapped: tl.constexpr, 

1155 # alibi 

1156 is_alibi: tl.constexpr, 

1157 alibi_slopes_ptr, 

1158 alibi_slopes_batch_stride: tl.constexpr, 

1159 # block table 

1160 total_q, 

1161 page_table_ptr, 

1162 page_table_batch_stride: tl.constexpr, 

1163 block_size: tl.constexpr, 

1164 # kernel params 

1165 BLOCK_M: tl.constexpr, 

1166 BLOCK_N: tl.constexpr, 

1167 num_warps: tl.constexpr, 

1168 num_stages: tl.constexpr, 

1169): 

1170 m_block = tl.program_id(0) 

1171 bid = tl.program_id(1) 

1172 hid = tl.program_id(2) 

1173 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M) 

1174 

1175 if is_cu_seqlens_q: 

1176 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32) 

1177 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32) 

1178 q_len = q_eos - q_bos 

1179 # Current request's start offset in the batched Q 

1180 q_offset = q_bos * q_row_stride 

1181 o_offset = q_bos * o_row_stride 

1182 lse_offset = q_bos * 1 

1183 else: 

1184 q_len = seqlen_q 

1185 q_offset = bid * q_batch_stride 

1186 o_offset = bid * o_batch_stride 

1187 lse_offset = bid * seqlen_q 

1188 

1189 if is_cu_seqlens_k: 

1190 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32) 

1191 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32) 

1192 k_len_cache = k_eos - k_bos 

1193 # k_offset = k_bos * k_row_stride 

1194 else: 

1195 k_len_cache = seqlen_k 

1196 # k_offset = bid * k_batch_stride 

1197 

1198 # v_head_offset = (hid / h_hk_ratio) * k_head_stride 

1199 

1200 if is_seqused_k: 

1201 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32) 

1202 else: 

1203 k_len = k_len_cache 

1204 

1205 # Noop CTA 

1206 if m_block * BLOCK_M > q_len: 

1207 return 

1208 

1209 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0) 

1210 is_even_mn: tl.constexpr = False 

1211 

1212 if is_local: 

1213 n_block_min = max( 

1214 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N 

1215 ) 

1216 else: 

1217 n_block_min = 0 

1218 

1219 n_block_max = tl.cdiv(k_len, BLOCK_N) 

1220 if is_causal or is_local: 

1221 n_block_max = min( 

1222 n_block_max, 

1223 tl.cdiv( 

1224 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N 

1225 ), 

1226 ) 

1227 

1228 if is_dropout: 

1229 philox_seed = tl.load(philox_args).to(tl.uint64) 

1230 philox_offset = tl.load(philox_args + 1).to(tl.uint64) 

1231 

1232 # start processing kv blocks 

1233 page_table_ptr += bid * page_table_batch_stride 

1234 q_row_offset = hid * q_head_stride 

1235 k_row_offset = (hid // h_hk_ratio) * k_head_stride 

1236 

1237 gQ = tl.make_block_ptr( 

1238 base=q_ptr + q_offset + q_row_offset, 

1239 shape=(q_len, d), 

1240 strides=(q_row_stride, 1), 

1241 offsets=(0, 0), 

1242 block_shape=(BLOCK_M, d), 

1243 order=(1, 0), 

1244 ) 

1245 

1246 gK = tl.make_block_ptr( 

1247 base=k_ptr + k_row_offset, 

1248 shape=(d, bk * block_size), 

1249 strides=(1, k_row_stride), 

1250 offsets=(0, 0), 

1251 block_shape=(d, BLOCK_N), 

1252 order=(0, 1), 

1253 ) 

1254 

1255 gV = tl.make_block_ptr( 

1256 base=v_ptr + k_row_offset, 

1257 shape=(bk * block_size, d), 

1258 strides=(k_row_stride, 1), 

1259 offsets=(0, 0), 

1260 block_shape=(BLOCK_N, d), 

1261 order=(1, 0), 

1262 ) 

1263 

1264 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0,)) 

1265 

1266 acc_ = tl.zeros((BLOCK_M, d), dtype=tl.float32) 

1267 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) 

1268 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32) 

1269 

1270 if is_alibi: 

1271 alibi_offset = bid * alibi_slopes_batch_stride + hid 

1272 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset) 

1273 alibi_slope /= scale_softmax 

1274 else: 

1275 alibi_slope = 0.0 

1276 

1277 if not is_causal and not is_local: 

1278 n_masking_steps = 1 

1279 elif is_even_mn: 

1280 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) 

1281 else: 

1282 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1 

1283 

1284 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps) 

1285 for step in tl.range(0, n_masking_steps): 

1286 # for step in tl.range(1): 

1287 n_block = n_block_max - 1 - step 

1288 cache_row_index = block_to_cache_index( 

1289 n_block, 

1290 page_table_ptr, 

1291 block_size, 

1292 page_table_batch_stride, 

1293 k_row_stride, 

1294 BLOCK_N, 

1295 ) 

1296 bK = tl.load(gK.advance([0, cache_row_index]), boundary_check=(1,)) 

1297 # preload V 

1298 bV = tl.load(gV.advance([cache_row_index, 0]), boundary_check=(0,)) 

1299 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1300 S = apply_softcap(S, softcap, is_softcap) 

1301 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1302 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1303 S = apply_alibi( 

1304 S, 

1305 col_idx, 

1306 row_idx, 

1307 q_len, 

1308 k_len, 

1309 is_causal=is_causal, 

1310 is_alibi=is_alibi, 

1311 alibi_slope=alibi_slope, 

1312 ) 

1313 S = apply_mask( 

1314 S, 

1315 col_idx, 

1316 row_idx, 

1317 q_len, 

1318 k_len, 

1319 window_size_left, 

1320 window_size_right, 

1321 is_even_mn=is_even_mn, 

1322 is_causal=is_causal, 

1323 is_local=is_local, 

1324 ) 

1325 

1326 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1327 acc_, 

1328 S, 

1329 rowmax_, 

1330 rowsum_, 

1331 softmax_scale_log2e=scale_softmax_log2, 

1332 is_border=True, 

1333 ) 

1334 P = P.to(v_ptr.type.element_ty) 

1335 

1336 if is_dropout: 

1337 P = apply_dropout( 

1338 P, 

1339 n_block * BLOCK_N, 

1340 m_block * BLOCK_M, 

1341 k_len, 

1342 bid, 

1343 hid, 

1344 philox_seed, 

1345 philox_offset, 

1346 p_dropout_in_uint8_t, 

1347 is_dropout, 

1348 encode_dropout_in_sign_bit=False, 

1349 NUM_HEADS=h, 

1350 BLOCK_M=BLOCK_M, 

1351 BLOCK_N=BLOCK_N, 

1352 ) 

1353 

1354 acc_ = tl.dot(P, bV, acc_) 

1355 

1356 for n_block in tl.range( 

1357 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1 

1358 ): 

1359 cache_row_index = block_to_cache_index( 

1360 n_block, 

1361 page_table_ptr, 

1362 block_size, 

1363 page_table_batch_stride, 

1364 k_row_stride, 

1365 BLOCK_N, 

1366 ) 

1367 bK = tl.load(gK.advance([0, cache_row_index])) 

1368 # preload V 

1369 bV = tl.load(gV.advance([cache_row_index, 0])) 

1370 S = tl.dot(bQ, bK, out_dtype=tl.float32) 

1371 S = apply_softcap(S, softcap, is_softcap) 

1372 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N) 

1373 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1374 S = apply_alibi( 

1375 S, 

1376 col_idx, 

1377 row_idx, 

1378 q_len, 

1379 k_len, 

1380 is_causal=is_causal, 

1381 is_alibi=is_alibi, 

1382 alibi_slope=alibi_slope, 

1383 ) 

1384 S = apply_mask( 

1385 S, 

1386 col_idx, 

1387 row_idx, 

1388 q_len, 

1389 k_len, 

1390 window_size_left, 

1391 window_size_right, 

1392 is_even_mn=True, 

1393 is_causal=False, 

1394 is_local=is_local, 

1395 ) 

1396 

1397 acc_, P, rowmax_, rowsum_ = softmax_rescale( 

1398 acc_, 

1399 S, 

1400 rowmax_, 

1401 rowsum_, 

1402 softmax_scale_log2e=scale_softmax_log2, 

1403 is_border=is_local, 

1404 ) 

1405 P = P.to(v_ptr.type.element_ty) 

1406 

1407 if is_dropout: 

1408 P = apply_dropout( 

1409 P, 

1410 m_block * BLOCK_M, 

1411 n_block * BLOCK_N, 

1412 k_len, 

1413 bid, 

1414 hid, 

1415 philox_seed, 

1416 philox_offset, 

1417 p_dropout_in_uint8_t, 

1418 is_dropout, 

1419 encode_dropout_in_sign_bit=False, 

1420 NUM_HEADS=h, 

1421 BLOCK_M=BLOCK_M, 

1422 BLOCK_N=BLOCK_N, 

1423 ) 

1424 acc_ = tl.dot(P, bV, acc_) 

1425 

1426 # LSE 

1427 lse = tl.where( 

1428 rowsum_ == 0 | (rowsum_ != rowsum_), 

1429 float("inf"), 

1430 rowmax_ * scale_softmax + tl.log(rowsum_), 

1431 ) 

1432 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_) 

1433 

1434 acc_ *= inv_sum[:, None] 

1435 

1436 out = acc_.to(o_ptr.type.element_ty) # noqa 

1437 

1438 # Write back output 

1439 o_row_offset = hid * o_head_stride 

1440 

1441 gO = tl.make_block_ptr( 

1442 base=o_ptr + o_offset + o_row_offset, 

1443 shape=(q_len, d), 

1444 strides=(o_row_stride, 1), 

1445 offsets=(0, 0), 

1446 block_shape=(BLOCK_M, d), 

1447 order=(1, 0), 

1448 ) 

1449 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0,)) 

1450 

1451 # Write back lse 

1452 # lse shape: [h, total_q] 

1453 softmax_lse_ptr += hid * total_q 

1454 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M) 

1455 tl.store(softmax_lse_ptr + lse_row_offset, lse, mask=lse_row_offset < total_q)