Coverage for src/flag_gems/ops/attention.py: 30%

430 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from functools import partial 

4 

5import torch 

6import torch.nn.functional as F 

7import triton 

8import triton.language as tl 

9 

10from flag_gems import runtime 

11from flag_gems.config import use_c_extension 

12from flag_gems.ops.flash_api import mha_fwd, mha_varlan_fwd, mha_varlan_fwd_opt 

13from flag_gems.ops.flash_kernel import keep 

14from flag_gems.runtime import torch_device_fn 

15from flag_gems.utils import libentry, libtuner 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 

21@triton.jit 

22def _attn_fwd_inner( 

23 acc, 

24 l_i, 

25 m_i, 

26 query, # 

27 K_block_ptr, 

28 V_block_ptr, # 

29 mask_block_ptr, # 

30 stride_k_seqlen, 

31 stride_v_seqlen, 

32 stride_attn_mask_kv_seqlen, # 

33 start_m, 

34 qk_scale, # 

35 q_load_mask, 

36 BLOCK_M: tl.constexpr, 

37 HEAD_DIM: tl.constexpr, 

38 BLOCK_N: tl.constexpr, # 

39 STAGE: tl.constexpr, 

40 offs_m: tl.constexpr, 

41 offs_n: tl.constexpr, # 

42 KV_CTX: tl.constexpr, 

43 fp8_v: tl.constexpr, 

44 HAS_ATTN_MASK: tl.constexpr, 

45 PRE_LOAD_V: tl.constexpr, 

46): 

47 # range of values handled by this stage 

48 if STAGE == 1: 

49 lo, hi = 0, start_m * BLOCK_M 

50 elif STAGE == 2: 

51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

52 # causal = False 

53 else: 

54 lo, hi = 0, KV_CTX 

55 

56 K_block_ptr += lo * stride_k_seqlen 

57 V_block_ptr += lo * stride_v_seqlen 

58 if HAS_ATTN_MASK: 

59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen 

60 

61 LOG2E = 1.44269504 # log2(e) constant 

62 

63 # loop over key, value and update accumulator 

64 for start_n in range(lo, hi, BLOCK_N): 

65 kv_load_mask = (start_n + offs_n) < KV_CTX 

66 # start_n = tl.multiple_of(start_n, BLOCK_N) 

67 # -- compute qk ---- 

68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0) 

69 if PRE_LOAD_V: 

70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

71 

72 qk = tl.dot(query, key, allow_tf32=False) 

73 # incase not divisible. 

74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf")) 

75 # qk = qk.to(tl.float32) 

76 

77 if HAS_ATTN_MASK: 

78 attn_mask = tl.load( 

79 mask_block_ptr, 

80 mask=q_load_mask[:, None] & kv_load_mask[None, :], 

81 other=0.0, 

82 ) 

83 

84 if STAGE == 2: 

85 mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 

86 

87 if HAS_ATTN_MASK: 

88 qk = qk * qk_scale + attn_mask 

89 qk *= LOG2E 

90 qk = qk + tl.where(mask, 0, -1.0e6) 

91 else: 

92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6) 

93 

94 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

95 qk -= m_ij[:, None] 

96 else: 

97 qk *= qk_scale * LOG2E 

98 if HAS_ATTN_MASK: 

99 qk = qk + attn_mask 

100 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

101 qk = qk - m_ij[:, None] 

102 

103 p = tl.math.exp2(qk) 

104 l_ij = tl.sum(p, 1) 

105 # -- update m_i and l_i 

106 alpha = tl.math.exp2(m_i - m_ij) 

107 l_i = l_i * alpha + l_ij 

108 # -- update output accumulator -- 

109 acc = acc * alpha[:, None] 

110 # update acc 

111 if not PRE_LOAD_V: 

112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

113 if fp8_v: 

114 p = p.to(tl.float8e5) 

115 else: 

116 p = p.to(query.dtype) 

117 p = p.to(value.dtype) 

118 acc = tl.dot(p, value, acc, allow_tf32=False) 

119 # update m_i and l_i 

120 m_i = m_ij 

121 

122 K_block_ptr += BLOCK_N * stride_k_seqlen 

123 V_block_ptr += BLOCK_N * stride_v_seqlen 

124 

125 if HAS_ATTN_MASK: 

126 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen 

127 

128 return acc, l_i, m_i 

129 

130 

131# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim, 

132# we need to generate more configs. 

133configs = runtime.get_tuned_config("attention") 

134SMALL_HEAD_DIM_CONFIGS = [ 

135 triton.Config( 

136 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w 

137 ) 

138 for BM in [64, 128] 

139 for BN in [16, 32] 

140 for s in [2, 3, 4] 

141 for w in [4, 8] 

142] 

143configs += SMALL_HEAD_DIM_CONFIGS 

144 

145 

146@libentry() 

147@libtuner( 

148 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)), 

149 key=["KV_CTX", "HEAD_DIM"], 

150) 

151@triton.jit 

152def _attn_fwd( 

153 Q, 

154 K, 

155 V, 

156 attn_mask, 

157 sm_scale, 

158 M, 

159 Out, # 

160 stride_q_batch, 

161 stride_q_head, 

162 stride_q_seqlen, 

163 stride_q_headsize, 

164 stride_k_batch, 

165 stride_k_head, 

166 stride_k_seqlen, 

167 stride_k_headsize, 

168 stride_v_batch, 

169 stride_v_head, 

170 stride_v_seqlen, 

171 stride_v_headsize, 

172 stride_attn_mask_batch, 

173 stride_attn_mask_head, 

174 stride_attn_mask_q_seqlen, 

175 stride_attn_mask_kv_seqlen, 

176 stride_o_batch, 

177 stride_o_head, 

178 stride_o_seqlen, 

179 stride_o_headsize, 

180 Z, 

181 q_head_num, 

182 kv_head_num, 

183 GROUP_HEAD: tl.constexpr, 

184 Q_CTX, 

185 KV_CTX, 

186 HEAD_DIM: tl.constexpr, 

187 BLOCK_M: tl.constexpr, 

188 BLOCK_N: tl.constexpr, 

189 STAGE: tl.constexpr, 

190 HAS_ATTN_MASK: tl.constexpr, 

191 PRE_LOAD_V: tl.constexpr, 

192): 

193 tl.static_assert(BLOCK_N <= HEAD_DIM) 

194 start_m = tl.program_id(0) 

195 off_hz = tl.program_id(1) 

196 batch_id = off_hz // q_head_num 

197 head_id = off_hz % q_head_num 

198 kv_head_id = head_id // GROUP_HEAD 

199 

200 q_offset = ( 

201 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head 

202 ) 

203 o_offset = ( 

204 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head 

205 ) 

206 kv_offset = ( 

207 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head 

208 ) 

209 

210 offs_headsize = tl.arange(0, HEAD_DIM) 

211 

212 # initialize offsets 

213 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 

214 q_load_mask = offs_m < Q_CTX 

215 offs_n = tl.arange(0, BLOCK_N) 

216 

217 Q_block_ptr = ( 

218 Q 

219 + q_offset 

220 + offs_m[:, None] * stride_q_seqlen 

221 + offs_headsize[None, :] * stride_q_headsize 

222 ) 

223 K_block_ptr = ( 

224 K 

225 + kv_offset 

226 + offs_n[None, :] * stride_k_seqlen 

227 + offs_headsize[:, None] * stride_k_headsize 

228 ) 

229 V_block_ptr = ( 

230 V 

231 + kv_offset 

232 + offs_n[:, None] * stride_v_seqlen 

233 + offs_headsize[None, :] * stride_v_headsize 

234 ) 

235 

236 if HAS_ATTN_MASK: 

237 attn_mask_offset = ( 

238 batch_id.to(tl.int64) * stride_attn_mask_batch 

239 + head_id.to(tl.int64) * stride_attn_mask_head 

240 ) 

241 mask_block_ptr = ( 

242 attn_mask 

243 + attn_mask_offset 

244 + offs_m[:, None] * stride_attn_mask_q_seqlen 

245 + offs_n[None, :] * stride_attn_mask_kv_seqlen 

246 ) 

247 else: 

248 mask_block_ptr = None 

249 

250 O_block_ptr = ( 

251 Out 

252 + o_offset 

253 + offs_m[:, None] * stride_o_seqlen 

254 + offs_headsize[None, :] * stride_o_headsize 

255 ) 

256 

257 # initialize pointer to m and l 

258 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 

259 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 

260 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

261 # load scales 

262 qk_scale = sm_scale 

263 # qk_scale *= 1.44269504 # 1/log(2) 

264 # load query: it will stay in SRAM throughout 

265 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0) 

266 # stage 1: off-band 

267 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 

268 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 

269 if STAGE & 1: 

270 acc, l_i, m_i = _attn_fwd_inner( 

271 acc, 

272 l_i, 

273 m_i, 

274 query, 

275 K_block_ptr, 

276 V_block_ptr, 

277 mask_block_ptr, 

278 stride_k_seqlen, 

279 stride_v_seqlen, 

280 stride_attn_mask_kv_seqlen, 

281 start_m, 

282 qk_scale, 

283 q_load_mask, 

284 BLOCK_M, 

285 HEAD_DIM, 

286 BLOCK_N, 

287 4 - STAGE, 

288 offs_m, 

289 offs_n, 

290 KV_CTX, 

291 V.dtype.element_ty == tl.float8e5, 

292 HAS_ATTN_MASK, 

293 PRE_LOAD_V, 

294 ) 

295 # stage 2: on-band 

296 if STAGE & 2: 

297 # barrier makes it easier for compielr to schedule the 

298 # two loops independently 

299 acc, l_i, m_i = _attn_fwd_inner( 

300 acc, 

301 l_i, 

302 m_i, 

303 query, 

304 K_block_ptr, 

305 V_block_ptr, 

306 mask_block_ptr, 

307 stride_k_seqlen, 

308 stride_v_seqlen, 

309 stride_attn_mask_kv_seqlen, 

310 start_m, 

311 qk_scale, 

312 q_load_mask, 

313 BLOCK_M, 

314 HEAD_DIM, 

315 BLOCK_N, 

316 2, 

317 offs_m, 

318 offs_n, 

319 KV_CTX, 

320 V.dtype.element_ty == tl.float8e5, 

321 HAS_ATTN_MASK, 

322 PRE_LOAD_V, 

323 ) 

324 # epilogue 

325 m_i += tl.math.log2(l_i) 

326 acc = acc / l_i[:, None] 

327 m_ptrs = M + off_hz * Q_CTX + offs_m 

328 tl.store(m_ptrs, m_i, mask=q_load_mask) 

329 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) 

330 

331 

332@triton.jit 

333def _attn_bwd_preprocess( 

334 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr 

335): 

336 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

337 mask = off_m < Q_CTX 

338 

339 off_hz = tl.program_id(1) 

340 off_n = tl.arange(0, D_HEAD) 

341 # load 

342 o = tl.load( 

343 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

344 mask=mask[:, None], 

345 other=0.0, 

346 ) 

347 do = tl.load( 

348 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

349 mask=mask[:, None], 

350 other=0.0, 

351 ).to(tl.float32) 

352 delta = tl.sum(o * do, axis=1) 

353 # write-back 

354 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) 

355 

356 

357# The main inner-loop logic for computing dK and dV. 

358@triton.jit 

359def _attn_bwd_dkdv( 

360 dk, 

361 dv, # 

362 Q, 

363 key, 

364 value, 

365 sm_scale, # 

366 DO, # 

367 M, 

368 D, # 

369 # shared by Q/K/V/DO. 

370 stride_tok, 

371 stride_d, # 

372 H, 

373 Q_CTX, 

374 KV_CTX, 

375 BLOCK_M1: tl.constexpr, # 

376 BLOCK_N1: tl.constexpr, # 

377 BLOCK_DMODEL: tl.constexpr, # 

378 # Filled in by the wrapper. 

379 start_n, 

380 start_m, 

381 num_steps, # 

382 MASK: tl.constexpr, 

383): 

384 # BLOCK_M1: 32 

385 # BLOCK_N1: 128 

386 offs_n = start_n + tl.arange(0, BLOCK_N1) 

387 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) 

388 

389 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) 

390 

391 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 

392 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 

393 curr_m = start_m 

394 step_m = BLOCK_M1 

395 for blk_idx in range(num_steps): 

396 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) 

397 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) 

398 

399 qT_ptrs = ( 

400 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 

401 ) # (BLOCK_DMODEL, BLOCK_M1) 

402 do_ptrs = ( 

403 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

404 ) # (BLOCK_M1, BLOCK_DMODEL) 

405 

406 qT = tl.load( 

407 qT_ptrs, mask=offs_m_mask[None, :], other=0.0 

408 ) # (BLOCK_DMODEL, BLOCK_M1) 

409 

410 # Load m before computing qk to reduce pipeline stall. 

411 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) 

412 

413 # key: (BLOCK_N1, BLOCK_DMODEL) 

414 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) 

415 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) 

416 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) 

417 pT = tl.math.exp2(qkT - m) 

418 # pT = tl.math.exp2(qkT - m[None, :]) 

419 

420 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ 

421 :, None 

422 ] # (BLOCK_N1, BLOCK_M1) 

423 # Autoregressive masking. 

424 if MASK: 

425 mask &= offs_m[None, :] >= offs_n[:, None] 

426 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) 

427 

428 do = tl.load( 

429 do_ptrs, mask=offs_m_mask[:, None], other=0.0 

430 ) # (BLOCK_M1, BLOCK_DMODEL) 

431 

432 # Compute dV. 

433 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) 

434 # D (= delta) is pre-divided by ds_scale. 

435 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) 

436 

437 # Compute dP and dS. 

438 dpT = tl.dot(value, tl.trans(do)).to( 

439 tl.float32 

440 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) 

441 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) 

442 dsT = dsT.to(qT.dtype) 

443 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) 

444 dsT = tl.where( 

445 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 

446 ) # (BLOCK_N1, BLOCK_M1) 

447 dk += tl.dot( 

448 dsT, tl.trans(qT) 

449 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) 

450 # Increment pointers. 

451 curr_m += step_m 

452 return dk, dv 

453 

454 

455# the main inner-loop logic for computing dQ 

456@triton.jit 

457def _attn_bwd_dq( 

458 dq, 

459 query, 

460 K, 

461 V, # 

462 do, 

463 m, 

464 D, 

465 # shared by Q/K/V/DO. 

466 stride_tok, 

467 stride_d, # 

468 H, 

469 Q_CTX, # 

470 KV_CTX, # 

471 BLOCK_M2: tl.constexpr, # 

472 BLOCK_N2: tl.constexpr, # 

473 BLOCK_DMODEL: tl.constexpr, 

474 # Filled in by the wrapper. 

475 start_m, 

476 start_n, 

477 num_steps, # 

478 MASK: tl.constexpr, 

479): 

480 offs_m = start_m + tl.arange(0, BLOCK_M2) 

481 offs_m_mask = offs_m < Q_CTX 

482 

483 offs_k = tl.arange(0, BLOCK_DMODEL) 

484 # D (= delta) is pre-divided by ds_scale. 

485 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) 

486 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 

487 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 

488 curr_n = start_n 

489 step_n = BLOCK_N2 

490 for blk_idx in range(num_steps): 

491 offs_n = curr_n + tl.arange(0, BLOCK_N2) 

492 offs_n_mask = offs_n < KV_CTX 

493 

494 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

495 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

496 

497 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

498 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

499 qk = tl.dot(query, kT) 

500 p = tl.math.exp2(qk - m) 

501 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] 

502 # Autoregressive masking. 

503 if MASK: 

504 # mask = (offs_m[:, None] >= offs_n[None, :]) 

505 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] 

506 mask &= offs_m[:, None] >= offs_n[None, :] 

507 p = tl.where(mask, p, 0.0) 

508 # Compute dP and dS. 

509 dp = tl.dot(do, vT).to(tl.float32) 

510 ds = p * (dp - Di[:, None]) 

511 ds = tl.where(mask, ds, 0.0).to(kT.dtype) 

512 # Compute dQ. 

513 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 

514 dq += tl.dot(ds, tl.trans(kT)) 

515 # Increment pointers. 

516 curr_n += step_n 

517 return dq 

518 

519 

520config_backward = runtime.get_tuned_config("attention_bwd") 

521 

522 

523@libentry() 

524@libtuner( 

525 configs=config_backward, 

526 key=["KV_CTX", "BLOCK_DMODEL"], 

527) 

528@triton.jit 

529def _attn_bwd( 

530 Q, 

531 K, 

532 V, 

533 sm_scale, # 

534 DO, # 

535 DQ, 

536 DK, 

537 DV, # 

538 M, 

539 D, 

540 # shared by Q/K/V/DO. 

541 stride_z, 

542 stride_h, 

543 stride_tok, 

544 stride_d, # 

545 kv_stride_z, 

546 kv_stride_h, # 

547 dk_stride_z, 

548 dk_stride_h, 

549 dk_stride_tok, # 

550 H, # query head num 

551 Q_CTX, # 

552 KV_CTX, # 

553 kv_head_num, # 

554 GROUP_HEAD: tl.constexpr, # 

555 BLOCK_M1: tl.constexpr, # 

556 BLOCK_N1: tl.constexpr, # 

557 BLOCK_M2: tl.constexpr, # 

558 BLOCK_N2: tl.constexpr, # 

559 BLK_SLICE_FACTOR: tl.constexpr, # 

560 BLOCK_DMODEL: tl.constexpr, 

561 IS_CAUSAL: tl.constexpr = True, 

562): 

563 LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 

564 

565 bhid = tl.program_id(2) 

566 off_chz = (bhid * Q_CTX).to(tl.int64) 

567 batch_id = bhid // H 

568 q_head_id = bhid % H 

569 kv_head_id = q_head_id // GROUP_HEAD 

570 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) 

571 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) 

572 dk_adj = (dk_stride_h * q_head_id + dk_stride_z * batch_id).to(tl.int64) 

573 

574 pid = tl.program_id(0) 

575 

576 # offset pointers for batch/head 

577 Q += adj 

578 K += kv_adj 

579 V += kv_adj 

580 DO += adj 

581 DQ += adj 

582 DK += dk_adj 

583 DV += dk_adj 

584 M += off_chz 

585 D += off_chz 

586 

587 # load scales 

588 offs_k = tl.arange(0, BLOCK_DMODEL) 

589 

590 # dK/dV: only execute when this pid covers a valid KV block 

591 start_n = pid * BLOCK_N1 

592 if start_n < KV_CTX: 

593 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

594 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

595 

596 # load K and V: they stay in SRAM throughout the inner loop. 

597 offs_n = start_n + tl.arange(0, BLOCK_N1) 

598 offs_n_mask = offs_n < KV_CTX 

599 key = tl.load( 

600 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

601 mask=offs_n_mask[:, None], 

602 other=0.0, 

603 ) 

604 value = tl.load( 

605 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

606 mask=offs_n_mask[:, None], 

607 other=0.0, 

608 ) 

609 

610 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 

611 

612 # Causal: masked diagonal phase, then unmasked above-diagonal phase. 

613 # Non-causal: skip masked phase, single unmasked pass over all Q rows. 

614 if IS_CAUSAL: 

615 # The causal mask is q_idx >= kv_idx, so for KV block starting at 

616 # start_n, the first Q row that can attend is start_n itself. 

617 start_m = start_n 

618 # Clamp to valid Q range 

619 if start_m < Q_CTX: 

620 end_m = min(start_m + BLOCK_N1, Q_CTX) 

621 num_steps = (end_m - start_m + MASK_BLOCK_M1 - 1) // MASK_BLOCK_M1 

622 dk, dv = _attn_bwd_dkdv( 

623 dk, 

624 dv, # 

625 Q, 

626 key, 

627 value, 

628 sm_scale, # 

629 DO, # 

630 M, 

631 D, # 

632 stride_tok, 

633 stride_d, # 

634 H, 

635 Q_CTX, # 

636 KV_CTX, # 

637 MASK_BLOCK_M1, 

638 BLOCK_N1, 

639 BLOCK_DMODEL, # 

640 start_n, 

641 start_m, 

642 num_steps, # 

643 MASK=True, # 

644 ) 

645 start_m += num_steps * MASK_BLOCK_M1 

646 # else: start_n >= Q_CTX, no Q rows can attend to this KV block 

647 else: 

648 start_m = 0 

649 

650 # Unmasked phase (shared): traverse remaining Q rows. 

651 remaining_m = Q_CTX - start_m 

652 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 

653 if num_steps > 0: 

654 dk, dv = _attn_bwd_dkdv( 

655 dk, 

656 dv, # 

657 Q, 

658 key, 

659 value, 

660 sm_scale, # 

661 DO, # 

662 M, 

663 D, # 

664 stride_tok, 

665 stride_d, # 

666 H, 

667 Q_CTX, # 

668 KV_CTX, # 

669 BLOCK_M1, 

670 BLOCK_N1, 

671 BLOCK_DMODEL, # 

672 start_n, 

673 start_m, 

674 num_steps, # 

675 MASK=False, # 

676 ) 

677 

678 dv_ptrs = DV + offs_n[:, None] * dk_stride_tok + offs_k[None, :] * stride_d 

679 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) 

680 

681 # Write back dK. 

682 dk *= sm_scale 

683 dk_ptrs = DK + offs_n[:, None] * dk_stride_tok + offs_k[None, :] * stride_d 

684 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) 

685 

686 # dQ: only execute when this pid covers a valid Q block 

687 start_m = pid * BLOCK_M2 

688 if start_m < Q_CTX: 

689 offs_m = start_m + tl.arange(0, BLOCK_M2) 

690 offs_m_mask = offs_m < Q_CTX 

691 query = tl.load( 

692 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

693 mask=offs_m_mask[:, None], 

694 other=0.0, 

695 ) 

696 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) 

697 do = tl.load( 

698 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

699 mask=offs_m_mask[:, None], 

700 other=0.0, 

701 ) 

702 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) 

703 m = m[:, None] 

704 

705 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 

706 

707 if IS_CAUSAL: 

708 # Masked diagonal phase: KV columns [diag_n, end_n) 

709 # diag_n is the KV position where the causal boundary starts for 

710 # this Q block. Only needed when diag_n < KV_CTX. 

711 diag_n = min(start_m, KV_CTX) 

712 end_n = min(start_m + BLOCK_M2, KV_CTX) 

713 num_steps = (end_n - diag_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 

714 

715 if num_steps > 0: 

716 dq = _attn_bwd_dq( 

717 dq, 

718 query, 

719 K, 

720 V, # 

721 do, 

722 m, 

723 D, # 

724 stride_tok, 

725 stride_d, # 

726 H, 

727 Q_CTX, # 

728 KV_CTX, # 

729 BLOCK_M2, 

730 MASK_BLOCK_N2, 

731 BLOCK_DMODEL, # 

732 start_m, 

733 diag_n, 

734 num_steps, # 

735 MASK=True, # 

736 ) 

737 

738 # Unmasked phase: KV columns [0, diag_n), all fully visible. 

739 stage2_num_steps = (diag_n + BLOCK_N2 - 1) // BLOCK_N2 

740 else: 

741 # Non-causal: single unmasked pass over all KV columns. 

742 stage2_num_steps = (KV_CTX + BLOCK_N2 - 1) // BLOCK_N2 

743 

744 if stage2_num_steps > 0: 

745 dq = _attn_bwd_dq( 

746 dq, 

747 query, 

748 K, 

749 V, # 

750 do, 

751 m, 

752 D, # 

753 stride_tok, 

754 stride_d, # 

755 H, 

756 Q_CTX, # 

757 KV_CTX, # 

758 BLOCK_M2, 

759 BLOCK_N2, 

760 BLOCK_DMODEL, # 

761 start_m, 

762 0, 

763 stage2_num_steps, # 

764 MASK=False, # 

765 ) 

766 

767 # Write back dQ. 

768 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

769 dq *= LN2 

770 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) 

771 

772 

773def scaled_dot_product_attention_forward( 

774 query, 

775 key, 

776 value, 

777 attn_mask=None, 

778 dropout_p=0.0, 

779 is_causal=False, 

780 scale=None, 

781 enable_gqa=False, 

782): 

783 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD") 

784 # shape constraints 

785 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

786 # when v is in float8_e5m2 it is transposed. 

787 HEAD_DIM_V = value.shape[-1] 

788 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

789 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

790 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

791 

792 o = torch.empty_like(query, dtype=value.dtype) 

793 

794 stage = 3 if is_causal else 1 

795 

796 if scale is None: 

797 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

798 else: 

799 sm_scale = scale 

800 

801 q_head_num = query.shape[1] 

802 kv_head_num = key.shape[1] 

803 assert enable_gqa or q_head_num == kv_head_num, ( 

804 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " 

805 "enable_gqa must be True to support different head numbers." 

806 ) 

807 

808 grid = lambda args: ( 

809 triton.cdiv(query.shape[2], args["BLOCK_M"]), 

810 query.shape[0] * query.shape[1], 

811 1, 

812 ) 

813 

814 if attn_mask is not None: 

815 HAS_ATTN_MASK = True 

816 if attn_mask.dtype == torch.bool: 

817 attn_mask = attn_mask.to(query.dtype) * -1.0e6 

818 stride_attn_mask_batch = attn_mask.stride(0) 

819 stride_attn_mask_head = attn_mask.stride(1) 

820 stride_attn_mask_q_seqlen = attn_mask.stride(2) 

821 stride_attn_mask_kv_seqlen = attn_mask.stride(3) 

822 else: 

823 HAS_ATTN_MASK = False 

824 stride_attn_mask_batch = 1 

825 stride_attn_mask_head = 1 

826 stride_attn_mask_q_seqlen = 1 

827 stride_attn_mask_kv_seqlen = 1 

828 

829 M = torch.empty( 

830 (query.shape[0], query.shape[1], query.shape[2]), 

831 device=query.device, 

832 dtype=torch.float32, 

833 ) 

834 

835 with torch_device_fn.device(query.device): 

836 _attn_fwd[grid]( 

837 query, 

838 key, 

839 value, 

840 attn_mask, 

841 sm_scale, 

842 M, 

843 o, # 

844 query.stride(0), 

845 query.stride(1), 

846 query.stride(2), 

847 query.stride(3), # 

848 key.stride(0), 

849 key.stride(1), 

850 key.stride(2), 

851 key.stride(3), # 

852 value.stride(0), 

853 value.stride(1), 

854 value.stride(2), 

855 value.stride(3), # 

856 stride_attn_mask_batch, 

857 stride_attn_mask_head, 

858 stride_attn_mask_q_seqlen, 

859 stride_attn_mask_kv_seqlen, # 

860 o.stride(0), 

861 o.stride(1), 

862 o.stride(2), 

863 o.stride(3), # 

864 query.shape[0], 

865 q_head_num, 

866 kv_head_num, # 

867 q_head_num // kv_head_num, # group_head 

868 query.shape[2], # 

869 key.shape[2], # 

870 HEAD_DIM_K, # 

871 STAGE=stage, # 

872 HAS_ATTN_MASK=HAS_ATTN_MASK, # 

873 ) 

874 return o, M 

875 

876 

877def scaled_dot_product_attention_backward( 

878 do, 

879 query, 

880 key, 

881 value, 

882 o, 

883 M, 

884 attn_mask=None, 

885 dropout_p=0.0, 

886 is_causal=False, 

887 scale=None, 

888 enable_gqa=False, 

889): 

890 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD") 

891 # shape constraints 

892 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

893 # when v is in float8_e5m2 it is transposed. 

894 HEAD_DIM_V = value.shape[-1] 

895 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

896 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

897 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

898 

899 if scale is None: 

900 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

901 else: 

902 sm_scale = scale 

903 

904 assert do.is_contiguous() 

905 assert ( 

906 query.is_contiguous() 

907 and key.is_contiguous() 

908 and value.is_contiguous() 

909 and o.is_contiguous() 

910 ) 

911 assert query.stride() == o.stride() == do.stride() 

912 assert key.stride() == value.stride() 

913 

914 BLOCK_DMODEL = HEAD_DIM_K 

915 BATCH, Q_HEAD, Q_CTX = query.shape[:3] 

916 _, KV_HEAD, KV_CTX = key.shape[:3] 

917 group_head = Q_HEAD // KV_HEAD 

918 

919 # NUM_WARPS, NUM_STAGES = 4, 1 

920 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 

921 BLK_SLICE_FACTOR = 2 

922 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 

923 

924 RCP_LN2 = 1.0 / math.log(2) 

925 

926 arg_k = key * (sm_scale * RCP_LN2) 

927 # PRE_BLOCK = 128 

928 PRE_BLOCK = 256 

929 

930 # PRE_BLOCK = 32 

931 # assert N_CTX % PRE_BLOCK == 0 

932 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) 

933 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) 

934 

935 delta = torch.empty_like(M) 

936 

937 # NOTE that dk & dv always have the same number of heads as q 

938 dq = torch.empty_like(query).contiguous() 

939 dk = torch.empty( 

940 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K), 

941 device=key.device, 

942 dtype=key.dtype, 

943 memory_format=torch.contiguous_format, 

944 ) 

945 dv = torch.empty( 

946 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V), 

947 device=value.device, 

948 dtype=value.dtype, 

949 memory_format=torch.contiguous_format, 

950 ) 

951 

952 _attn_bwd_preprocess[pre_grid]( 

953 o, 

954 do, # 

955 delta, # 

956 BATCH, 

957 Q_HEAD, 

958 Q_CTX, # 

959 BLOCK_M=PRE_BLOCK, 

960 D_HEAD=BLOCK_DMODEL, # 

961 ) 

962 

963 grid = lambda meta: ( 

964 max( 

965 triton.cdiv( 

966 KV_CTX, meta["BLOCK_N1"] 

967 ), # _attn_bwd_dq traverse the key-value sequence 

968 triton.cdiv( 

969 Q_CTX, meta["BLOCK_M2"] 

970 ), # _attn_bwd_dkdv traverse the query sequence 

971 ), 

972 1, 

973 BATCH * Q_HEAD, 

974 ) 

975 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") 

976 # logger.info(f"{M.shape=}") 

977 

978 _attn_bwd[grid]( 

979 query, 

980 arg_k, 

981 value, 

982 sm_scale, 

983 do, 

984 dq, 

985 dk, 

986 dv, # 

987 M, 

988 delta, # 

989 query.stride(0), 

990 query.stride(1), 

991 query.stride(2), 

992 query.stride(3), # 

993 key.stride(0), 

994 key.stride(1), # 

995 dk.stride(0), 

996 dk.stride(1), 

997 dk.stride(2), # 

998 Q_HEAD, 

999 Q_CTX, # 

1000 KV_CTX, # 

1001 KV_HEAD, # 

1002 GROUP_HEAD=group_head, # 

1003 # BLOCK_M1=BLOCK_M1, 

1004 # BLOCK_N1=BLOCK_N1, # 

1005 # BLOCK_M2=BLOCK_M2, 

1006 # BLOCK_N2=BLOCK_N2, # 

1007 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 

1008 BLOCK_DMODEL=BLOCK_DMODEL, # 

1009 IS_CAUSAL=is_causal, # 

1010 ) 

1011 

1012 if group_head > 1: 

1013 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) 

1014 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) 

1015 dk = dk.sum(dim=2) 

1016 dv = dv.sum(dim=2) 

1017 

1018 return dq, dk, dv 

1019 

1020 

1021class ScaleDotProductAttention(torch.autograd.Function): 

1022 @staticmethod 

1023 def forward( 

1024 ctx, 

1025 query, 

1026 key, 

1027 value, 

1028 attn_mask=None, 

1029 dropout_p=0.0, 

1030 is_causal=False, 

1031 scale=None, 

1032 enable_gqa=False, 

1033 ): 

1034 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5) 

1035 o, M = scaled_dot_product_attention_forward( 

1036 query, 

1037 key, 

1038 value, 

1039 attn_mask, 

1040 dropout_p, 

1041 is_causal, 

1042 sm_scale, 

1043 enable_gqa, 

1044 ) 

1045 

1046 ctx.save_for_backward(query, key, value, o, M) 

1047 ctx.sm_scale = sm_scale 

1048 ctx.causal = is_causal 

1049 ctx.enable_gqa = enable_gqa 

1050 return o 

1051 

1052 @staticmethod 

1053 def backward(ctx, do): 

1054 query, key, value, o, M = ctx.saved_tensors 

1055 is_causal = ctx.causal 

1056 enable_gqa = ctx.enable_gqa 

1057 sm_scale = ctx.sm_scale 

1058 dq, dk, dv = scaled_dot_product_attention_backward( 

1059 do, 

1060 query, 

1061 key, 

1062 value, 

1063 o, 

1064 M, 

1065 attn_mask=None, 

1066 dropout_p=0.0, 

1067 is_causal=is_causal, 

1068 scale=sm_scale, 

1069 enable_gqa=enable_gqa, 

1070 ) 

1071 return dq, dk, dv, None, None, None, None, None 

1072 

1073 

1074def scaled_dot_product_attention( 

1075 query, 

1076 key, 

1077 value, 

1078 attn_mask=None, 

1079 dropout_p=0.0, 

1080 is_causal=False, 

1081 scale=None, 

1082 enable_gqa=False, 

1083): 

1084 return ScaleDotProductAttention.apply( 

1085 query, 

1086 key, 

1087 value, 

1088 attn_mask, 

1089 dropout_p, 

1090 is_causal, 

1091 scale, 

1092 enable_gqa, 

1093 ) 

1094 

1095 

1096def flash_attention_forward( 

1097 query, 

1098 key, 

1099 value, 

1100 cumulative_sequence_length_q, 

1101 cumulative_sequence_length_k, 

1102 max_q, 

1103 max_k, 

1104 dropout_p, 

1105 is_causal, 

1106 return_debug_mask, 

1107 *, 

1108 scale=None, 

1109 softcap=0.0, 

1110 window_size_left=None, 

1111 window_size_right=None, 

1112 seqused_k=None, 

1113 alibi_slopes=None, 

1114 disable_splitkv=False, 

1115): 

1116 logger.debug("GEMS FLASH_ATTENTION_FORWARD") 

1117 assert ( 

1118 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None 

1119 ), "varlen is not supported yet." 

1120 

1121 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

1122 HEAD_DIM_V = value.shape[-1] 

1123 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

1124 original_head_dim = HEAD_DIM_K 

1125 supported_head_dims = (16, 32, 64, 96, 128, 192, 256) 

1126 if HEAD_DIM_K not in supported_head_dims: 

1127 padded_head_dim = None 

1128 for d in supported_head_dims: 

1129 if d >= HEAD_DIM_K: 

1130 padded_head_dim = d 

1131 break 

1132 assert ( 

1133 padded_head_dim is not None 

1134 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}" 

1135 pad = padded_head_dim - HEAD_DIM_K 

1136 query = F.pad(query, (0, pad)) 

1137 key = F.pad(key, (0, pad)) 

1138 value = F.pad(value, (0, pad)) 

1139 HEAD_DIM_K = padded_head_dim 

1140 

1141 softmax_scale = scale or 1.0 / (original_head_dim**0.5) 

1142 if window_size_left is not None: 

1143 non_null_window_left = window_size_left 

1144 else: 

1145 non_null_window_left = -1 

1146 if window_size_right is not None: 

1147 non_null_window_right = window_size_right 

1148 else: 

1149 non_null_window_right = -1 

1150 

1151 out = torch.empty_like(query) 

1152 if cumulative_sequence_length_q is not None: 

1153 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd( 

1154 query, 

1155 key, 

1156 value, 

1157 out, 

1158 cumulative_sequence_length_q, 

1159 cumulative_sequence_length_k, 

1160 seqused_k, 

1161 None, 

1162 None, # block_table 

1163 alibi_slopes, 

1164 max_q, 

1165 max_k, 

1166 dropout_p, 

1167 scale, 

1168 False, 

1169 is_causal, 

1170 non_null_window_left, 

1171 non_null_window_right, 

1172 softcap, 

1173 return_debug_mask and dropout_p > 0, 

1174 None, 

1175 ) 

1176 else: 

1177 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( 

1178 query, 

1179 key, 

1180 value, 

1181 out, 

1182 alibi_slopes, 

1183 dropout_p, 

1184 softmax_scale, 

1185 is_causal, 

1186 non_null_window_left, 

1187 non_null_window_right, 

1188 softcap, 

1189 return_debug_mask, 

1190 disable_splitkv=disable_splitkv, 

1191 ) 

1192 

1193 if HEAD_DIM_K != original_head_dim: 

1194 out = out[..., :original_head_dim] 

1195 return (out, lse, philox_seed, philox_offset, p) 

1196 

1197 

1198# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py 

1199def maybe_contiguous(x): 

1200 return x.contiguous() if x is not None and x.stride(-1) != 1 else x 

1201 

1202 

1203def flash_attn_varlen_func( 

1204 q, 

1205 k, 

1206 v, 

1207 max_seqlen_q, 

1208 cu_seqlens_q, 

1209 max_seqlen_k, 

1210 cu_seqlens_k=None, # only used for non-paged prefill 

1211 seqused_k=None, 

1212 q_v=None, 

1213 dropout_p=0.0, 

1214 softmax_scale=None, 

1215 causal=False, 

1216 window_size=None, 

1217 softcap=0.0, # 0.0 means deactivated 

1218 alibi_slopes=None, 

1219 deterministic=False, 

1220 return_attn_probs=False, 

1221 block_table=None, 

1222 return_softmax_lse=False, 

1223 out=None, 

1224 # Dummy FA3 arguments 

1225 scheduler_metadata=None, 

1226 q_descale=None, 

1227 k_descale=None, 

1228 v_descale=None, 

1229 s_aux=None, 

1230 num_splits: int = 0, 

1231 cp_world_size: int = 1, 

1232 cp_rank: int = 0, 

1233 cp_tot_seqused_k=None, 

1234 fa_version: int = 2, 

1235): 

1236 """dropout_p should be set to 0.0 during evaluation 

1237 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1238 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1239 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1240 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1241 

1242 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1243 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1244 1 1 1 1 0 

1245 1 1 1 1 1 

1246 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1247 0 0 

1248 0 0 

1249 0 0 

1250 1 0 

1251 1 1 

1252 If the row of the mask is all zero, the output will be zero. 

1253 

1254 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1255 will only attend to keys between 

1256 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1257 

1258 Arguments: 

1259 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1260 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1261 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1262 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1263 of the sequences in the batch, used to index into q. 

1264 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1265 of the sequences in the batch, used to index into kv. 

1266 max_seqlen_q: int. Maximum query sequence length in the batch. 

1267 max_seqlen_k: int. Maximum key sequence length in the batch. 

1268 dropout_p: float. Dropout probability. 

1269 softmax_scale: float. The scaling of QK^T before applying softmax. 

1270 Default to 1 / sqrt(headdim). 

1271 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1272 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1273 softcap: float. Anything > 0 activates softcapping attention. 

1274 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1275 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1276 is added to the attention score of query i and key j. 

1277 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1278 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1279 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1280 testing only. The returned probabilities are not guaranteed to be correct 

1281 (they might not have the right scaling). 

1282 Return: 

1283 out: (total, nheads, headdim). 

1284 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1285 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1286 normalization factor). 

1287 """ 

1288 if fa_version != 2: 

1289 raise RuntimeError("Only FA2 is implemented.") 

1290 if num_splits > 0: 

1291 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1292 if use_c_extension: 

1293 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1294 with torch_device_fn.device(q.device): 

1295 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1296 q, 

1297 k, 

1298 v, 

1299 max_seqlen_q, 

1300 cu_seqlens_q, 

1301 max_seqlen_k, 

1302 cu_seqlens_k, 

1303 seqused_k, 

1304 q_v, 

1305 dropout_p, 

1306 softmax_scale, 

1307 causal, 

1308 window_size, 

1309 softcap, 

1310 alibi_slopes, 

1311 deterministic, 

1312 return_attn_probs, 

1313 block_table, 

1314 return_softmax_lse, 

1315 out, 

1316 scheduler_metadata, 

1317 q_descale, 

1318 k_descale, 

1319 v_descale, 

1320 s_aux, 

1321 num_splits, 

1322 cp_world_size, 

1323 cp_rank, 

1324 cp_tot_seqused_k, 

1325 fa_version, 

1326 ) 

1327 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1328 else: 

1329 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC") 

1330 assert ( 

1331 cu_seqlens_k is not None or seqused_k is not None 

1332 ), "cu_seqlens_k or seqused_k must be provided" 

1333 assert ( 

1334 cu_seqlens_k is None or seqused_k is None 

1335 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1336 assert ( 

1337 block_table is None or seqused_k is not None 

1338 ), "seqused_k must be provided if block_table is provided" 

1339 if softmax_scale is None: 

1340 softmax_scale = q.shape[-1] ** (-0.5) 

1341 # custom op does not support non-tuple input 

1342 if window_size is None: 

1343 real_window_size = (-1, -1) 

1344 else: 

1345 assert len(window_size) == 2 

1346 real_window_size = (window_size[0], window_size[1]) 

1347 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1348 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1349 max_seqlen_q = ( 

1350 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1351 ) 

1352 max_seqlen_k = ( 

1353 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1354 ) 

1355 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd( 

1356 q, 

1357 k, 

1358 v, 

1359 out, 

1360 cu_seqlens_q, 

1361 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1362 # still wants it so we pass all zeros 

1363 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1364 seqused_k, 

1365 None, 

1366 block_table, 

1367 alibi_slopes, 

1368 max_seqlen_q, 

1369 max_seqlen_k, 

1370 dropout_p, 

1371 softmax_scale, 

1372 False, 

1373 causal, 

1374 real_window_size[0], 

1375 real_window_size[1], 

1376 softcap, 

1377 return_softmax_lse and dropout_p > 0, 

1378 None, 

1379 ) 

1380 

1381 return (out, softmax_lse) if return_softmax_lse else out 

1382 

1383 

1384def flash_attn_varlen_opt_func( 

1385 q, 

1386 k, 

1387 v, 

1388 max_seqlen_q, 

1389 cu_seqlens_q, 

1390 max_seqlen_k, 

1391 cu_seqlens_k=None, # only used for non-paged prefill 

1392 seqused_k=None, 

1393 q_v=None, 

1394 dropout_p=0.0, 

1395 softmax_scale=None, 

1396 causal=False, 

1397 window_size=None, 

1398 softcap=0.0, # 0.0 means deactivated 

1399 alibi_slopes=None, 

1400 deterministic=False, 

1401 return_attn_probs=False, 

1402 block_table=None, 

1403 return_softmax_lse=False, 

1404 out=None, 

1405 lse=None, 

1406 # Dummy FA3 arguments 

1407 scheduler_metadata=None, 

1408 q_descale=None, 

1409 k_descale=None, 

1410 v_descale=None, 

1411 s_aux=None, 

1412 num_splits: int = 0, 

1413 cp_world_size: int = 1, 

1414 cp_rank: int = 0, 

1415 cp_tot_seqused_k=None, 

1416 fa_version: int = 2, 

1417): 

1418 """dropout_p should be set to 0.0 during evaluation 

1419 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1420 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1421 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1422 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1423 

1424 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1425 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1426 1 1 1 1 0 

1427 1 1 1 1 1 

1428 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1429 0 0 

1430 0 0 

1431 0 0 

1432 1 0 

1433 1 1 

1434 If the row of the mask is all zero, the output will be zero. 

1435 

1436 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1437 will only attend to keys between 

1438 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1439 

1440 Arguments: 

1441 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1442 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1443 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1444 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1445 of the sequences in the batch, used to index into q. 

1446 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1447 of the sequences in the batch, used to index into kv. 

1448 max_seqlen_q: int. Maximum query sequence length in the batch. 

1449 max_seqlen_k: int. Maximum key sequence length in the batch. 

1450 dropout_p: float. Dropout probability. 

1451 softmax_scale: float. The scaling of QK^T before applying softmax. 

1452 Default to 1 / sqrt(headdim). 

1453 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1454 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1455 softcap: float. Anything > 0 activates softcapping attention. 

1456 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1457 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1458 is added to the attention score of query i and key j. 

1459 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1460 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1461 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1462 testing only. The returned probabilities are not guaranteed to be correct 

1463 (they might not have the right scaling). 

1464 Return: 

1465 out: (total, nheads, headdim). 

1466 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1467 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1468 normalization factor). 

1469 """ 

1470 if fa_version != 2: 

1471 raise RuntimeError("Only FA2 is implemented.") 

1472 if num_splits > 0: 

1473 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1474 if use_c_extension: 

1475 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1476 with torch_device_fn.device(q.device): 

1477 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1478 q, 

1479 k, 

1480 v, 

1481 max_seqlen_q, 

1482 cu_seqlens_q, 

1483 max_seqlen_k, 

1484 cu_seqlens_k, 

1485 seqused_k, 

1486 q_v, 

1487 dropout_p, 

1488 softmax_scale, 

1489 causal, 

1490 window_size, 

1491 softcap, 

1492 alibi_slopes, 

1493 deterministic, 

1494 return_attn_probs, 

1495 block_table, 

1496 return_softmax_lse, 

1497 out, 

1498 scheduler_metadata, 

1499 q_descale, 

1500 k_descale, 

1501 v_descale, 

1502 s_aux, 

1503 num_splits, 

1504 cp_world_size, 

1505 cp_rank, 

1506 cp_tot_seqused_k, 

1507 fa_version, 

1508 ) 

1509 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1510 else: 

1511 logger.debug("GEMS FLASH_ATTN_VARLEN_OPT_FUNC") 

1512 assert ( 

1513 cu_seqlens_k is not None or seqused_k is not None 

1514 ), "cu_seqlens_k or seqused_k must be provided" 

1515 assert ( 

1516 cu_seqlens_k is None or seqused_k is None 

1517 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1518 assert ( 

1519 block_table is None or seqused_k is not None 

1520 ), "seqused_k must be provided if block_table is provided" 

1521 if softmax_scale is None: 

1522 softmax_scale = q.shape[-1] ** (-0.5) 

1523 # custom op does not support non-tuple input 

1524 if window_size is None: 

1525 real_window_size = (-1, -1) 

1526 else: 

1527 assert len(window_size) == 2 

1528 real_window_size = (window_size[0], window_size[1]) 

1529 # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1530 # dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1531 max_seqlen_q = ( 

1532 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1533 ) 

1534 max_seqlen_k = ( 

1535 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1536 ) 

1537 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd_opt( 

1538 q, 

1539 k, 

1540 v, 

1541 out, 

1542 lse, 

1543 cu_seqlens_q, 

1544 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1545 # still wants it so we pass all zeros 

1546 # dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1547 cu_seqlens_q if cu_seqlens_k is None else cu_seqlens_k, 

1548 seqused_k, 

1549 None, 

1550 block_table, 

1551 alibi_slopes, 

1552 max_seqlen_q, 

1553 max_seqlen_k, 

1554 dropout_p, 

1555 softmax_scale, 

1556 False, 

1557 causal, 

1558 real_window_size[0], 

1559 real_window_size[1], 

1560 softcap, 

1561 return_softmax_lse and dropout_p > 0, 

1562 None, 

1563 ) 

1564 

1565 return (out, softmax_lse) if return_softmax_lse else out