Coverage for src/flag_gems/runtime/backend/_ascend/ops/attention.py: 0%

399 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3from functools import partial 

4 

5import torch 

6import torch.nn.functional as F 

7import triton 

8import triton.language as tl 

9 

10from flag_gems import runtime 

11from flag_gems.config import use_c_extension 

12from flag_gems.ops.flash_api import mha_fwd, mha_varlan_fwd 

13from flag_gems.ops.flash_kernel import keep 

14from flag_gems.runtime import torch_device_fn 

15from flag_gems.utils import libentry, libtuner 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 

21@triton.jit 

22def _attn_fwd_inner( 

23 acc, 

24 l_i, 

25 m_i, 

26 query, # 

27 K_block_ptr, 

28 V_block_ptr, # 

29 mask_block_ptr, # 

30 stride_k_seqlen, 

31 stride_v_seqlen, 

32 stride_attn_mask_kv_seqlen, # 

33 start_m, 

34 qk_scale, # 

35 q_load_mask, 

36 BLOCK_M: tl.constexpr, 

37 HEAD_DIM: tl.constexpr, 

38 BLOCK_N: tl.constexpr, # 

39 STAGE: tl.constexpr, 

40 offs_m: tl.constexpr, 

41 offs_n: tl.constexpr, # 

42 KV_CTX: tl.constexpr, 

43 fp8_v: tl.constexpr, 

44 HAS_ATTN_MASK: tl.constexpr, 

45 PRE_LOAD_V: tl.constexpr, 

46): 

47 # range of values handled by this stage 

48 if STAGE == 1: 

49 lo, hi = 0, start_m * BLOCK_M 

50 elif STAGE == 2: 

51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

52 # causal = False 

53 else: 

54 lo, hi = 0, KV_CTX 

55 

56 K_block_ptr += lo * stride_k_seqlen 

57 V_block_ptr += lo * stride_v_seqlen 

58 if HAS_ATTN_MASK: 

59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen 

60 

61 LOG2E = 1.44269504 # log2(e) constant 

62 

63 # loop over key, value and update accumulator 

64 for start_n in range(lo, hi, BLOCK_N): 

65 kv_load_mask = (start_n + offs_n) < KV_CTX 

66 # start_n = tl.multiple_of(start_n, BLOCK_N) 

67 # -- compute qk ---- 

68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0) 

69 if PRE_LOAD_V: 

70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

71 

72 qk = tl.dot(query, key, allow_tf32=False) 

73 # incase not divisible. 

74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf")) 

75 # qk = qk.to(tl.float32) 

76 

77 if HAS_ATTN_MASK: 

78 attn_mask = tl.load( 

79 mask_block_ptr, 

80 mask=q_load_mask[:, None] & kv_load_mask[None, :], 

81 other=0.0, 

82 ) 

83 

84 if STAGE == 2: 

85 mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 

86 

87 if HAS_ATTN_MASK: 

88 qk = qk * qk_scale + attn_mask 

89 qk *= LOG2E 

90 qk = qk + tl.where(mask, 0, -1.0e6) 

91 else: 

92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6) 

93 

94 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

95 qk -= m_ij[:, None] 

96 else: 

97 qk *= qk_scale * LOG2E 

98 if HAS_ATTN_MASK: 

99 qk = qk + attn_mask 

100 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

101 qk = qk - m_ij[:, None] 

102 

103 p = tl.math.exp2(qk) 

104 l_ij = tl.sum(p, 1) 

105 # -- update m_i and l_i 

106 alpha = tl.math.exp2(m_i - m_ij) 

107 l_i = l_i * alpha + l_ij 

108 # -- update output accumulator -- 

109 acc = acc * alpha[:, None] 

110 # update acc 

111 if not PRE_LOAD_V: 

112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

113 if fp8_v: 

114 p = p.to(tl.float8e5) 

115 else: 

116 p = p.to(query.dtype) 

117 p = p.to(value.dtype) 

118 acc = tl.dot(p, value, acc, allow_tf32=False) 

119 # update m_i and l_i 

120 m_i = m_ij 

121 

122 K_block_ptr += BLOCK_N * stride_k_seqlen 

123 V_block_ptr += BLOCK_N * stride_v_seqlen 

124 

125 if HAS_ATTN_MASK: 

126 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen 

127 

128 return acc, l_i, m_i 

129 

130 

131# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim, 

132# we need to generate more configs. 

133configs = runtime.get_tuned_config("attention") 

134SMALL_HEAD_DIM_CONFIGS = [ 

135 triton.Config( 

136 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w 

137 ) 

138 for BM in [64, 128] 

139 for BN in [16, 32] 

140 for s in [2, 3, 4] 

141 for w in [4, 8] 

142] 

143configs += SMALL_HEAD_DIM_CONFIGS 

144 

145 

146@libentry() 

147@libtuner( 

148 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)), 

149 key=["KV_CTX", "HEAD_DIM"], 

150) 

151@triton.jit 

152def _attn_fwd( 

153 Q, 

154 K, 

155 V, 

156 attn_mask, 

157 sm_scale, 

158 M, 

159 Out, # 

160 stride_q_batch, 

161 stride_q_head, 

162 stride_q_seqlen, 

163 stride_q_headsize, 

164 stride_k_batch, 

165 stride_k_head, 

166 stride_k_seqlen, 

167 stride_k_headsize, 

168 stride_v_batch, 

169 stride_v_head, 

170 stride_v_seqlen, 

171 stride_v_headsize, 

172 stride_attn_mask_batch, 

173 stride_attn_mask_head, 

174 stride_attn_mask_q_seqlen, 

175 stride_attn_mask_kv_seqlen, 

176 stride_o_batch, 

177 stride_o_head, 

178 stride_o_seqlen, 

179 stride_o_headsize, 

180 Z, 

181 q_head_num, 

182 kv_head_num, 

183 GROUP_HEAD: tl.constexpr, 

184 Q_CTX, 

185 KV_CTX, 

186 HEAD_DIM: tl.constexpr, 

187 BLOCK_M: tl.constexpr, 

188 BLOCK_N: tl.constexpr, 

189 STAGE: tl.constexpr, 

190 HAS_ATTN_MASK: tl.constexpr, 

191 PRE_LOAD_V: tl.constexpr, 

192): 

193 tl.static_assert(BLOCK_N <= HEAD_DIM) 

194 start_m = tl.program_id(0) 

195 off_hz = tl.program_id(1) 

196 batch_id = off_hz // q_head_num 

197 head_id = off_hz % q_head_num 

198 kv_head_id = head_id // GROUP_HEAD 

199 

200 q_offset = ( 

201 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head 

202 ) 

203 o_offset = ( 

204 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head 

205 ) 

206 kv_offset = ( 

207 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head 

208 ) 

209 

210 offs_headsize = tl.arange(0, HEAD_DIM) 

211 

212 # initialize offsets 

213 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 

214 q_load_mask = offs_m < Q_CTX 

215 offs_n = tl.arange(0, BLOCK_N) 

216 

217 Q_block_ptr = ( 

218 Q 

219 + q_offset 

220 + offs_m[:, None] * stride_q_seqlen 

221 + offs_headsize[None, :] * stride_q_headsize 

222 ) 

223 K_block_ptr = ( 

224 K 

225 + kv_offset 

226 + offs_n[None, :] * stride_k_seqlen 

227 + offs_headsize[:, None] * stride_k_headsize 

228 ) 

229 V_block_ptr = ( 

230 V 

231 + kv_offset 

232 + offs_n[:, None] * stride_v_seqlen 

233 + offs_headsize[None, :] * stride_v_headsize 

234 ) 

235 

236 if HAS_ATTN_MASK: 

237 attn_mask_offset = ( 

238 batch_id.to(tl.int64) * stride_attn_mask_batch 

239 + head_id.to(tl.int64) * stride_attn_mask_head 

240 ) 

241 mask_block_ptr = ( 

242 attn_mask 

243 + attn_mask_offset 

244 + offs_m[:, None] * stride_attn_mask_q_seqlen 

245 + offs_n[None, :] * stride_attn_mask_kv_seqlen 

246 ) 

247 else: 

248 mask_block_ptr = None 

249 

250 O_block_ptr = ( 

251 Out 

252 + o_offset 

253 + offs_m[:, None] * stride_o_seqlen 

254 + offs_headsize[None, :] * stride_o_headsize 

255 ) 

256 

257 # initialize pointer to m and l 

258 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 

259 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 

260 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

261 # load scales 

262 qk_scale = sm_scale 

263 # qk_scale *= 1.44269504 # 1/log(2) 

264 # load query: it will stay in SRAM throughout 

265 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0) 

266 # stage 1: off-band 

267 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 

268 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 

269 if STAGE & 1: 

270 acc, l_i, m_i = _attn_fwd_inner( 

271 acc, 

272 l_i, 

273 m_i, 

274 query, 

275 K_block_ptr, 

276 V_block_ptr, 

277 mask_block_ptr, 

278 stride_k_seqlen, 

279 stride_v_seqlen, 

280 stride_attn_mask_kv_seqlen, 

281 start_m, 

282 qk_scale, 

283 q_load_mask, 

284 BLOCK_M, 

285 HEAD_DIM, 

286 BLOCK_N, 

287 4 - STAGE, 

288 offs_m, 

289 offs_n, 

290 KV_CTX, 

291 V.dtype.element_ty == tl.float8e5, 

292 HAS_ATTN_MASK, 

293 PRE_LOAD_V, 

294 ) 

295 # stage 2: on-band 

296 if STAGE & 2: 

297 # barrier makes it easier for compielr to schedule the 

298 # two loops independently 

299 acc, l_i, m_i = _attn_fwd_inner( 

300 acc, 

301 l_i, 

302 m_i, 

303 query, 

304 K_block_ptr, 

305 V_block_ptr, 

306 mask_block_ptr, 

307 stride_k_seqlen, 

308 stride_v_seqlen, 

309 stride_attn_mask_kv_seqlen, 

310 start_m, 

311 qk_scale, 

312 q_load_mask, 

313 BLOCK_M, 

314 HEAD_DIM, 

315 BLOCK_N, 

316 2, 

317 offs_m, 

318 offs_n, 

319 KV_CTX, 

320 V.dtype.element_ty == tl.float8e5, 

321 HAS_ATTN_MASK, 

322 PRE_LOAD_V, 

323 ) 

324 # epilogue 

325 m_i += tl.math.log2(l_i) 

326 acc = acc / l_i[:, None] 

327 m_ptrs = M + off_hz * Q_CTX + offs_m 

328 tl.store(m_ptrs, m_i, mask=q_load_mask) 

329 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) 

330 

331 

332@triton.jit 

333def _attn_bwd_preprocess( 

334 O_, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr 

335): # noqa: E741 

336 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

337 mask = off_m < Q_CTX 

338 

339 off_hz = tl.program_id(1) 

340 off_n = tl.arange(0, D_HEAD) 

341 # load 

342 o = tl.load( 

343 O_ + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

344 mask=mask[:, None], 

345 other=0.0, 

346 ) 

347 do = tl.load( 

348 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

349 mask=mask[:, None], 

350 other=0.0, 

351 ).to(tl.float32) 

352 delta = tl.sum(o * do, axis=1) 

353 # write-back 

354 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) 

355 

356 

357# The main inner-loop logic for computing dK and dV. 

358@triton.jit 

359def _attn_bwd_dkdv( 

360 dk, 

361 dv, # 

362 Q, 

363 key, 

364 value, 

365 sm_scale, # 

366 DO, # 

367 M, 

368 D, # 

369 # shared by Q/K/V/DO. 

370 stride_tok, 

371 stride_d, # 

372 H, 

373 Q_CTX, 

374 KV_CTX, 

375 BLOCK_M1: tl.constexpr, # 

376 BLOCK_N1: tl.constexpr, # 

377 BLOCK_DMODEL: tl.constexpr, # 

378 # Filled in by the wrapper. 

379 start_n, 

380 start_m, 

381 num_steps, # 

382 MASK: tl.constexpr, 

383): 

384 # BLOCK_M1: 32 

385 # BLOCK_N1: 128 

386 offs_n = start_n + tl.arange(0, BLOCK_N1) 

387 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) 

388 

389 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) 

390 

391 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 

392 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 

393 curr_m = start_m 

394 step_m = BLOCK_M1 

395 for blk_idx in range(num_steps): 

396 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) 

397 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) 

398 

399 qT_ptrs = ( 

400 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 

401 ) # (BLOCK_DMODEL, BLOCK_M1) 

402 do_ptrs = ( 

403 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

404 ) # (BLOCK_M1, BLOCK_DMODEL) 

405 

406 qT = tl.load( 

407 qT_ptrs, mask=offs_m_mask[None, :], other=0.0 

408 ) # (BLOCK_DMODEL, BLOCK_M1) 

409 

410 # Load m before computing qk to reduce pipeline stall. 

411 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) 

412 

413 # key: (BLOCK_N1, BLOCK_DMODEL) 

414 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) 

415 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) 

416 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) 

417 pT = tl.math.exp2(qkT - m) 

418 # pT = tl.math.exp2(qkT - m[None, :]) 

419 

420 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ 

421 :, None 

422 ] # (BLOCK_N1, BLOCK_M1) 

423 # Autoregressive masking. 

424 if MASK: 

425 mask &= offs_m[None, :] >= offs_n[:, None] 

426 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) 

427 

428 do = tl.load(do_ptrs) 

429 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL) 

430 

431 # Compute dV. 

432 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) 

433 # D (= delta) is pre-divided by ds_scale. 

434 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) 

435 

436 # Compute dP and dS. 

437 dpT = tl.dot(value, tl.trans(do)).to( 

438 tl.float32 

439 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) 

440 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) 

441 dsT = dsT.to(qT.dtype) 

442 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) 

443 dsT = tl.where( 

444 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 

445 ) # (BLOCK_N1, BLOCK_M1) 

446 dk += tl.dot( 

447 dsT, tl.trans(qT) 

448 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) 

449 # Increment pointers. 

450 curr_m += step_m 

451 return dk, dv 

452 

453 

454# the main inner-loop logic for computing dQ 

455@triton.jit 

456def _attn_bwd_dq( 

457 dq, 

458 query, 

459 K, 

460 V, # 

461 do, 

462 m, 

463 D, 

464 # shared by Q/K/V/DO. 

465 stride_tok, 

466 stride_d, # 

467 H, 

468 Q_CTX, # 

469 KV_CTX, # 

470 BLOCK_M2: tl.constexpr, # 

471 BLOCK_N2: tl.constexpr, # 

472 BLOCK_DMODEL: tl.constexpr, 

473 # Filled in by the wrapper. 

474 start_m, 

475 start_n, 

476 num_steps, # 

477 MASK: tl.constexpr, 

478): 

479 offs_m = start_m + tl.arange(0, BLOCK_M2) 

480 offs_m_mask = offs_m < Q_CTX 

481 

482 offs_k = tl.arange(0, BLOCK_DMODEL) 

483 # D (= delta) is pre-divided by ds_scale. 

484 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) 

485 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 

486 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 

487 curr_n = start_n 

488 step_n = BLOCK_N2 

489 for blk_idx in range(num_steps): 

490 offs_n = curr_n + tl.arange(0, BLOCK_N2) 

491 offs_n_mask = offs_n < KV_CTX 

492 

493 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

494 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

495 

496 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

497 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

498 qk = tl.dot(query, kT) 

499 p = tl.math.exp2(qk - m) 

500 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] 

501 # Autoregressive masking. 

502 if MASK: 

503 # mask = (offs_m[:, None] >= offs_n[None, :]) 

504 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] 

505 mask &= offs_m[:, None] >= offs_n[None, :] 

506 p = tl.where(mask, p, 0.0) 

507 # Compute dP and dS. 

508 dp = tl.dot(do, vT).to(tl.float32) 

509 ds = p * (dp - Di[:, None]) 

510 ds = tl.where(mask, ds, 0.0).to(kT.dtype) 

511 # Compute dQ. 

512 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 

513 dq += tl.dot(ds, tl.trans(kT)) 

514 # Increment pointers. 

515 curr_n += step_n 

516 return dq 

517 

518 

519config_backward = runtime.get_tuned_config("attention_bwd") 

520 

521 

522@libentry() 

523@libtuner( 

524 configs=config_backward, 

525 key=["KV_CTX", "BLOCK_DMODEL"], 

526) 

527@triton.jit 

528def _attn_bwd( 

529 Q, 

530 K, 

531 V, 

532 sm_scale, # 

533 DO, # 

534 DQ, 

535 DK, 

536 DV, # 

537 M, 

538 D, 

539 # shared by Q/K/V/DO. 

540 stride_z, 

541 stride_h, 

542 stride_tok, 

543 stride_d, # 

544 kv_stride_z, 

545 kv_stride_h, # 

546 H, # query head num 

547 Q_CTX, # 

548 KV_CTX, # 

549 kv_head_num, # 

550 GROUP_HEAD: tl.constexpr, # 

551 BLOCK_M1: tl.constexpr, # 

552 BLOCK_N1: tl.constexpr, # 

553 BLOCK_M2: tl.constexpr, # 

554 BLOCK_N2: tl.constexpr, # 

555 BLK_SLICE_FACTOR: tl.constexpr, # 

556 BLOCK_DMODEL: tl.constexpr, 

557): 

558 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.") 

559 

560 LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 

561 

562 bhid = tl.program_id(2) 

563 off_chz = (bhid * Q_CTX).to(tl.int64) 

564 batch_id = bhid // H 

565 q_head_id = bhid % H 

566 kv_head_id = q_head_id // GROUP_HEAD 

567 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) 

568 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) 

569 

570 pid = tl.program_id(0) 

571 

572 # offset pointers for batch/head 

573 Q += adj 

574 K += kv_adj 

575 V += kv_adj 

576 DO += adj 

577 DQ += adj 

578 DK += adj 

579 DV += adj 

580 M += off_chz 

581 D += off_chz 

582 

583 # load scales 

584 offs_k = tl.arange(0, BLOCK_DMODEL) 

585 

586 start_n = pid * BLOCK_N1 

587 start_m = start_n 

588 

589 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 

590 offs_n = start_n + tl.arange(0, BLOCK_N1) 

591 offs_n_mask = offs_n < KV_CTX 

592 

593 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

594 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

595 

596 # load K and V: they stay in SRAM throughout the inner loop. 

597 key = tl.load( 

598 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

599 mask=offs_n_mask[:, None], 

600 other=0.0, 

601 ) 

602 value = tl.load( 

603 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

604 mask=offs_n_mask[:, None], 

605 other=0.0, 

606 ) 

607 

608 num_steps = BLOCK_N1 // MASK_BLOCK_M1 

609 

610 dk, dv = _attn_bwd_dkdv( 

611 dk, 

612 dv, # 

613 Q, 

614 key, 

615 value, 

616 sm_scale, # 

617 DO, # 

618 M, 

619 D, # 

620 stride_tok, 

621 stride_d, # 

622 H, 

623 Q_CTX, # 

624 KV_CTX, # 

625 MASK_BLOCK_M1, 

626 BLOCK_N1, 

627 BLOCK_DMODEL, # 

628 start_n, 

629 start_m, 

630 num_steps, # 

631 MASK=True, # 

632 ) 

633 

634 # Compute dK and dV for non-masked blocks. 

635 start_m += num_steps * MASK_BLOCK_M1 

636 remaining_m = Q_CTX - start_m 

637 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 

638 

639 if num_steps > 0 and start_m < Q_CTX: 

640 dk, dv = _attn_bwd_dkdv( # 

641 dk, 

642 dv, # 

643 Q, 

644 key, 

645 value, 

646 sm_scale, # 

647 DO, # 

648 M, 

649 D, # 

650 stride_tok, 

651 stride_d, # 

652 H, 

653 Q_CTX, # 

654 KV_CTX, # 

655 BLOCK_M1, 

656 BLOCK_N1, 

657 BLOCK_DMODEL, # 

658 start_n, 

659 start_m, 

660 num_steps, # 

661 MASK=False, # 

662 ) 

663 # tl.device_print("dv: ", dv) 

664 

665 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

666 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) 

667 

668 # Write back dK. 

669 dk *= sm_scale 

670 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

671 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) 

672 

673 # THIS BLOCK DOES DQ: 

674 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 

675 start_m = pid * BLOCK_M2 

676 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX 

677 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 

678 

679 offs_m = start_m + tl.arange(0, BLOCK_M2) 

680 offs_m_mask = offs_m < Q_CTX 

681 

682 query = tl.load( 

683 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

684 mask=offs_m_mask[:, None], 

685 other=0.0, 

686 ) 

687 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) 

688 do = tl.load( 

689 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

690 mask=offs_m_mask[:, None], 

691 other=0.0, 

692 ) 

693 

694 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) 

695 m = m[:, None] 

696 

697 # Stage 1 - Compute dQ for masked (diagonal) blocks. 

698 # NOTE: This code scans each row of QK^T backward (from right to left, 

699 # but inside each call to _attn_bwd_dq, from left to right), but that's 

700 # not due to anything important. I just wanted to reuse the loop 

701 # structure for dK & dV above as much as possible. 

702 

703 if num_steps > 0: 

704 dq = _attn_bwd_dq( 

705 dq, 

706 query, 

707 K, 

708 V, # 

709 do, 

710 m, 

711 D, # 

712 stride_tok, 

713 stride_d, # 

714 H, 

715 Q_CTX, # 

716 KV_CTX, # 

717 BLOCK_M2, 

718 MASK_BLOCK_N2, 

719 BLOCK_DMODEL, # 

720 start_m, 

721 start_n, 

722 num_steps, # 

723 MASK=True, # 

724 ) 

725 

726 # Stage 2 - non-masked blocks 

727 stage2_end_n = start_n 

728 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2 

729 

730 if stage2_num_steps > 0: 

731 dq = _attn_bwd_dq( 

732 dq, 

733 query, 

734 K, 

735 V, # 

736 do, 

737 m, 

738 D, # 

739 stride_tok, 

740 stride_d, # 

741 H, 

742 Q_CTX, # 

743 KV_CTX, # 

744 BLOCK_M2, 

745 BLOCK_N2, 

746 BLOCK_DMODEL, # 

747 start_m, 

748 stage2_end_n - stage2_num_steps * BLOCK_N2, 

749 stage2_num_steps, # 

750 MASK=False, # 

751 ) 

752 # Write back dQ. 

753 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

754 dq *= LN2 

755 # tl.store(dq_ptrs, dq) 

756 

757 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) 

758 

759 

760def scaled_dot_product_attention_forward( 

761 query, 

762 key, 

763 value, 

764 attn_mask=None, 

765 dropout_p=0.0, 

766 is_causal=False, 

767 scale=None, 

768 enable_gqa=False, 

769): 

770 logger.debug("GEMS_ASCEND SCALED DOT PRODUCT ATTENTION FORWARD") 

771 # shape constraints 

772 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

773 # when v is in float8_e5m2 it is transposed. 

774 HEAD_DIM_V = value.shape[-1] 

775 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

776 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

777 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

778 

779 o = torch.empty_like(query, dtype=value.dtype) 

780 

781 stage = 3 if is_causal else 1 

782 

783 if scale is None: 

784 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

785 else: 

786 sm_scale = scale 

787 

788 q_head_num = query.shape[1] 

789 kv_head_num = key.shape[1] 

790 assert enable_gqa or q_head_num == kv_head_num, ( 

791 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " 

792 "enable_gqa must be True to support different head numbers." 

793 ) 

794 

795 grid = lambda args: ( 

796 triton.cdiv(query.shape[2], args["BLOCK_M"]), 

797 query.shape[0] * query.shape[1], 

798 1, 

799 ) 

800 

801 if attn_mask is not None: 

802 HAS_ATTN_MASK = True 

803 if attn_mask.dtype == torch.bool: 

804 attn_mask = attn_mask.to(query.dtype) * -1.0e6 

805 stride_attn_mask_batch = attn_mask.stride(0) 

806 stride_attn_mask_head = attn_mask.stride(1) 

807 stride_attn_mask_q_seqlen = attn_mask.stride(2) 

808 stride_attn_mask_kv_seqlen = attn_mask.stride(3) 

809 else: 

810 HAS_ATTN_MASK = False 

811 stride_attn_mask_batch = 1 

812 stride_attn_mask_head = 1 

813 stride_attn_mask_q_seqlen = 1 

814 stride_attn_mask_kv_seqlen = 1 

815 

816 M = torch.empty( 

817 (query.shape[0], query.shape[1], query.shape[2]), 

818 device=query.device, 

819 dtype=torch.float32, 

820 ) 

821 

822 with torch_device_fn.device(query.device): 

823 _attn_fwd[grid]( 

824 query, 

825 key, 

826 value, 

827 attn_mask, 

828 sm_scale, 

829 M, 

830 o, # 

831 query.stride(0), 

832 query.stride(1), 

833 query.stride(2), 

834 query.stride(3), # 

835 key.stride(0), 

836 key.stride(1), 

837 key.stride(2), 

838 key.stride(3), # 

839 value.stride(0), 

840 value.stride(1), 

841 value.stride(2), 

842 value.stride(3), # 

843 stride_attn_mask_batch, 

844 stride_attn_mask_head, 

845 stride_attn_mask_q_seqlen, 

846 stride_attn_mask_kv_seqlen, # 

847 o.stride(0), 

848 o.stride(1), 

849 o.stride(2), 

850 o.stride(3), # 

851 query.shape[0], 

852 q_head_num, 

853 kv_head_num, # 

854 q_head_num // kv_head_num, # group_head 

855 query.shape[2], # 

856 key.shape[2], # 

857 HEAD_DIM_K, # 

858 STAGE=stage, # 

859 HAS_ATTN_MASK=HAS_ATTN_MASK, # 

860 sync_solver=True, 

861 ) 

862 return o, M 

863 

864 

865def scaled_dot_product_attention_backward( 

866 do, 

867 query, 

868 key, 

869 value, 

870 o, 

871 M, 

872 attn_mask=None, 

873 dropout_p=0.0, 

874 is_causal=False, 

875 scale=None, 

876 enable_gqa=False, 

877): 

878 logger.debug("GEMS_ASCEND SCALED DOT PRODUCT ATTENTION BACKWARD") 

879 # shape constraints 

880 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

881 # when v is in float8_e5m2 it is transposed. 

882 HEAD_DIM_V = value.shape[-1] 

883 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

884 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

885 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

886 

887 if scale is None: 

888 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

889 else: 

890 sm_scale = scale 

891 

892 assert do.is_contiguous() 

893 assert ( 

894 query.is_contiguous() 

895 and key.is_contiguous() 

896 and value.is_contiguous() 

897 and o.is_contiguous() 

898 ) 

899 assert query.stride() == o.stride() == do.stride() 

900 assert key.stride() == value.stride() 

901 

902 BLOCK_DMODEL = HEAD_DIM_K 

903 BATCH, Q_HEAD, Q_CTX = query.shape[:3] 

904 _, KV_HEAD, KV_CTX = key.shape[:3] 

905 group_head = Q_HEAD // KV_HEAD 

906 

907 # NUM_WARPS, NUM_STAGES = 4, 1 

908 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 

909 BLK_SLICE_FACTOR = 2 

910 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 

911 

912 RCP_LN2 = 1.0 / math.log(2) 

913 

914 arg_k = key * (sm_scale * RCP_LN2) 

915 # PRE_BLOCK = 128 

916 PRE_BLOCK = 256 

917 

918 # PRE_BLOCK = 32 

919 # assert N_CTX % PRE_BLOCK == 0 

920 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) 

921 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) 

922 

923 delta = torch.empty_like(M) 

924 

925 # NOTE that dk & dv always have the same number of heads as q 

926 dq = torch.empty_like(query).contiguous() 

927 dk = torch.empty( 

928 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K), 

929 device=key.device, 

930 dtype=key.dtype, 

931 memory_format=torch.contiguous_format, 

932 ) 

933 dv = torch.empty( 

934 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V), 

935 device=value.device, 

936 dtype=value.dtype, 

937 memory_format=torch.contiguous_format, 

938 ) 

939 

940 _attn_bwd_preprocess[pre_grid]( 

941 o, 

942 do, # 

943 delta, # 

944 BATCH, 

945 Q_HEAD, 

946 Q_CTX, # 

947 BLOCK_M=PRE_BLOCK, 

948 D_HEAD=BLOCK_DMODEL, # 

949 ) 

950 

951 max_block_n1 = ( 

952 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward]) 

953 if config_backward 

954 else 128 

955 ) 

956 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD) 

957 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") 

958 # logger.info(f"{M.shape=}") 

959 

960 _attn_bwd[grid]( 

961 query, 

962 arg_k, 

963 value, 

964 sm_scale, 

965 do, 

966 dq, 

967 dk, 

968 dv, # 

969 M, 

970 delta, # 

971 query.stride(0), 

972 query.stride(1), 

973 query.stride(2), 

974 query.stride(3), # 

975 key.stride(0), 

976 key.stride(1), # 

977 Q_HEAD, 

978 Q_CTX, # 

979 KV_CTX, # 

980 KV_HEAD, # 

981 GROUP_HEAD=group_head, # 

982 # BLOCK_M1=BLOCK_M1, 

983 # BLOCK_N1=BLOCK_N1, # 

984 # BLOCK_M2=BLOCK_M2, 

985 # BLOCK_N2=BLOCK_N2, # 

986 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 

987 BLOCK_DMODEL=BLOCK_DMODEL, # 

988 # num_warps=NUM_WARPS, # 

989 # num_stages=NUM_STAGES, # 

990 ) 

991 

992 if group_head > 1: 

993 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) 

994 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) 

995 dk = dk.sum(dim=2) 

996 dv = dv.sum(dim=2) 

997 

998 return dq, dk, dv 

999 

1000 

1001class ScaleDotProductAttention(torch.autograd.Function): 

1002 @staticmethod 

1003 def forward( 

1004 ctx, 

1005 query, 

1006 key, 

1007 value, 

1008 attn_mask=None, 

1009 dropout_p=0.0, 

1010 is_causal=False, 

1011 scale=None, 

1012 enable_gqa=False, 

1013 ): 

1014 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5) 

1015 o, M = scaled_dot_product_attention_forward( 

1016 query, 

1017 key, 

1018 value, 

1019 attn_mask, 

1020 dropout_p, 

1021 is_causal, 

1022 sm_scale, 

1023 enable_gqa, 

1024 ) 

1025 

1026 ctx.save_for_backward(query, key, value, o, M) 

1027 ctx.sm_scale = sm_scale 

1028 ctx.causal = is_causal 

1029 ctx.enable_gqa = enable_gqa 

1030 return o 

1031 

1032 @staticmethod 

1033 def backward(ctx, do): 

1034 query, key, value, o, M = ctx.saved_tensors 

1035 is_causal = ctx.causal 

1036 enable_gqa = ctx.enable_gqa 

1037 sm_scale = ctx.sm_scale 

1038 dq, dk, dv = scaled_dot_product_attention_backward( 

1039 do, 

1040 query, 

1041 key, 

1042 value, 

1043 o, 

1044 M, 

1045 attn_mask=None, 

1046 dropout_p=0.0, 

1047 is_causal=is_causal, 

1048 scale=sm_scale, 

1049 enable_gqa=enable_gqa, 

1050 ) 

1051 return dq, dk, dv, None, None, None, None, None 

1052 

1053 

1054def scaled_dot_product_attention( 

1055 query, 

1056 key, 

1057 value, 

1058 attn_mask=None, 

1059 dropout_p=0.0, 

1060 is_causal=False, 

1061 scale=None, 

1062 enable_gqa=False, 

1063): 

1064 return ScaleDotProductAttention.apply( 

1065 query, 

1066 key, 

1067 value, 

1068 attn_mask, 

1069 dropout_p, 

1070 is_causal, 

1071 scale, 

1072 enable_gqa, 

1073 ) 

1074 

1075 

1076def flash_attention_forward( 

1077 query, 

1078 key, 

1079 value, 

1080 cumulative_sequence_length_q, 

1081 cumulative_sequence_length_k, 

1082 max_q, 

1083 max_k, 

1084 dropout_p, 

1085 is_causal, 

1086 return_debug_mask, 

1087 *, 

1088 scale=None, 

1089 softcap=0.0, 

1090 window_size_left=None, 

1091 window_size_right=None, 

1092 seqused_k=None, 

1093 alibi_slopes=None, 

1094 disable_splitkv=False, 

1095): 

1096 logger.debug("GEMS_ASCEND FLASH_ATTENTION_FORWARD") 

1097 assert ( 

1098 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None 

1099 ), "varlen is not supported yet." 

1100 

1101 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

1102 HEAD_DIM_V = value.shape[-1] 

1103 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

1104 original_head_dim = HEAD_DIM_K 

1105 supported_head_dims = (16, 32, 64, 96, 128, 192, 256) 

1106 if HEAD_DIM_K not in supported_head_dims: 

1107 padded_head_dim = None 

1108 for d in supported_head_dims: 

1109 if d >= HEAD_DIM_K: 

1110 padded_head_dim = d 

1111 break 

1112 assert ( 

1113 padded_head_dim is not None 

1114 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}" 

1115 pad = padded_head_dim - HEAD_DIM_K 

1116 query = F.pad(query, (0, pad)) 

1117 key = F.pad(key, (0, pad)) 

1118 value = F.pad(value, (0, pad)) 

1119 HEAD_DIM_K = padded_head_dim 

1120 

1121 softmax_scale = scale or 1.0 / (original_head_dim**0.5) 

1122 if window_size_left is not None: 

1123 non_null_window_left = window_size_left 

1124 else: 

1125 non_null_window_left = -1 

1126 if window_size_right is not None: 

1127 non_null_window_right = window_size_right 

1128 else: 

1129 non_null_window_right = -1 

1130 

1131 out = torch.empty_like(query) 

1132 if cumulative_sequence_length_q is not None: 

1133 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd( 

1134 query, 

1135 key, 

1136 value, 

1137 out, 

1138 cumulative_sequence_length_q, 

1139 cumulative_sequence_length_k, 

1140 seqused_k, 

1141 None, 

1142 None, # block_table 

1143 alibi_slopes, 

1144 max_q, 

1145 max_k, 

1146 dropout_p, 

1147 scale, 

1148 False, 

1149 is_causal, 

1150 non_null_window_left, 

1151 non_null_window_right, 

1152 softcap, 

1153 return_debug_mask and dropout_p > 0, 

1154 None, 

1155 ) 

1156 else: 

1157 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( 

1158 query, 

1159 key, 

1160 value, 

1161 out, 

1162 alibi_slopes, 

1163 dropout_p, 

1164 softmax_scale, 

1165 is_causal, 

1166 non_null_window_left, 

1167 non_null_window_right, 

1168 softcap, 

1169 return_debug_mask, 

1170 disable_splitkv=disable_splitkv, 

1171 ) 

1172 

1173 if HEAD_DIM_K != original_head_dim: 

1174 out = out[..., :original_head_dim] 

1175 return (out, lse, philox_seed, philox_offset, p) 

1176 

1177 

1178# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py 

1179def maybe_contiguous(x): 

1180 return x.contiguous() if x is not None and x.stride(-1) != 1 else x 

1181 

1182 

1183def flash_attn_varlen_func( 

1184 q, 

1185 k, 

1186 v, 

1187 max_seqlen_q, 

1188 cu_seqlens_q, 

1189 max_seqlen_k, 

1190 cu_seqlens_k=None, # only used for non-paged prefill 

1191 seqused_k=None, 

1192 q_v=None, 

1193 dropout_p=0.0, 

1194 softmax_scale=None, 

1195 causal=False, 

1196 window_size=None, 

1197 softcap=0.0, # 0.0 means deactivated 

1198 alibi_slopes=None, 

1199 deterministic=False, 

1200 return_attn_probs=False, 

1201 block_table=None, 

1202 return_softmax_lse=False, 

1203 out=None, 

1204 # Dummy FA3 arguments 

1205 scheduler_metadata=None, 

1206 q_descale=None, 

1207 k_descale=None, 

1208 v_descale=None, 

1209 s_aux=None, 

1210 num_splits: int = 0, 

1211 cp_world_size: int = 1, 

1212 cp_rank: int = 0, 

1213 cp_tot_seqused_k=None, 

1214 fa_version: int = 2, 

1215): 

1216 """dropout_p should be set to 0.0 during evaluation 

1217 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1218 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1219 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1220 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1221 

1222 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1223 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1224 1 1 1 1 0 

1225 1 1 1 1 1 

1226 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1227 0 0 

1228 0 0 

1229 0 0 

1230 1 0 

1231 1 1 

1232 If the row of the mask is all zero, the output will be zero. 

1233 

1234 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1235 will only attend to keys between 

1236 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1237 

1238 Arguments: 

1239 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1240 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1241 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1242 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1243 of the sequences in the batch, used to index into q. 

1244 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1245 of the sequences in the batch, used to index into kv. 

1246 max_seqlen_q: int. Maximum query sequence length in the batch. 

1247 max_seqlen_k: int. Maximum key sequence length in the batch. 

1248 dropout_p: float. Dropout probability. 

1249 softmax_scale: float. The scaling of QK^T before applying softmax. 

1250 Default to 1 / sqrt(headdim). 

1251 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1252 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1253 softcap: float. Anything > 0 activates softcapping attention. 

1254 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1255 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1256 is added to the attention score of query i and key j. 

1257 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1258 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1259 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1260 testing only. The returned probabilities are not guaranteed to be correct 

1261 (they might not have the right scaling). 

1262 Return: 

1263 out: (total, nheads, headdim). 

1264 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1265 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1266 normalization factor). 

1267 """ 

1268 if fa_version != 2: 

1269 raise RuntimeError("Only FA2 is implemented.") 

1270 if num_splits > 0: 

1271 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1272 if use_c_extension: 

1273 logger.debug("GEMS_ASCEND FLASH_ATTN_VARLEN_FUNC(C EXTENSION)") 

1274 with torch_device_fn.device(q.device): 

1275 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1276 q, 

1277 k, 

1278 v, 

1279 max_seqlen_q, 

1280 cu_seqlens_q, 

1281 max_seqlen_k, 

1282 cu_seqlens_k, 

1283 seqused_k, 

1284 q_v, 

1285 dropout_p, 

1286 softmax_scale, 

1287 causal, 

1288 window_size, 

1289 softcap, 

1290 alibi_slopes, 

1291 deterministic, 

1292 return_attn_probs, 

1293 block_table, 

1294 return_softmax_lse, 

1295 out, 

1296 scheduler_metadata, 

1297 q_descale, 

1298 k_descale, 

1299 v_descale, 

1300 s_aux, 

1301 num_splits, 

1302 cp_world_size, 

1303 cp_rank, 

1304 cp_tot_seqused_k, 

1305 fa_version, 

1306 ) 

1307 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1308 else: 

1309 logger.debug("GEMS_ASCEND FLASH_ATTN_VARLEN_FUNC") 

1310 assert ( 

1311 cu_seqlens_k is not None or seqused_k is not None 

1312 ), "cu_seqlens_k or seqused_k must be provided" 

1313 assert ( 

1314 cu_seqlens_k is None or seqused_k is None 

1315 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1316 assert ( 

1317 block_table is None or seqused_k is not None 

1318 ), "seqused_k must be provided if block_table is provided" 

1319 if softmax_scale is None: 

1320 softmax_scale = q.shape[-1] ** (-0.5) 

1321 # custom op does not support non-tuple input 

1322 if window_size is None: 

1323 real_window_size = (-1, -1) 

1324 else: 

1325 assert len(window_size) == 2 

1326 real_window_size = (window_size[0], window_size[1]) 

1327 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1328 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1329 max_seqlen_q = ( 

1330 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1331 ) 

1332 max_seqlen_k = ( 

1333 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1334 ) 

1335 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd( 

1336 q, 

1337 k, 

1338 v, 

1339 out, 

1340 cu_seqlens_q, 

1341 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1342 # still wants it so we pass all zeros 

1343 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1344 seqused_k, 

1345 None, 

1346 block_table, 

1347 alibi_slopes, 

1348 max_seqlen_q, 

1349 max_seqlen_k, 

1350 dropout_p, 

1351 softmax_scale, 

1352 False, 

1353 causal, 

1354 real_window_size[0], 

1355 real_window_size[1], 

1356 softcap, 

1357 return_softmax_lse and dropout_p > 0, 

1358 None, 

1359 ) 

1360 

1361 return (out, softmax_lse) if return_softmax_lse else out