Coverage for src/flag_gems/runtime/backend/_sunrise/ops/attention.py: 0%

440 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from functools import partial 

4 

5import torch 

6import torch.nn.functional as F 

7import triton 

8import triton.language as tl 

9 

10from flag_gems import runtime 

11from flag_gems.config import use_c_extension 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.runtime.backend._sunrise.ops.flash_api import ( 

14 mha_fwd, 

15 mha_varlan_fwd, 

16 mha_varlan_fwd_opt, 

17) 

18from flag_gems.runtime.backend._sunrise.ops.flash_kernel import keep 

19from flag_gems.utils import libentry, libtuner 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 

25@triton.jit 

26def _attn_fwd_inner( 

27 acc, 

28 l_i, 

29 m_i, 

30 query, # 

31 K_block_ptr, 

32 V_block_ptr, # 

33 mask_block_ptr, # 

34 stride_k_seqlen, 

35 stride_v_seqlen, 

36 stride_attn_mask_kv_seqlen, # 

37 start_m, 

38 qk_scale, # 

39 q_load_mask, 

40 BLOCK_M: tl.constexpr, 

41 HEAD_DIM: tl.constexpr, 

42 BLOCK_N: tl.constexpr, # 

43 STAGE: tl.constexpr, 

44 offs_m: tl.constexpr, 

45 offs_n: tl.constexpr, # 

46 KV_CTX: tl.constexpr, 

47 fp8_v: tl.constexpr, 

48 HAS_ATTN_MASK: tl.constexpr, 

49 PRE_LOAD_V: tl.constexpr, 

50): 

51 # range of values handled by this stage 

52 if STAGE == 1: 

53 lo, hi = 0, start_m * BLOCK_M 

54 elif STAGE == 2: 

55 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

56 # causal = False 

57 else: 

58 lo, hi = 0, KV_CTX 

59 

60 K_block_ptr += lo * stride_k_seqlen 

61 V_block_ptr += lo * stride_v_seqlen 

62 if HAS_ATTN_MASK: 

63 mask_block_ptr += lo * stride_attn_mask_kv_seqlen 

64 

65 LOG2E = 1.44269504 # log2(e) constant 

66 

67 # loop over key, value and update accumulator 

68 for start_n in range(lo, hi, BLOCK_N): 

69 kv_load_mask = (start_n + offs_n) < KV_CTX 

70 # start_n = tl.multiple_of(start_n, BLOCK_N) 

71 # -- compute qk ---- 

72 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0) 

73 if PRE_LOAD_V: 

74 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

75 

76 qk = tl.dot(query, key, allow_tf32=False) 

77 # incase not divisible. 

78 qk = tl.where(kv_load_mask[None, :], qk, -float("inf")) 

79 # qk = qk.to(tl.float32) 

80 

81 if HAS_ATTN_MASK: 

82 attn_mask = tl.load( 

83 mask_block_ptr, 

84 mask=q_load_mask[:, None] & kv_load_mask[None, :], 

85 other=0.0, 

86 ) 

87 

88 if STAGE == 2: 

89 mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 

90 

91 if HAS_ATTN_MASK: 

92 qk = qk * qk_scale + attn_mask 

93 qk *= LOG2E 

94 qk = qk + tl.where(mask, 0, -1.0e6) 

95 else: 

96 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6) 

97 

98 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

99 qk -= m_ij[:, None] 

100 else: 

101 qk *= qk_scale * LOG2E 

102 if HAS_ATTN_MASK: 

103 qk = qk + attn_mask 

104 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

105 qk = qk - m_ij[:, None] 

106 

107 p = tl.math.exp2(qk) 

108 l_ij = tl.sum(p, 1) 

109 # -- update m_i and l_i 

110 alpha = tl.math.exp2(m_i - m_ij) 

111 l_i = l_i * alpha + l_ij 

112 # -- update output accumulator -- 

113 acc = acc * alpha[:, None] 

114 # update acc 

115 if not PRE_LOAD_V: 

116 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

117 if fp8_v: 

118 p = p.to(tl.float8e5) 

119 else: 

120 p = p.to(query.dtype) 

121 p = p.to(value.dtype) 

122 acc = tl.dot(p, value, acc, allow_tf32=False) 

123 # update m_i and l_i 

124 m_i = m_ij 

125 

126 K_block_ptr += BLOCK_N * stride_k_seqlen 

127 V_block_ptr += BLOCK_N * stride_v_seqlen 

128 

129 if HAS_ATTN_MASK: 

130 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen 

131 

132 return acc, l_i, m_i 

133 

134 

135# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim, 

136# we need to generate more configs. 

137configs = runtime.get_tuned_config("attention") 

138SMALL_HEAD_DIM_CONFIGS = [ 

139 triton.Config( 

140 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w 

141 ) 

142 for BM in [16, 32] 

143 for BN in [8, 16] 

144 for s in [0] 

145 for w in [8, 16] 

146] 

147# configs += SMALL_HEAD_DIM_CONFIGS 

148configs = SMALL_HEAD_DIM_CONFIGS 

149 

150 

151@libentry() 

152@libtuner( 

153 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)), 

154 key=["KV_CTX", "HEAD_DIM"], 

155) 

156@triton.jit 

157def _attn_fwd( 

158 Q, 

159 K, 

160 V, 

161 attn_mask, 

162 sm_scale, 

163 M, 

164 Out, # 

165 stride_q_batch, 

166 stride_q_head, 

167 stride_q_seqlen, 

168 stride_q_headsize, 

169 stride_k_batch, 

170 stride_k_head, 

171 stride_k_seqlen, 

172 stride_k_headsize, 

173 stride_v_batch, 

174 stride_v_head, 

175 stride_v_seqlen, 

176 stride_v_headsize, 

177 stride_attn_mask_batch, 

178 stride_attn_mask_head, 

179 stride_attn_mask_q_seqlen, 

180 stride_attn_mask_kv_seqlen, 

181 stride_o_batch, 

182 stride_o_head, 

183 stride_o_seqlen, 

184 stride_o_headsize, 

185 Z, 

186 q_head_num, 

187 kv_head_num, 

188 GROUP_HEAD: tl.constexpr, 

189 Q_CTX, 

190 KV_CTX, 

191 HEAD_DIM: tl.constexpr, 

192 BLOCK_M: tl.constexpr, 

193 BLOCK_N: tl.constexpr, 

194 STAGE: tl.constexpr, 

195 HAS_ATTN_MASK: tl.constexpr, 

196 PRE_LOAD_V: tl.constexpr, 

197): 

198 tl.static_assert(BLOCK_N <= HEAD_DIM) 

199 start_m = tl.program_id(0) 

200 off_hz = tl.program_id(1) 

201 batch_id = off_hz // q_head_num 

202 head_id = off_hz % q_head_num 

203 kv_head_id = head_id // GROUP_HEAD 

204 

205 q_offset = ( 

206 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head 

207 ) 

208 o_offset = ( 

209 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head 

210 ) 

211 kv_offset = ( 

212 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head 

213 ) 

214 

215 offs_headsize = tl.arange(0, HEAD_DIM) 

216 

217 # initialize offsets 

218 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 

219 q_load_mask = offs_m < Q_CTX 

220 offs_n = tl.arange(0, BLOCK_N) 

221 

222 Q_block_ptr = ( 

223 Q 

224 + q_offset 

225 + offs_m[:, None] * stride_q_seqlen 

226 + offs_headsize[None, :] * stride_q_headsize 

227 ) 

228 K_block_ptr = ( 

229 K 

230 + kv_offset 

231 + offs_n[None, :] * stride_k_seqlen 

232 + offs_headsize[:, None] * stride_k_headsize 

233 ) 

234 V_block_ptr = ( 

235 V 

236 + kv_offset 

237 + offs_n[:, None] * stride_v_seqlen 

238 + offs_headsize[None, :] * stride_v_headsize 

239 ) 

240 

241 if HAS_ATTN_MASK: 

242 attn_mask_offset = ( 

243 batch_id.to(tl.int64) * stride_attn_mask_batch 

244 + head_id.to(tl.int64) * stride_attn_mask_head 

245 ) 

246 mask_block_ptr = ( 

247 attn_mask 

248 + attn_mask_offset 

249 + offs_m[:, None] * stride_attn_mask_q_seqlen 

250 + offs_n[None, :] * stride_attn_mask_kv_seqlen 

251 ) 

252 else: 

253 mask_block_ptr = None 

254 

255 O_block_ptr = ( 

256 Out 

257 + o_offset 

258 + offs_m[:, None] * stride_o_seqlen 

259 + offs_headsize[None, :] * stride_o_headsize 

260 ) 

261 

262 # initialize pointer to m and l 

263 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 

264 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 

265 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

266 # load scales 

267 qk_scale = sm_scale 

268 # qk_scale *= 1.44269504 # 1/log(2) 

269 # load query: it will stay in SRAM throughout 

270 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0) 

271 # stage 1: off-band 

272 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 

273 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 

274 if STAGE & 1: 

275 acc, l_i, m_i = _attn_fwd_inner( 

276 acc, 

277 l_i, 

278 m_i, 

279 query, 

280 K_block_ptr, 

281 V_block_ptr, 

282 mask_block_ptr, 

283 stride_k_seqlen, 

284 stride_v_seqlen, 

285 stride_attn_mask_kv_seqlen, 

286 start_m, 

287 qk_scale, 

288 q_load_mask, 

289 BLOCK_M, 

290 HEAD_DIM, 

291 BLOCK_N, 

292 4 - STAGE, 

293 offs_m, 

294 offs_n, 

295 KV_CTX, 

296 V.dtype.element_ty == tl.float8e5, 

297 HAS_ATTN_MASK, 

298 PRE_LOAD_V, 

299 ) 

300 # stage 2: on-band 

301 if STAGE & 2: 

302 # barrier makes it easier for compielr to schedule the 

303 # two loops independently 

304 acc, l_i, m_i = _attn_fwd_inner( 

305 acc, 

306 l_i, 

307 m_i, 

308 query, 

309 K_block_ptr, 

310 V_block_ptr, 

311 mask_block_ptr, 

312 stride_k_seqlen, 

313 stride_v_seqlen, 

314 stride_attn_mask_kv_seqlen, 

315 start_m, 

316 qk_scale, 

317 q_load_mask, 

318 BLOCK_M, 

319 HEAD_DIM, 

320 BLOCK_N, 

321 2, 

322 offs_m, 

323 offs_n, 

324 KV_CTX, 

325 V.dtype.element_ty == tl.float8e5, 

326 HAS_ATTN_MASK, 

327 PRE_LOAD_V, 

328 ) 

329 # epilogue 

330 m_i += tl.math.log2(l_i) 

331 acc = acc / l_i[:, None] 

332 m_ptrs = M + off_hz * Q_CTX + offs_m 

333 tl.store(m_ptrs, m_i, mask=q_load_mask) 

334 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) 

335 

336 

337@triton.jit 

338def _attn_bwd_preprocess( 

339 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr 

340): 

341 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

342 mask = off_m < Q_CTX 

343 

344 off_hz = tl.program_id(1) 

345 off_n = tl.arange(0, D_HEAD) 

346 # load 

347 o = tl.load( 

348 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

349 mask=mask[:, None], 

350 other=0.0, 

351 ) 

352 do = tl.load( 

353 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

354 mask=mask[:, None], 

355 other=0.0, 

356 ).to(tl.float32) 

357 delta = tl.sum(o * do, axis=1) 

358 # write-back 

359 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) 

360 

361 

362# The main inner-loop logic for computing dK and dV. 

363@triton.jit 

364def _attn_bwd_dkdv( 

365 dk, 

366 dv, # 

367 Q, 

368 key, 

369 value, 

370 sm_scale, # 

371 DO, # 

372 M, 

373 D, # 

374 # shared by Q/K/V/DO. 

375 stride_tok, 

376 stride_d, # 

377 H, 

378 Q_CTX, 

379 KV_CTX, 

380 BLOCK_M1: tl.constexpr, # 

381 BLOCK_N1: tl.constexpr, # 

382 BLOCK_DMODEL: tl.constexpr, # 

383 # Filled in by the wrapper. 

384 start_n, 

385 start_m, 

386 num_steps, # 

387 MASK: tl.constexpr, 

388): 

389 # BLOCK_M1: 32 

390 # BLOCK_N1: 128 

391 offs_n = start_n + tl.arange(0, BLOCK_N1) 

392 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) 

393 

394 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) 

395 

396 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 

397 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 

398 curr_m = start_m 

399 step_m = BLOCK_M1 

400 for blk_idx in range(num_steps): 

401 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) 

402 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) 

403 

404 qT_ptrs = ( 

405 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 

406 ) # (BLOCK_DMODEL, BLOCK_M1) 

407 do_ptrs = ( 

408 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

409 ) # (BLOCK_M1, BLOCK_DMODEL) 

410 

411 qT = tl.load( 

412 qT_ptrs, mask=offs_m_mask[None, :], other=0.0 

413 ) # (BLOCK_DMODEL, BLOCK_M1) 

414 

415 # Load m before computing qk to reduce pipeline stall. 

416 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) 

417 

418 # key: (BLOCK_N1, BLOCK_DMODEL) 

419 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) 

420 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) 

421 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) 

422 pT = tl.math.exp2(qkT - m) 

423 # pT = tl.math.exp2(qkT - m[None, :]) 

424 

425 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ 

426 :, None 

427 ] # (BLOCK_N1, BLOCK_M1) 

428 # Autoregressive masking. 

429 if MASK: 

430 mask &= offs_m[None, :] >= offs_n[:, None] 

431 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) 

432 

433 do = tl.load(do_ptrs) 

434 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL) 

435 

436 # Compute dV. 

437 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) 

438 # D (= delta) is pre-divided by ds_scale. 

439 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) 

440 

441 # Compute dP and dS. 

442 dpT = tl.dot(value, tl.trans(do)).to( 

443 tl.float32 

444 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) 

445 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) 

446 dsT = dsT.to(qT.dtype) 

447 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) 

448 dsT = tl.where( 

449 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 

450 ) # (BLOCK_N1, BLOCK_M1) 

451 dk += tl.dot( 

452 dsT, tl.trans(qT) 

453 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) 

454 # Increment pointers. 

455 curr_m += step_m 

456 return dk, dv 

457 

458 

459# the main inner-loop logic for computing dQ 

460@triton.jit 

461def _attn_bwd_dq( 

462 dq, 

463 query, 

464 K, 

465 V, # 

466 do, 

467 m, 

468 D, 

469 # shared by Q/K/V/DO. 

470 stride_tok, 

471 stride_d, # 

472 H, 

473 Q_CTX, # 

474 KV_CTX, # 

475 BLOCK_M2: tl.constexpr, # 

476 BLOCK_N2: tl.constexpr, # 

477 BLOCK_DMODEL: tl.constexpr, 

478 # Filled in by the wrapper. 

479 start_m, 

480 start_n, 

481 num_steps, # 

482 MASK: tl.constexpr, 

483): 

484 offs_m = start_m + tl.arange(0, BLOCK_M2) 

485 offs_m_mask = offs_m < Q_CTX 

486 

487 offs_k = tl.arange(0, BLOCK_DMODEL) 

488 # D (= delta) is pre-divided by ds_scale. 

489 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) 

490 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 

491 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 

492 curr_n = start_n 

493 step_n = BLOCK_N2 

494 for blk_idx in range(num_steps): 

495 offs_n = curr_n + tl.arange(0, BLOCK_N2) 

496 offs_n_mask = offs_n < KV_CTX 

497 

498 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

499 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

500 

501 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

502 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

503 qk = tl.dot(query, kT) 

504 p = tl.math.exp2(qk - m) 

505 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] 

506 # Autoregressive masking. 

507 if MASK: 

508 # mask = (offs_m[:, None] >= offs_n[None, :]) 

509 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] 

510 mask &= offs_m[:, None] >= offs_n[None, :] 

511 p = tl.where(mask, p, 0.0) 

512 # Compute dP and dS. 

513 dp = tl.dot(do, vT).to(tl.float32) 

514 ds = p * (dp - Di[:, None]) 

515 ds = tl.where(mask, ds, 0.0).to(kT.dtype) 

516 # Compute dQ. 

517 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 

518 dq += tl.dot(ds, tl.trans(kT)) 

519 # Increment pointers. 

520 curr_n += step_n 

521 return dq 

522 

523 

524config_backward = runtime.get_tuned_config("attention_bwd") 

525 

526 

527@libentry() 

528@libtuner( 

529 configs=config_backward, 

530 key=["KV_CTX", "BLOCK_DMODEL"], 

531) 

532@triton.jit 

533def _attn_bwd( 

534 Q, 

535 K, 

536 V, 

537 sm_scale, # 

538 DO, # 

539 DQ, 

540 DK, 

541 DV, # 

542 M, 

543 D, 

544 # shared by Q/K/V/DO. 

545 stride_z, 

546 stride_h, 

547 stride_tok, 

548 stride_d, # 

549 kv_stride_z, 

550 kv_stride_h, # 

551 H, # query head num 

552 Q_CTX, # 

553 KV_CTX, # 

554 kv_head_num, # 

555 GROUP_HEAD: tl.constexpr, # 

556 BLOCK_M1: tl.constexpr, # 

557 BLOCK_N1: tl.constexpr, # 

558 BLOCK_M2: tl.constexpr, # 

559 BLOCK_N2: tl.constexpr, # 

560 BLK_SLICE_FACTOR: tl.constexpr, # 

561 BLOCK_DMODEL: tl.constexpr, 

562): 

563 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.") 

564 

565 LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 

566 

567 bhid = tl.program_id(2) 

568 off_chz = (bhid * Q_CTX).to(tl.int64) 

569 batch_id = bhid // H 

570 q_head_id = bhid % H 

571 kv_head_id = q_head_id // GROUP_HEAD 

572 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) 

573 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) 

574 

575 pid = tl.program_id(0) 

576 

577 # offset pointers for batch/head 

578 Q += adj 

579 K += kv_adj 

580 V += kv_adj 

581 DO += adj 

582 DQ += adj 

583 DK += adj 

584 DV += adj 

585 M += off_chz 

586 D += off_chz 

587 

588 # load scales 

589 offs_k = tl.arange(0, BLOCK_DMODEL) 

590 

591 start_n = pid * BLOCK_N1 

592 start_m = start_n 

593 

594 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 

595 offs_n = start_n + tl.arange(0, BLOCK_N1) 

596 offs_n_mask = offs_n < KV_CTX 

597 

598 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

599 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

600 

601 # load K and V: they stay in SRAM throughout the inner loop. 

602 key = tl.load( 

603 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

604 mask=offs_n_mask[:, None], 

605 other=0.0, 

606 ) 

607 value = tl.load( 

608 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

609 mask=offs_n_mask[:, None], 

610 other=0.0, 

611 ) 

612 

613 num_steps = BLOCK_N1 // MASK_BLOCK_M1 

614 

615 dk, dv = _attn_bwd_dkdv( 

616 dk, 

617 dv, # 

618 Q, 

619 key, 

620 value, 

621 sm_scale, # 

622 DO, # 

623 M, 

624 D, # 

625 stride_tok, 

626 stride_d, # 

627 H, 

628 Q_CTX, # 

629 KV_CTX, # 

630 MASK_BLOCK_M1, 

631 BLOCK_N1, 

632 BLOCK_DMODEL, # 

633 start_n, 

634 start_m, 

635 num_steps, # 

636 MASK=True, # 

637 ) 

638 

639 # Compute dK and dV for non-masked blocks. 

640 start_m += num_steps * MASK_BLOCK_M1 

641 remaining_m = Q_CTX - start_m 

642 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 

643 

644 if num_steps > 0 and start_m < Q_CTX: 

645 dk, dv = _attn_bwd_dkdv( # 

646 dk, 

647 dv, # 

648 Q, 

649 key, 

650 value, 

651 sm_scale, # 

652 DO, # 

653 M, 

654 D, # 

655 stride_tok, 

656 stride_d, # 

657 H, 

658 Q_CTX, # 

659 KV_CTX, # 

660 BLOCK_M1, 

661 BLOCK_N1, 

662 BLOCK_DMODEL, # 

663 start_n, 

664 start_m, 

665 num_steps, # 

666 MASK=False, # 

667 ) 

668 # tl.device_print("dv: ", dv) 

669 

670 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

671 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) 

672 

673 # Write back dK. 

674 dk *= sm_scale 

675 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

676 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) 

677 

678 # THIS BLOCK DOES DQ: 

679 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 

680 start_m = pid * BLOCK_M2 

681 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX 

682 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 

683 

684 offs_m = start_m + tl.arange(0, BLOCK_M2) 

685 offs_m_mask = offs_m < Q_CTX 

686 

687 query = tl.load( 

688 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

689 mask=offs_m_mask[:, None], 

690 other=0.0, 

691 ) 

692 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) 

693 do = tl.load( 

694 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

695 mask=offs_m_mask[:, None], 

696 other=0.0, 

697 ) 

698 

699 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) 

700 m = m[:, None] 

701 

702 # Stage 1 - Compute dQ for masked (diagonal) blocks. 

703 # NOTE: This code scans each row of QK^T backward (from right to left, 

704 # but inside each call to _attn_bwd_dq, from left to right), but that's 

705 # not due to anything important. I just wanted to reuse the loop 

706 # structure for dK & dV above as much as possible. 

707 

708 if num_steps > 0: 

709 dq = _attn_bwd_dq( 

710 dq, 

711 query, 

712 K, 

713 V, # 

714 do, 

715 m, 

716 D, # 

717 stride_tok, 

718 stride_d, # 

719 H, 

720 Q_CTX, # 

721 KV_CTX, # 

722 BLOCK_M2, 

723 MASK_BLOCK_N2, 

724 BLOCK_DMODEL, # 

725 start_m, 

726 start_n, 

727 num_steps, # 

728 MASK=True, # 

729 ) 

730 

731 # Stage 2 - non-masked blocks 

732 stage2_end_n = start_n 

733 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2 

734 

735 if stage2_num_steps > 0: 

736 dq = _attn_bwd_dq( 

737 dq, 

738 query, 

739 K, 

740 V, # 

741 do, 

742 m, 

743 D, # 

744 stride_tok, 

745 stride_d, # 

746 H, 

747 Q_CTX, # 

748 KV_CTX, # 

749 BLOCK_M2, 

750 BLOCK_N2, 

751 BLOCK_DMODEL, # 

752 start_m, 

753 stage2_end_n - stage2_num_steps * BLOCK_N2, 

754 stage2_num_steps, # 

755 MASK=False, # 

756 ) 

757 # Write back dQ. 

758 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

759 dq *= LN2 

760 # tl.store(dq_ptrs, dq) 

761 

762 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) 

763 

764 

765def scaled_dot_product_attention_forward( 

766 query, 

767 key, 

768 value, 

769 attn_mask=None, 

770 dropout_p=0.0, 

771 is_causal=False, 

772 scale=None, 

773 enable_gqa=False, 

774): 

775 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD") 

776 # shape constraints 

777 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

778 # when v is in float8_e5m2 it is transposed. 

779 HEAD_DIM_V = value.shape[-1] 

780 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

781 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

782 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

783 

784 o = torch.empty_like(query, dtype=value.dtype) 

785 

786 stage = 3 if is_causal else 1 

787 

788 if scale is None: 

789 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

790 else: 

791 sm_scale = scale 

792 

793 q_head_num = query.shape[1] 

794 kv_head_num = key.shape[1] 

795 assert enable_gqa or q_head_num == kv_head_num, ( 

796 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " 

797 "enable_gqa must be True to support different head numbers." 

798 ) 

799 

800 grid = lambda args: ( 

801 triton.cdiv(query.shape[2], args["BLOCK_M"]), 

802 query.shape[0] * query.shape[1], 

803 1, 

804 ) 

805 

806 if attn_mask is not None: 

807 HAS_ATTN_MASK = True 

808 if attn_mask.dtype == torch.bool: 

809 attn_mask = attn_mask.to(query.dtype) * -1.0e6 

810 stride_attn_mask_batch = attn_mask.stride(0) 

811 stride_attn_mask_head = attn_mask.stride(1) 

812 stride_attn_mask_q_seqlen = attn_mask.stride(2) 

813 stride_attn_mask_kv_seqlen = attn_mask.stride(3) 

814 else: 

815 HAS_ATTN_MASK = False 

816 stride_attn_mask_batch = 1 

817 stride_attn_mask_head = 1 

818 stride_attn_mask_q_seqlen = 1 

819 stride_attn_mask_kv_seqlen = 1 

820 

821 M = torch.empty( 

822 (query.shape[0], query.shape[1], query.shape[2]), 

823 device=query.device, 

824 dtype=torch.float32, 

825 ) 

826 

827 with torch_device_fn.device(query.device): 

828 _attn_fwd[grid]( 

829 query, 

830 key, 

831 value, 

832 attn_mask, 

833 sm_scale, 

834 M, 

835 o, # 

836 query.stride(0), 

837 query.stride(1), 

838 query.stride(2), 

839 query.stride(3), # 

840 key.stride(0), 

841 key.stride(1), 

842 key.stride(2), 

843 key.stride(3), # 

844 value.stride(0), 

845 value.stride(1), 

846 value.stride(2), 

847 value.stride(3), # 

848 stride_attn_mask_batch, 

849 stride_attn_mask_head, 

850 stride_attn_mask_q_seqlen, 

851 stride_attn_mask_kv_seqlen, # 

852 o.stride(0), 

853 o.stride(1), 

854 o.stride(2), 

855 o.stride(3), # 

856 query.shape[0], 

857 q_head_num, 

858 kv_head_num, # 

859 q_head_num // kv_head_num, # group_head 

860 query.shape[2], # 

861 key.shape[2], # 

862 HEAD_DIM_K, # 

863 STAGE=stage, # 

864 HAS_ATTN_MASK=HAS_ATTN_MASK, # 

865 ) 

866 return o, M 

867 

868 

869def scaled_dot_product_attention_backward( 

870 do, 

871 query, 

872 key, 

873 value, 

874 o, 

875 M, 

876 attn_mask=None, 

877 dropout_p=0.0, 

878 is_causal=False, 

879 scale=None, 

880 enable_gqa=False, 

881): 

882 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD") 

883 # shape constraints 

884 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

885 # when v is in float8_e5m2 it is transposed. 

886 HEAD_DIM_V = value.shape[-1] 

887 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

888 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

889 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

890 

891 if scale is None: 

892 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

893 else: 

894 sm_scale = scale 

895 

896 assert do.is_contiguous() 

897 assert ( 

898 query.is_contiguous() 

899 and key.is_contiguous() 

900 and value.is_contiguous() 

901 and o.is_contiguous() 

902 ) 

903 assert query.stride() == o.stride() == do.stride() 

904 assert key.stride() == value.stride() 

905 

906 BLOCK_DMODEL = HEAD_DIM_K 

907 BATCH, Q_HEAD, Q_CTX = query.shape[:3] 

908 _, KV_HEAD, KV_CTX = key.shape[:3] 

909 group_head = Q_HEAD // KV_HEAD 

910 

911 # NUM_WARPS, NUM_STAGES = 4, 1 

912 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 

913 BLK_SLICE_FACTOR = 2 

914 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 

915 

916 RCP_LN2 = 1.0 / math.log(2) 

917 

918 arg_k = key * (sm_scale * RCP_LN2) 

919 # PRE_BLOCK = 128 

920 PRE_BLOCK = 256 

921 

922 # PRE_BLOCK = 32 

923 # assert N_CTX % PRE_BLOCK == 0 

924 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) 

925 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) 

926 

927 delta = torch.empty_like(M) 

928 

929 # NOTE that dk & dv always have the same number of heads as q 

930 dq = torch.empty_like(query).contiguous() 

931 dk = torch.empty( 

932 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K), 

933 device=key.device, 

934 dtype=key.dtype, 

935 memory_format=torch.contiguous_format, 

936 ) 

937 dv = torch.empty( 

938 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V), 

939 device=value.device, 

940 dtype=value.dtype, 

941 memory_format=torch.contiguous_format, 

942 ) 

943 

944 _attn_bwd_preprocess[pre_grid]( 

945 o, 

946 do, # 

947 delta, # 

948 BATCH, 

949 Q_HEAD, 

950 Q_CTX, # 

951 BLOCK_M=PRE_BLOCK, 

952 D_HEAD=BLOCK_DMODEL, # 

953 ) 

954 

955 max_block_n1 = ( 

956 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward]) 

957 if config_backward 

958 else 128 

959 ) 

960 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD) 

961 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") 

962 # logger.info(f"{M.shape=}") 

963 

964 _attn_bwd[grid]( 

965 query, 

966 arg_k, 

967 value, 

968 sm_scale, 

969 do, 

970 dq, 

971 dk, 

972 dv, # 

973 M, 

974 delta, # 

975 query.stride(0), 

976 query.stride(1), 

977 query.stride(2), 

978 query.stride(3), # 

979 key.stride(0), 

980 key.stride(1), # 

981 Q_HEAD, 

982 Q_CTX, # 

983 KV_CTX, # 

984 KV_HEAD, # 

985 GROUP_HEAD=group_head, # 

986 # BLOCK_M1=BLOCK_M1, 

987 # BLOCK_N1=BLOCK_N1, # 

988 # BLOCK_M2=BLOCK_M2, 

989 # BLOCK_N2=BLOCK_N2, # 

990 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 

991 BLOCK_DMODEL=BLOCK_DMODEL, # 

992 # num_warps=NUM_WARPS, # 

993 # num_stages=NUM_STAGES, # 

994 ) 

995 

996 if group_head > 1: 

997 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) 

998 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) 

999 dk = dk.sum(dim=2) 

1000 dv = dv.sum(dim=2) 

1001 

1002 return dq, dk, dv 

1003 

1004 

1005class ScaleDotProductAttention(torch.autograd.Function): 

1006 @staticmethod 

1007 def forward( 

1008 ctx, 

1009 query, 

1010 key, 

1011 value, 

1012 attn_mask=None, 

1013 dropout_p=0.0, 

1014 is_causal=False, 

1015 scale=None, 

1016 enable_gqa=False, 

1017 ): 

1018 # [sunrise fix] padding for unsupported head dims, since auto lower unsupported. 

1019 head_size = key.shape[-1] 

1020 supported_head_dims = {16, 32, 64, 128, 256} 

1021 if head_size not in supported_head_dims: 

1022 padded_head_dim = None 

1023 for d in supported_head_dims: 

1024 if d >= head_size: 

1025 padded_head_dim = d 

1026 break 

1027 assert ( 

1028 padded_head_dim is not None 

1029 ), f"Unsupported head dim {head_size}, max supported is {supported_head_dims[-1]}" 

1030 pad = padded_head_dim - head_size 

1031 query = F.pad(query, (0, pad)) 

1032 key = F.pad(key, (0, pad)) 

1033 value = F.pad(value, (0, pad)) 

1034 

1035 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5) 

1036 o, M = scaled_dot_product_attention_forward( 

1037 query, 

1038 key, 

1039 value, 

1040 attn_mask, 

1041 dropout_p, 

1042 is_causal, 

1043 sm_scale, 

1044 enable_gqa, 

1045 ) 

1046 # [sunrise fix] padding for unsupported head dims, since auto lower unsupported. 

1047 if head_size not in supported_head_dims: 

1048 o = o[..., :head_size] 

1049 

1050 ctx.save_for_backward(query, key, value, o, M) 

1051 ctx.sm_scale = sm_scale 

1052 ctx.causal = is_causal 

1053 ctx.enable_gqa = enable_gqa 

1054 return o 

1055 

1056 @staticmethod 

1057 def backward(ctx, do): 

1058 query, key, value, o, M = ctx.saved_tensors 

1059 is_causal = ctx.causal 

1060 enable_gqa = ctx.enable_gqa 

1061 sm_scale = ctx.sm_scale 

1062 dq, dk, dv = scaled_dot_product_attention_backward( 

1063 do, 

1064 query, 

1065 key, 

1066 value, 

1067 o, 

1068 M, 

1069 attn_mask=None, 

1070 dropout_p=0.0, 

1071 is_causal=is_causal, 

1072 scale=sm_scale, 

1073 enable_gqa=enable_gqa, 

1074 ) 

1075 return dq, dk, dv, None, None, None, None, None 

1076 

1077 

1078def scaled_dot_product_attention( 

1079 query, 

1080 key, 

1081 value, 

1082 attn_mask=None, 

1083 dropout_p=0.0, 

1084 is_causal=False, 

1085 scale=None, 

1086 enable_gqa=False, 

1087): 

1088 return ScaleDotProductAttention.apply( 

1089 query, 

1090 key, 

1091 value, 

1092 attn_mask, 

1093 dropout_p, 

1094 is_causal, 

1095 scale, 

1096 enable_gqa, 

1097 ) 

1098 

1099 

1100def flash_attention_forward( 

1101 query, 

1102 key, 

1103 value, 

1104 cumulative_sequence_length_q, 

1105 cumulative_sequence_length_k, 

1106 max_q, 

1107 max_k, 

1108 dropout_p, 

1109 is_causal, 

1110 return_debug_mask, 

1111 *, 

1112 scale=None, 

1113 softcap=0.0, 

1114 window_size_left=None, 

1115 window_size_right=None, 

1116 seqused_k=None, 

1117 alibi_slopes=None, 

1118 disable_splitkv=False, 

1119): 

1120 logger.debug("GEMS FLASH_ATTENTION_FORWARD") 

1121 assert ( 

1122 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None 

1123 ), "varlen is not supported yet." 

1124 

1125 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

1126 HEAD_DIM_V = value.shape[-1] 

1127 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

1128 original_head_dim = HEAD_DIM_K 

1129 supported_head_dims = (16, 32, 64, 96, 128, 192, 256) 

1130 if HEAD_DIM_K not in supported_head_dims: 

1131 padded_head_dim = None 

1132 for d in supported_head_dims: 

1133 if d >= HEAD_DIM_K: 

1134 padded_head_dim = d 

1135 break 

1136 assert ( 

1137 padded_head_dim is not None 

1138 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}" 

1139 pad = padded_head_dim - HEAD_DIM_K 

1140 query = F.pad(query, (0, pad)) 

1141 key = F.pad(key, (0, pad)) 

1142 value = F.pad(value, (0, pad)) 

1143 HEAD_DIM_K = padded_head_dim 

1144 

1145 softmax_scale = scale or 1.0 / (original_head_dim**0.5) 

1146 if window_size_left is not None: 

1147 non_null_window_left = window_size_left 

1148 else: 

1149 non_null_window_left = -1 

1150 if window_size_right is not None: 

1151 non_null_window_right = window_size_right 

1152 else: 

1153 non_null_window_right = -1 

1154 

1155 out = torch.empty_like(query) 

1156 if cumulative_sequence_length_q is not None: 

1157 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd( 

1158 query, 

1159 key, 

1160 value, 

1161 out, 

1162 cumulative_sequence_length_q, 

1163 cumulative_sequence_length_k, 

1164 seqused_k, 

1165 None, 

1166 None, # block_table 

1167 alibi_slopes, 

1168 max_q, 

1169 max_k, 

1170 dropout_p, 

1171 scale, 

1172 False, 

1173 is_causal, 

1174 non_null_window_left, 

1175 non_null_window_right, 

1176 softcap, 

1177 return_debug_mask and dropout_p > 0, 

1178 None, 

1179 ) 

1180 else: 

1181 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( 

1182 query, 

1183 key, 

1184 value, 

1185 out, 

1186 alibi_slopes, 

1187 dropout_p, 

1188 softmax_scale, 

1189 is_causal, 

1190 non_null_window_left, 

1191 non_null_window_right, 

1192 softcap, 

1193 return_debug_mask, 

1194 disable_splitkv=disable_splitkv, 

1195 ) 

1196 

1197 if HEAD_DIM_K != original_head_dim: 

1198 out = out[..., :original_head_dim] 

1199 return (out, lse, philox_seed, philox_offset, p) 

1200 

1201 

1202# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py 

1203def maybe_contiguous(x): 

1204 return x.contiguous() if x is not None and x.stride(-1) != 1 else x 

1205 

1206 

1207def flash_attn_varlen_func( 

1208 q, 

1209 k, 

1210 v, 

1211 max_seqlen_q, 

1212 cu_seqlens_q, 

1213 max_seqlen_k, 

1214 cu_seqlens_k=None, # only used for non-paged prefill 

1215 seqused_k=None, 

1216 q_v=None, 

1217 dropout_p=0.0, 

1218 softmax_scale=None, 

1219 causal=False, 

1220 window_size=None, 

1221 softcap=0.0, # 0.0 means deactivated 

1222 alibi_slopes=None, 

1223 deterministic=False, 

1224 return_attn_probs=False, 

1225 block_table=None, 

1226 return_softmax_lse=False, 

1227 out=None, 

1228 # Dummy FA3 arguments 

1229 scheduler_metadata=None, 

1230 q_descale=None, 

1231 k_descale=None, 

1232 v_descale=None, 

1233 s_aux=None, 

1234 num_splits: int = 0, 

1235 cp_world_size: int = 1, 

1236 cp_rank: int = 0, 

1237 cp_tot_seqused_k=None, 

1238 fa_version: int = 2, 

1239): 

1240 """dropout_p should be set to 0.0 during evaluation 

1241 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1242 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1243 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1244 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1245 

1246 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1247 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1248 1 1 1 1 0 

1249 1 1 1 1 1 

1250 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1251 0 0 

1252 0 0 

1253 0 0 

1254 1 0 

1255 1 1 

1256 If the row of the mask is all zero, the output will be zero. 

1257 

1258 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1259 will only attend to keys between 

1260 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1261 

1262 Arguments: 

1263 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1264 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1265 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1266 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1267 of the sequences in the batch, used to index into q. 

1268 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1269 of the sequences in the batch, used to index into kv. 

1270 max_seqlen_q: int. Maximum query sequence length in the batch. 

1271 max_seqlen_k: int. Maximum key sequence length in the batch. 

1272 dropout_p: float. Dropout probability. 

1273 softmax_scale: float. The scaling of QK^T before applying softmax. 

1274 Default to 1 / sqrt(headdim). 

1275 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1276 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1277 softcap: float. Anything > 0 activates softcapping attention. 

1278 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1279 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1280 is added to the attention score of query i and key j. 

1281 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1282 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1283 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1284 testing only. The returned probabilities are not guaranteed to be correct 

1285 (they might not have the right scaling). 

1286 Return: 

1287 out: (total, nheads, headdim). 

1288 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1289 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1290 normalization factor). 

1291 """ 

1292 import os 

1293 

1294 os.environ["OFF_ASYNC"] = "1" 

1295 if fa_version != 2: 

1296 raise RuntimeError("Only FA2 is implemented.") 

1297 if num_splits > 0: 

1298 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1299 if use_c_extension: 

1300 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1301 with torch_device_fn.device(q.device): 

1302 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1303 q, 

1304 k, 

1305 v, 

1306 max_seqlen_q, 

1307 cu_seqlens_q, 

1308 max_seqlen_k, 

1309 cu_seqlens_k, 

1310 seqused_k, 

1311 q_v, 

1312 dropout_p, 

1313 softmax_scale, 

1314 causal, 

1315 window_size, 

1316 softcap, 

1317 alibi_slopes, 

1318 deterministic, 

1319 return_attn_probs, 

1320 block_table, 

1321 return_softmax_lse, 

1322 out, 

1323 scheduler_metadata, 

1324 q_descale, 

1325 k_descale, 

1326 v_descale, 

1327 s_aux, 

1328 num_splits, 

1329 cp_world_size, 

1330 cp_rank, 

1331 cp_tot_seqused_k, 

1332 fa_version, 

1333 ) 

1334 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1335 else: 

1336 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC") 

1337 assert ( 

1338 cu_seqlens_k is not None or seqused_k is not None 

1339 ), "cu_seqlens_k or seqused_k must be provided" 

1340 assert ( 

1341 cu_seqlens_k is None or seqused_k is None 

1342 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1343 assert ( 

1344 block_table is None or seqused_k is not None 

1345 ), "seqused_k must be provided if block_table is provided" 

1346 if softmax_scale is None: 

1347 softmax_scale = q.shape[-1] ** (-0.5) 

1348 # custom op does not support non-tuple input 

1349 if window_size is None: 

1350 real_window_size = (-1, -1) 

1351 else: 

1352 assert len(window_size) == 2 

1353 real_window_size = (window_size[0], window_size[1]) 

1354 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1355 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1356 max_seqlen_q = ( 

1357 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1358 ) 

1359 max_seqlen_k = ( 

1360 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1361 ) 

1362 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd( 

1363 q, 

1364 k, 

1365 v, 

1366 out, 

1367 cu_seqlens_q, 

1368 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1369 # still wants it so we pass all zeros 

1370 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1371 seqused_k, 

1372 None, 

1373 block_table, 

1374 alibi_slopes, 

1375 max_seqlen_q, 

1376 max_seqlen_k, 

1377 dropout_p, 

1378 softmax_scale, 

1379 False, 

1380 causal, 

1381 real_window_size[0], 

1382 real_window_size[1], 

1383 softcap, 

1384 return_softmax_lse and dropout_p > 0, 

1385 None, 

1386 ) 

1387 

1388 return (out, softmax_lse) if return_softmax_lse else out 

1389 

1390 

1391def flash_attn_varlen_opt_func( 

1392 q, 

1393 k, 

1394 v, 

1395 max_seqlen_q, 

1396 cu_seqlens_q, 

1397 max_seqlen_k, 

1398 cu_seqlens_k=None, # only used for non-paged prefill 

1399 seqused_k=None, 

1400 q_v=None, 

1401 dropout_p=0.0, 

1402 softmax_scale=None, 

1403 causal=False, 

1404 window_size=None, 

1405 softcap=0.0, # 0.0 means deactivated 

1406 alibi_slopes=None, 

1407 deterministic=False, 

1408 return_attn_probs=False, 

1409 block_table=None, 

1410 return_softmax_lse=False, 

1411 out=None, 

1412 lse=None, 

1413 # Dummy FA3 arguments 

1414 scheduler_metadata=None, 

1415 q_descale=None, 

1416 k_descale=None, 

1417 v_descale=None, 

1418 s_aux=None, 

1419 num_splits: int = 0, 

1420 cp_world_size: int = 1, 

1421 cp_rank: int = 0, 

1422 cp_tot_seqused_k=None, 

1423 fa_version: int = 2, 

1424): 

1425 """dropout_p should be set to 0.0 during evaluation 

1426 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1427 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1428 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1429 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1430 

1431 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1432 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1433 1 1 1 1 0 

1434 1 1 1 1 1 

1435 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1436 0 0 

1437 0 0 

1438 0 0 

1439 1 0 

1440 1 1 

1441 If the row of the mask is all zero, the output will be zero. 

1442 

1443 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1444 will only attend to keys between 

1445 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1446 

1447 Arguments: 

1448 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1449 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1450 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1451 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1452 of the sequences in the batch, used to index into q. 

1453 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1454 of the sequences in the batch, used to index into kv. 

1455 max_seqlen_q: int. Maximum query sequence length in the batch. 

1456 max_seqlen_k: int. Maximum key sequence length in the batch. 

1457 dropout_p: float. Dropout probability. 

1458 softmax_scale: float. The scaling of QK^T before applying softmax. 

1459 Default to 1 / sqrt(headdim). 

1460 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1461 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1462 softcap: float. Anything > 0 activates softcapping attention. 

1463 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1464 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1465 is added to the attention score of query i and key j. 

1466 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1467 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1468 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1469 testing only. The returned probabilities are not guaranteed to be correct 

1470 (they might not have the right scaling). 

1471 Return: 

1472 out: (total, nheads, headdim). 

1473 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1474 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1475 normalization factor). 

1476 """ 

1477 if fa_version != 2: 

1478 raise RuntimeError("Only FA2 is implemented.") 

1479 if num_splits > 0: 

1480 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1481 if use_c_extension: 

1482 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1483 with torch_device_fn.device(q.device): 

1484 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1485 q, 

1486 k, 

1487 v, 

1488 max_seqlen_q, 

1489 cu_seqlens_q, 

1490 max_seqlen_k, 

1491 cu_seqlens_k, 

1492 seqused_k, 

1493 q_v, 

1494 dropout_p, 

1495 softmax_scale, 

1496 causal, 

1497 window_size, 

1498 softcap, 

1499 alibi_slopes, 

1500 deterministic, 

1501 return_attn_probs, 

1502 block_table, 

1503 return_softmax_lse, 

1504 out, 

1505 scheduler_metadata, 

1506 q_descale, 

1507 k_descale, 

1508 v_descale, 

1509 s_aux, 

1510 num_splits, 

1511 cp_world_size, 

1512 cp_rank, 

1513 cp_tot_seqused_k, 

1514 fa_version, 

1515 ) 

1516 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1517 else: 

1518 logger.debug("GEMS FLASH_ATTN_VARLEN_OPT_FUNC") 

1519 assert ( 

1520 cu_seqlens_k is not None or seqused_k is not None 

1521 ), "cu_seqlens_k or seqused_k must be provided" 

1522 assert ( 

1523 cu_seqlens_k is None or seqused_k is None 

1524 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1525 assert ( 

1526 block_table is None or seqused_k is not None 

1527 ), "seqused_k must be provided if block_table is provided" 

1528 if softmax_scale is None: 

1529 softmax_scale = q.shape[-1] ** (-0.5) 

1530 # custom op does not support non-tuple input 

1531 if window_size is None: 

1532 real_window_size = (-1, -1) 

1533 else: 

1534 assert len(window_size) == 2 

1535 real_window_size = (window_size[0], window_size[1]) 

1536 # q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1537 # dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1538 max_seqlen_q = ( 

1539 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1540 ) 

1541 max_seqlen_k = ( 

1542 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1543 ) 

1544 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd_opt( 

1545 q, 

1546 k, 

1547 v, 

1548 out, 

1549 lse, 

1550 cu_seqlens_q, 

1551 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1552 # still wants it so we pass all zeros 

1553 # dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1554 cu_seqlens_q if cu_seqlens_k is None else cu_seqlens_k, 

1555 seqused_k, 

1556 None, 

1557 block_table, 

1558 alibi_slopes, 

1559 max_seqlen_q, 

1560 max_seqlen_k, 

1561 dropout_p, 

1562 softmax_scale, 

1563 False, 

1564 causal, 

1565 real_window_size[0], 

1566 real_window_size[1], 

1567 softcap, 

1568 return_softmax_lse and dropout_p > 0, 

1569 None, 

1570 ) 

1571 

1572 return (out, softmax_lse) if return_softmax_lse else out