Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/attention.py: 0%

399 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from functools import partial 

4 

5import torch 

6import torch.nn.functional as F 

7import triton 

8import triton.language as tl 

9 

10from flag_gems import runtime 

11from flag_gems.config import use_c_extension 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry, libtuner 

14 

15from .flash_api import mha_fwd, mha_varlan_fwd 

16from .flash_kernel import keep 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html 

22@triton.jit 

23def _attn_fwd_inner( 

24 acc, 

25 l_i, 

26 m_i, 

27 query, # 

28 K_block_ptr, 

29 V_block_ptr, # 

30 mask_block_ptr, # 

31 stride_k_seqlen, 

32 stride_v_seqlen, 

33 stride_attn_mask_kv_seqlen, # 

34 start_m, 

35 qk_scale, # 

36 q_load_mask, 

37 BLOCK_M: tl.constexpr, 

38 HEAD_DIM: tl.constexpr, 

39 BLOCK_N: tl.constexpr, # 

40 STAGE: tl.constexpr, 

41 offs_m: tl.constexpr, 

42 offs_n: tl.constexpr, # 

43 KV_CTX: tl.constexpr, 

44 fp8_v: tl.constexpr, 

45 HAS_ATTN_MASK: tl.constexpr, 

46 PRE_LOAD_V: tl.constexpr, 

47): 

48 # range of values handled by this stage 

49 if STAGE == 1: 

50 lo, hi = 0, start_m * BLOCK_M 

51 elif STAGE == 2: 

52 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

53 # causal = False 

54 else: 

55 lo, hi = 0, KV_CTX 

56 

57 K_block_ptr += lo * stride_k_seqlen 

58 V_block_ptr += lo * stride_v_seqlen 

59 if HAS_ATTN_MASK: 

60 mask_block_ptr += lo * stride_attn_mask_kv_seqlen 

61 

62 LOG2E = 1.44269504 # log2(e) constant 

63 

64 # loop over key, value and update accumulator 

65 for start_n in range(lo, hi, BLOCK_N): 

66 kv_load_mask = (start_n + offs_n) < KV_CTX 

67 # start_n = tl.multiple_of(start_n, BLOCK_N) 

68 # -- compute qk ---- 

69 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0) 

70 if PRE_LOAD_V: 

71 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

72 

73 qk = tl.dot(query, key, allow_tf32=False) 

74 # incase not divisible. 

75 qk = tl.where(kv_load_mask[None, :], qk, -float("inf")) 

76 # qk = qk.to(tl.float32) 

77 

78 if HAS_ATTN_MASK: 

79 attn_mask = tl.load( 

80 mask_block_ptr, 

81 mask=q_load_mask[:, None] & kv_load_mask[None, :], 

82 other=0.0, 

83 ) 

84 

85 if STAGE == 2: 

86 mask = offs_m[:, None] >= (start_n + offs_n[None, :]) 

87 

88 if HAS_ATTN_MASK: 

89 qk = qk * qk_scale + attn_mask 

90 qk *= LOG2E 

91 qk = qk + tl.where(mask, 0, -1.0e6) 

92 else: 

93 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6) 

94 

95 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

96 qk -= m_ij[:, None] 

97 else: 

98 qk *= qk_scale * LOG2E 

99 if HAS_ATTN_MASK: 

100 qk = qk + attn_mask 

101 m_ij = tl.maximum(m_i, tl.max(qk, 1)) 

102 qk = qk - m_ij[:, None] 

103 

104 p = tl.math.exp2(qk) 

105 l_ij = tl.sum(p, 1) 

106 # -- update m_i and l_i 

107 alpha = tl.math.exp2(m_i - m_ij) 

108 l_i = l_i * alpha + l_ij 

109 # -- update output accumulator -- 

110 acc = acc * alpha[:, None] 

111 # update acc 

112 if not PRE_LOAD_V: 

113 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0) 

114 if fp8_v: 

115 p = p.to(tl.float8e5) 

116 else: 

117 p = p.to(query.dtype) 

118 p = p.to(value.dtype) 

119 acc = tl.dot(p, value, acc, allow_tf32=False) 

120 # update m_i and l_i 

121 m_i = m_ij 

122 

123 K_block_ptr += BLOCK_N * stride_k_seqlen 

124 V_block_ptr += BLOCK_N * stride_v_seqlen 

125 

126 if HAS_ATTN_MASK: 

127 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen 

128 

129 return acc, l_i, m_i 

130 

131 

132# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim, 

133# we need to generate more configs. 

134configs = runtime.get_tuned_config("attention") 

135SMALL_HEAD_DIM_CONFIGS = [ 

136 triton.Config( 

137 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w 

138 ) 

139 for BM in [64, 128] 

140 for BN in [16, 32] 

141 for s in [2, 3, 4] 

142 for w in [4, 8] 

143] 

144configs += SMALL_HEAD_DIM_CONFIGS 

145 

146 

147@libentry() 

148@libtuner( 

149 # configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)), 

150 configs=list(filter(partial(keep), configs)), 

151 key=["KV_CTX", "HEAD_DIM"], 

152) 

153@triton.jit 

154def _attn_fwd( 

155 Q, 

156 K, 

157 V, 

158 attn_mask, 

159 sm_scale, 

160 M, 

161 Out, # 

162 stride_q_batch, 

163 stride_q_head, 

164 stride_q_seqlen, 

165 stride_q_headsize, 

166 stride_k_batch, 

167 stride_k_head, 

168 stride_k_seqlen, 

169 stride_k_headsize, 

170 stride_v_batch, 

171 stride_v_head, 

172 stride_v_seqlen, 

173 stride_v_headsize, 

174 stride_attn_mask_batch, 

175 stride_attn_mask_head, 

176 stride_attn_mask_q_seqlen, 

177 stride_attn_mask_kv_seqlen, 

178 stride_o_batch, 

179 stride_o_head, 

180 stride_o_seqlen, 

181 stride_o_headsize, 

182 Z, 

183 q_head_num, 

184 kv_head_num, 

185 GROUP_HEAD: tl.constexpr, 

186 Q_CTX, 

187 KV_CTX, 

188 HEAD_DIM: tl.constexpr, 

189 BLOCK_M: tl.constexpr, 

190 BLOCK_N: tl.constexpr, 

191 STAGE: tl.constexpr, 

192 HAS_ATTN_MASK: tl.constexpr, 

193 PRE_LOAD_V: tl.constexpr, 

194): 

195 tl.static_assert(BLOCK_N <= HEAD_DIM) 

196 start_m = tl.program_id(0) 

197 off_hz = tl.program_id(1) 

198 batch_id = off_hz // q_head_num 

199 head_id = off_hz % q_head_num 

200 kv_head_id = head_id // GROUP_HEAD 

201 

202 q_offset = ( 

203 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head 

204 ) 

205 o_offset = ( 

206 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head 

207 ) 

208 kv_offset = ( 

209 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head 

210 ) 

211 

212 offs_headsize = tl.arange(0, HEAD_DIM) 

213 

214 # initialize offsets 

215 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) 

216 q_load_mask = offs_m < Q_CTX 

217 offs_n = tl.arange(0, BLOCK_N) 

218 

219 Q_block_ptr = ( 

220 Q 

221 + q_offset 

222 + offs_m[:, None] * stride_q_seqlen 

223 + offs_headsize[None, :] * stride_q_headsize 

224 ) 

225 K_block_ptr = ( 

226 K 

227 + kv_offset 

228 + offs_n[None, :] * stride_k_seqlen 

229 + offs_headsize[:, None] * stride_k_headsize 

230 ) 

231 V_block_ptr = ( 

232 V 

233 + kv_offset 

234 + offs_n[:, None] * stride_v_seqlen 

235 + offs_headsize[None, :] * stride_v_headsize 

236 ) 

237 

238 if HAS_ATTN_MASK: 

239 attn_mask_offset = ( 

240 batch_id.to(tl.int64) * stride_attn_mask_batch 

241 + head_id.to(tl.int64) * stride_attn_mask_head 

242 ) 

243 mask_block_ptr = ( 

244 attn_mask 

245 + attn_mask_offset 

246 + offs_m[:, None] * stride_attn_mask_q_seqlen 

247 + offs_n[None, :] * stride_attn_mask_kv_seqlen 

248 ) 

249 else: 

250 mask_block_ptr = None 

251 

252 O_block_ptr = ( 

253 Out 

254 + o_offset 

255 + offs_m[:, None] * stride_o_seqlen 

256 + offs_headsize[None, :] * stride_o_headsize 

257 ) 

258 

259 # initialize pointer to m and l 

260 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") 

261 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 

262 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) 

263 # load scales 

264 qk_scale = sm_scale 

265 # qk_scale *= 1.44269504 # 1/log(2) 

266 # load query: it will stay in SRAM throughout 

267 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0) 

268 # stage 1: off-band 

269 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE 

270 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE 

271 if STAGE & 1: 

272 acc, l_i, m_i = _attn_fwd_inner( 

273 acc, 

274 l_i, 

275 m_i, 

276 query, 

277 K_block_ptr, 

278 V_block_ptr, 

279 mask_block_ptr, 

280 stride_k_seqlen, 

281 stride_v_seqlen, 

282 stride_attn_mask_kv_seqlen, 

283 start_m, 

284 qk_scale, 

285 q_load_mask, 

286 BLOCK_M, 

287 HEAD_DIM, 

288 BLOCK_N, 

289 4 - STAGE, 

290 offs_m, 

291 offs_n, 

292 KV_CTX, 

293 V.dtype.element_ty == tl.float8e5, 

294 HAS_ATTN_MASK, 

295 PRE_LOAD_V, 

296 ) 

297 # stage 2: on-band 

298 if STAGE & 2: 

299 # barrier makes it easier for compielr to schedule the 

300 # two loops independently 

301 acc, l_i, m_i = _attn_fwd_inner( 

302 acc, 

303 l_i, 

304 m_i, 

305 query, 

306 K_block_ptr, 

307 V_block_ptr, 

308 mask_block_ptr, 

309 stride_k_seqlen, 

310 stride_v_seqlen, 

311 stride_attn_mask_kv_seqlen, 

312 start_m, 

313 qk_scale, 

314 q_load_mask, 

315 BLOCK_M, 

316 HEAD_DIM, 

317 BLOCK_N, 

318 2, 

319 offs_m, 

320 offs_n, 

321 KV_CTX, 

322 V.dtype.element_ty == tl.float8e5, 

323 HAS_ATTN_MASK, 

324 PRE_LOAD_V, 

325 ) 

326 # epilogue 

327 m_i += tl.math.log2(l_i) 

328 acc = acc / l_i[:, None] 

329 m_ptrs = M + off_hz * Q_CTX + offs_m 

330 tl.store(m_ptrs, m_i, mask=q_load_mask) 

331 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) 

332 

333 

334@triton.jit 

335def _attn_bwd_preprocess( 

336 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr 

337): 

338 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

339 mask = off_m < Q_CTX 

340 

341 off_hz = tl.program_id(1) 

342 off_n = tl.arange(0, D_HEAD) 

343 # load 

344 o = tl.load( 

345 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

346 mask=mask[:, None], 

347 other=0.0, 

348 ) 

349 do = tl.load( 

350 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :], 

351 mask=mask[:, None], 

352 other=0.0, 

353 ).to(tl.float32) 

354 delta = tl.sum(o * do, axis=1) 

355 # write-back 

356 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask) 

357 

358 

359# The main inner-loop logic for computing dK and dV. 

360@triton.jit 

361def _attn_bwd_dkdv( 

362 dk, 

363 dv, # 

364 Q, 

365 key, 

366 value, 

367 sm_scale, # 

368 DO, # 

369 M, 

370 D, # 

371 # shared by Q/K/V/DO. 

372 stride_tok, 

373 stride_d, # 

374 H, 

375 Q_CTX, 

376 KV_CTX, 

377 BLOCK_M1: tl.constexpr, # 

378 BLOCK_N1: tl.constexpr, # 

379 BLOCK_DMODEL: tl.constexpr, # 

380 # Filled in by the wrapper. 

381 start_n, 

382 start_m, 

383 num_steps, # 

384 MASK: tl.constexpr, 

385): 

386 # BLOCK_M1: 32 

387 # BLOCK_N1: 128 

388 offs_n = start_n + tl.arange(0, BLOCK_N1) 

389 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, ) 

390 

391 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, ) 

392 

393 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. 

394 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) 

395 curr_m = start_m 

396 step_m = BLOCK_M1 

397 for blk_idx in range(num_steps): 

398 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, ) 

399 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, ) 

400 

401 qT_ptrs = ( 

402 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d 

403 ) # (BLOCK_DMODEL, BLOCK_M1) 

404 do_ptrs = ( 

405 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

406 ) # (BLOCK_M1, BLOCK_DMODEL) 

407 

408 qT = tl.load( 

409 qT_ptrs, mask=offs_m_mask[None, :], other=0.0 

410 ) # (BLOCK_DMODEL, BLOCK_M1) 

411 

412 # Load m before computing qk to reduce pipeline stall. 

413 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, ) 

414 

415 # key: (BLOCK_N1, BLOCK_DMODEL) 

416 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1) 

417 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1) 

418 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1) 

419 pT = tl.math.exp2(qkT - m) 

420 # pT = tl.math.exp2(qkT - m[None, :]) 

421 

422 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[ 

423 :, None 

424 ] # (BLOCK_N1, BLOCK_M1) 

425 # Autoregressive masking. 

426 if MASK: 

427 mask &= offs_m[None, :] >= offs_n[:, None] 

428 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1) 

429 

430 do = tl.load(do_ptrs) 

431 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL) 

432 

433 # Compute dV. 

434 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL) 

435 # D (= delta) is pre-divided by ds_scale. 

436 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, ) 

437 

438 # Compute dP and dS. 

439 dpT = tl.dot(value, tl.trans(do)).to( 

440 tl.float32 

441 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1) 

442 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1) 

443 dsT = dsT.to(qT.dtype) 

444 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1) 

445 dsT = tl.where( 

446 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0 

447 ) # (BLOCK_N1, BLOCK_M1) 

448 dk += tl.dot( 

449 dsT, tl.trans(qT) 

450 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL) 

451 # Increment pointers. 

452 curr_m += step_m 

453 return dk, dv 

454 

455 

456# the main inner-loop logic for computing dQ 

457@triton.jit 

458def _attn_bwd_dq( 

459 dq, 

460 query, 

461 K, 

462 V, # 

463 do, 

464 m, 

465 D, 

466 # shared by Q/K/V/DO. 

467 stride_tok, 

468 stride_d, # 

469 H, 

470 Q_CTX, # 

471 KV_CTX, # 

472 BLOCK_M2: tl.constexpr, # 

473 BLOCK_N2: tl.constexpr, # 

474 BLOCK_DMODEL: tl.constexpr, 

475 # Filled in by the wrapper. 

476 start_m, 

477 start_n, 

478 num_steps, # 

479 MASK: tl.constexpr, 

480): 

481 offs_m = start_m + tl.arange(0, BLOCK_M2) 

482 offs_m_mask = offs_m < Q_CTX 

483 

484 offs_k = tl.arange(0, BLOCK_DMODEL) 

485 # D (= delta) is pre-divided by ds_scale. 

486 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) 

487 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. 

488 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) 

489 curr_n = start_n 

490 step_n = BLOCK_N2 

491 for blk_idx in range(num_steps): 

492 offs_n = curr_n + tl.arange(0, BLOCK_N2) 

493 offs_n_mask = offs_n < KV_CTX 

494 

495 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

496 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d 

497 

498 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

499 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0) 

500 qk = tl.dot(query, kT) 

501 p = tl.math.exp2(qk - m) 

502 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :] 

503 # Autoregressive masking. 

504 if MASK: 

505 # mask = (offs_m[:, None] >= offs_n[None, :]) 

506 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :] 

507 mask &= offs_m[:, None] >= offs_n[None, :] 

508 p = tl.where(mask, p, 0.0) 

509 # Compute dP and dS. 

510 dp = tl.dot(do, vT).to(tl.float32) 

511 ds = p * (dp - Di[:, None]) 

512 ds = tl.where(mask, ds, 0.0).to(kT.dtype) 

513 # Compute dQ. 

514 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. 

515 dq += tl.dot(ds, tl.trans(kT)) 

516 # Increment pointers. 

517 curr_n += step_n 

518 return dq 

519 

520 

521config_backward = runtime.get_tuned_config("attention_bwd") 

522 

523 

524@libentry() 

525@libtuner( 

526 configs=config_backward, 

527 key=["KV_CTX", "BLOCK_DMODEL"], 

528) 

529@triton.jit 

530def _attn_bwd( 

531 Q, 

532 K, 

533 V, 

534 sm_scale, # 

535 DO, # 

536 DQ, 

537 DK, 

538 DV, # 

539 M, 

540 D, 

541 # shared by Q/K/V/DO. 

542 stride_z, 

543 stride_h, 

544 stride_tok, 

545 stride_d, # 

546 kv_stride_z, 

547 kv_stride_h, # 

548 H, # query head num 

549 Q_CTX, # 

550 KV_CTX, # 

551 kv_head_num, # 

552 GROUP_HEAD: tl.constexpr, # 

553 BLOCK_M1: tl.constexpr, # 

554 BLOCK_N1: tl.constexpr, # 

555 BLOCK_M2: tl.constexpr, # 

556 BLOCK_N2: tl.constexpr, # 

557 BLK_SLICE_FACTOR: tl.constexpr, # 

558 BLOCK_DMODEL: tl.constexpr, 

559): 

560 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.") 

561 

562 LN2: tl.constexpr = 0.6931471824645996 # = ln(2) 

563 

564 bhid = tl.program_id(2) 

565 off_chz = (bhid * Q_CTX).to(tl.int64) 

566 batch_id = bhid // H 

567 q_head_id = bhid % H 

568 kv_head_id = q_head_id // GROUP_HEAD 

569 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64) 

570 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64) 

571 

572 pid = tl.program_id(0) 

573 

574 # offset pointers for batch/head 

575 Q += adj 

576 K += kv_adj 

577 V += kv_adj 

578 DO += adj 

579 DQ += adj 

580 DK += adj 

581 DV += adj 

582 M += off_chz 

583 D += off_chz 

584 

585 # load scales 

586 offs_k = tl.arange(0, BLOCK_DMODEL) 

587 

588 start_n = pid * BLOCK_N1 

589 start_m = start_n 

590 

591 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR 

592 offs_n = start_n + tl.arange(0, BLOCK_N1) 

593 offs_n_mask = offs_n < KV_CTX 

594 

595 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

596 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) 

597 

598 # load K and V: they stay in SRAM throughout the inner loop. 

599 key = tl.load( 

600 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

601 mask=offs_n_mask[:, None], 

602 other=0.0, 

603 ) 

604 value = tl.load( 

605 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, 

606 mask=offs_n_mask[:, None], 

607 other=0.0, 

608 ) 

609 

610 num_steps = BLOCK_N1 // MASK_BLOCK_M1 

611 

612 dk, dv = _attn_bwd_dkdv( 

613 dk, 

614 dv, # 

615 Q, 

616 key, 

617 value, 

618 sm_scale, # 

619 DO, # 

620 M, 

621 D, # 

622 stride_tok, 

623 stride_d, # 

624 H, 

625 Q_CTX, # 

626 KV_CTX, # 

627 MASK_BLOCK_M1, 

628 BLOCK_N1, 

629 BLOCK_DMODEL, # 

630 start_n, 

631 start_m, 

632 num_steps, # 

633 MASK=True, # 

634 ) 

635 

636 # Compute dK and dV for non-masked blocks. 

637 start_m += num_steps * MASK_BLOCK_M1 

638 remaining_m = Q_CTX - start_m 

639 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1 

640 

641 if num_steps > 0 and start_m < Q_CTX: 

642 dk, dv = _attn_bwd_dkdv( # 

643 dk, 

644 dv, # 

645 Q, 

646 key, 

647 value, 

648 sm_scale, # 

649 DO, # 

650 M, 

651 D, # 

652 stride_tok, 

653 stride_d, # 

654 H, 

655 Q_CTX, # 

656 KV_CTX, # 

657 BLOCK_M1, 

658 BLOCK_N1, 

659 BLOCK_DMODEL, # 

660 start_n, 

661 start_m, 

662 num_steps, # 

663 MASK=False, # 

664 ) 

665 # tl.device_print("dv: ", dv) 

666 

667 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

668 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None]) 

669 

670 # Write back dK. 

671 dk *= sm_scale 

672 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d 

673 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None]) 

674 

675 # THIS BLOCK DOES DQ: 

676 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR 

677 start_m = pid * BLOCK_M2 

678 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX 

679 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2 

680 

681 offs_m = start_m + tl.arange(0, BLOCK_M2) 

682 offs_m_mask = offs_m < Q_CTX 

683 

684 query = tl.load( 

685 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

686 mask=offs_m_mask[:, None], 

687 other=0.0, 

688 ) 

689 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) 

690 do = tl.load( 

691 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, 

692 mask=offs_m_mask[:, None], 

693 other=0.0, 

694 ) 

695 

696 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) 

697 m = m[:, None] 

698 

699 # Stage 1 - Compute dQ for masked (diagonal) blocks. 

700 # NOTE: This code scans each row of QK^T backward (from right to left, 

701 # but inside each call to _attn_bwd_dq, from left to right), but that's 

702 # not due to anything important. I just wanted to reuse the loop 

703 # structure for dK & dV above as much as possible. 

704 

705 if num_steps > 0: 

706 dq = _attn_bwd_dq( 

707 dq, 

708 query, 

709 K, 

710 V, # 

711 do, 

712 m, 

713 D, # 

714 stride_tok, 

715 stride_d, # 

716 H, 

717 Q_CTX, # 

718 KV_CTX, # 

719 BLOCK_M2, 

720 MASK_BLOCK_N2, 

721 BLOCK_DMODEL, # 

722 start_m, 

723 start_n, 

724 num_steps, # 

725 MASK=True, # 

726 ) 

727 

728 # Stage 2 - non-masked blocks 

729 stage2_end_n = start_n 

730 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2 

731 

732 if stage2_num_steps > 0: 

733 dq = _attn_bwd_dq( 

734 dq, 

735 query, 

736 K, 

737 V, # 

738 do, 

739 m, 

740 D, # 

741 stride_tok, 

742 stride_d, # 

743 H, 

744 Q_CTX, # 

745 KV_CTX, # 

746 BLOCK_M2, 

747 BLOCK_N2, 

748 BLOCK_DMODEL, # 

749 start_m, 

750 stage2_end_n - stage2_num_steps * BLOCK_N2, 

751 stage2_num_steps, # 

752 MASK=False, # 

753 ) 

754 # Write back dQ. 

755 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d 

756 dq *= LN2 

757 # tl.store(dq_ptrs, dq) 

758 

759 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None]) 

760 

761 

762def scaled_dot_product_attention_forward( 

763 query, 

764 key, 

765 value, 

766 attn_mask=None, 

767 dropout_p=0.0, 

768 is_causal=False, 

769 scale=None, 

770 enable_gqa=False, 

771): 

772 logger.debug("GEMS_TSINGMICRO SCALED_DOT_PRODUCT_ATTENTION") 

773 # shape constraints 

774 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

775 # when v is in float8_e5m2 it is transposed. 

776 HEAD_DIM_V = value.shape[-1] 

777 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

778 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

779 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

780 

781 o = torch.empty_like(query, dtype=value.dtype) 

782 

783 stage = 3 if is_causal else 1 

784 

785 if scale is None: 

786 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

787 else: 

788 sm_scale = scale 

789 

790 q_head_num = query.shape[1] 

791 kv_head_num = key.shape[1] 

792 assert enable_gqa or q_head_num == kv_head_num, ( 

793 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, " 

794 "enable_gqa must be True to support different head numbers." 

795 ) 

796 

797 grid = lambda args: ( 

798 triton.cdiv(query.shape[2], args["BLOCK_M"]), 

799 query.shape[0] * query.shape[1], 

800 1, 

801 ) 

802 

803 if attn_mask is not None: 

804 HAS_ATTN_MASK = True 

805 if attn_mask.dtype == torch.bool: 

806 attn_mask = attn_mask.to(query.dtype) * -1.0e6 

807 stride_attn_mask_batch = attn_mask.stride(0) 

808 stride_attn_mask_head = attn_mask.stride(1) 

809 stride_attn_mask_q_seqlen = attn_mask.stride(2) 

810 stride_attn_mask_kv_seqlen = attn_mask.stride(3) 

811 else: 

812 HAS_ATTN_MASK = False 

813 stride_attn_mask_batch = 1 

814 stride_attn_mask_head = 1 

815 stride_attn_mask_q_seqlen = 1 

816 stride_attn_mask_kv_seqlen = 1 

817 

818 M = torch.empty( 

819 (query.shape[0], query.shape[1], query.shape[2]), 

820 device=query.device, 

821 dtype=torch.float32, 

822 ) 

823 

824 with torch_device_fn.device(query.device): 

825 _attn_fwd[grid]( 

826 query, 

827 key, 

828 value, 

829 attn_mask, 

830 sm_scale, 

831 M, 

832 o, # 

833 query.stride(0), 

834 query.stride(1), 

835 query.stride(2), 

836 query.stride(3), # 

837 key.stride(0), 

838 key.stride(1), 

839 key.stride(2), 

840 key.stride(3), # 

841 value.stride(0), 

842 value.stride(1), 

843 value.stride(2), 

844 value.stride(3), # 

845 stride_attn_mask_batch, 

846 stride_attn_mask_head, 

847 stride_attn_mask_q_seqlen, 

848 stride_attn_mask_kv_seqlen, # 

849 o.stride(0), 

850 o.stride(1), 

851 o.stride(2), 

852 o.stride(3), # 

853 query.shape[0], 

854 q_head_num, 

855 kv_head_num, # 

856 q_head_num // kv_head_num, # group_head 

857 query.shape[2], # 

858 key.shape[2], # 

859 HEAD_DIM_K, # 

860 STAGE=stage, # 

861 HAS_ATTN_MASK=HAS_ATTN_MASK, # 

862 ) 

863 return o, M 

864 

865 

866def scaled_dot_product_attention_backward( 

867 do, 

868 query, 

869 key, 

870 value, 

871 o, 

872 M, 

873 attn_mask=None, 

874 dropout_p=0.0, 

875 is_causal=False, 

876 scale=None, 

877 enable_gqa=False, 

878): 

879 logger.debug("GEMS_TSINGMICRO SCALED_DOT_PRODUCT_ATTENTION_BACKWARD") 

880 # shape constraints 

881 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

882 # when v is in float8_e5m2 it is transposed. 

883 HEAD_DIM_V = value.shape[-1] 

884 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

885 assert HEAD_DIM_K in {16, 32, 64, 128, 256} 

886 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0" 

887 

888 if scale is None: 

889 sm_scale = 1.0 / (HEAD_DIM_K**0.5) 

890 else: 

891 sm_scale = scale 

892 

893 assert do.is_contiguous() 

894 assert ( 

895 query.is_contiguous() 

896 and key.is_contiguous() 

897 and value.is_contiguous() 

898 and o.is_contiguous() 

899 ) 

900 assert query.stride() == o.stride() == do.stride() 

901 assert key.stride() == value.stride() 

902 

903 BLOCK_DMODEL = HEAD_DIM_K 

904 BATCH, Q_HEAD, Q_CTX = query.shape[:3] 

905 _, KV_HEAD, KV_CTX = key.shape[:3] 

906 group_head = Q_HEAD // KV_HEAD 

907 

908 # NUM_WARPS, NUM_STAGES = 4, 1 

909 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 

910 BLK_SLICE_FACTOR = 2 

911 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) 

912 

913 RCP_LN2 = 1.0 / math.log(2) 

914 

915 arg_k = key * (sm_scale * RCP_LN2) 

916 # PRE_BLOCK = 128 

917 PRE_BLOCK = 256 

918 

919 # PRE_BLOCK = 32 

920 # assert N_CTX % PRE_BLOCK == 0 

921 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD) 

922 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD) 

923 

924 delta = torch.empty_like(M) 

925 

926 # NOTE that dk & dv always have the same number of heads as q 

927 dq = torch.empty_like(query).contiguous() 

928 dk = torch.empty( 

929 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K), 

930 device=key.device, 

931 dtype=key.dtype, 

932 memory_format=torch.contiguous_format, 

933 ) 

934 dv = torch.empty( 

935 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V), 

936 device=value.device, 

937 dtype=value.dtype, 

938 memory_format=torch.contiguous_format, 

939 ) 

940 

941 _attn_bwd_preprocess[pre_grid]( 

942 o, 

943 do, # 

944 delta, # 

945 BATCH, 

946 Q_HEAD, 

947 Q_CTX, # 

948 BLOCK_M=PRE_BLOCK, 

949 D_HEAD=BLOCK_DMODEL, # 

950 ) 

951 

952 max_block_n1 = ( 

953 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward]) 

954 if config_backward 

955 else 128 

956 ) 

957 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD) 

958 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}") 

959 # logger.info(f"{M.shape=}") 

960 

961 _attn_bwd[grid]( 

962 query, 

963 arg_k, 

964 value, 

965 sm_scale, 

966 do, 

967 dq, 

968 dk, 

969 dv, # 

970 M, 

971 delta, # 

972 query.stride(0), 

973 query.stride(1), 

974 query.stride(2), 

975 query.stride(3), # 

976 key.stride(0), 

977 key.stride(1), # 

978 Q_HEAD, 

979 Q_CTX, # 

980 KV_CTX, # 

981 KV_HEAD, # 

982 GROUP_HEAD=group_head, # 

983 # BLOCK_M1=BLOCK_M1, 

984 # BLOCK_N1=BLOCK_N1, # 

985 # BLOCK_M2=BLOCK_M2, 

986 # BLOCK_N2=BLOCK_N2, # 

987 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # 

988 BLOCK_DMODEL=BLOCK_DMODEL, # 

989 # num_warps=NUM_WARPS, # 

990 # num_stages=NUM_STAGES, # 

991 ) 

992 

993 if group_head > 1: 

994 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K) 

995 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V) 

996 dk = dk.sum(dim=2) 

997 dv = dv.sum(dim=2) 

998 

999 return dq, dk, dv 

1000 

1001 

1002class ScaleDotProductAttention(torch.autograd.Function): 

1003 @staticmethod 

1004 def forward( 

1005 ctx, 

1006 query, 

1007 key, 

1008 value, 

1009 attn_mask=None, 

1010 dropout_p=0.0, 

1011 is_causal=False, 

1012 scale=None, 

1013 enable_gqa=False, 

1014 ): 

1015 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5) 

1016 o, M = scaled_dot_product_attention_forward( 

1017 query, 

1018 key, 

1019 value, 

1020 attn_mask, 

1021 dropout_p, 

1022 is_causal, 

1023 sm_scale, 

1024 enable_gqa, 

1025 ) 

1026 

1027 ctx.save_for_backward(query, key, value, o, M) 

1028 ctx.sm_scale = sm_scale 

1029 ctx.causal = is_causal 

1030 ctx.enable_gqa = enable_gqa 

1031 return o 

1032 

1033 @staticmethod 

1034 def backward(ctx, do): 

1035 query, key, value, o, M = ctx.saved_tensors 

1036 is_causal = ctx.causal 

1037 enable_gqa = ctx.enable_gqa 

1038 sm_scale = ctx.sm_scale 

1039 dq, dk, dv = scaled_dot_product_attention_backward( 

1040 do, 

1041 query, 

1042 key, 

1043 value, 

1044 o, 

1045 M, 

1046 attn_mask=None, 

1047 dropout_p=0.0, 

1048 is_causal=is_causal, 

1049 scale=sm_scale, 

1050 enable_gqa=enable_gqa, 

1051 ) 

1052 return dq, dk, dv, None, None, None, None, None 

1053 

1054 

1055def scaled_dot_product_attention( 

1056 query, 

1057 key, 

1058 value, 

1059 attn_mask=None, 

1060 dropout_p=0.0, 

1061 is_causal=False, 

1062 scale=None, 

1063 enable_gqa=False, 

1064): 

1065 return ScaleDotProductAttention.apply( 

1066 query, 

1067 key, 

1068 value, 

1069 attn_mask, 

1070 dropout_p, 

1071 is_causal, 

1072 scale, 

1073 enable_gqa, 

1074 ) 

1075 

1076 

1077def flash_attention_forward( 

1078 query, 

1079 key, 

1080 value, 

1081 cumulative_sequence_length_q, 

1082 cumulative_sequence_length_k, 

1083 max_q, 

1084 max_k, 

1085 dropout_p, 

1086 is_causal, 

1087 return_debug_mask, 

1088 *, 

1089 scale=None, 

1090 softcap=0.0, 

1091 window_size_left=None, 

1092 window_size_right=None, 

1093 seqused_k=None, 

1094 alibi_slopes=None, 

1095 disable_splitkv=False, 

1096): 

1097 logger.debug("GEMS_TSINGMICRO FLASH_ATTENTION_FORWARD") 

1098 assert ( 

1099 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None 

1100 ), "varlen is not supported yet." 

1101 

1102 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1] 

1103 HEAD_DIM_V = value.shape[-1] 

1104 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V 

1105 original_head_dim = HEAD_DIM_K 

1106 supported_head_dims = (16, 32, 64, 96, 128, 192, 256) 

1107 if HEAD_DIM_K not in supported_head_dims: 

1108 padded_head_dim = None 

1109 for d in supported_head_dims: 

1110 if d >= HEAD_DIM_K: 

1111 padded_head_dim = d 

1112 break 

1113 assert ( 

1114 padded_head_dim is not None 

1115 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}" 

1116 pad = padded_head_dim - HEAD_DIM_K 

1117 query = F.pad(query, (0, pad)) 

1118 key = F.pad(key, (0, pad)) 

1119 value = F.pad(value, (0, pad)) 

1120 HEAD_DIM_K = padded_head_dim 

1121 

1122 softmax_scale = scale or 1.0 / (original_head_dim**0.5) 

1123 if window_size_left is not None: 

1124 non_null_window_left = window_size_left 

1125 else: 

1126 non_null_window_left = -1 

1127 if window_size_right is not None: 

1128 non_null_window_right = window_size_right 

1129 else: 

1130 non_null_window_right = -1 

1131 

1132 out = torch.empty_like(query) 

1133 if cumulative_sequence_length_q is not None: 

1134 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd( 

1135 query, 

1136 key, 

1137 value, 

1138 out, 

1139 cumulative_sequence_length_q, 

1140 cumulative_sequence_length_k, 

1141 seqused_k, 

1142 None, 

1143 None, # block_table 

1144 alibi_slopes, 

1145 max_q, 

1146 max_k, 

1147 dropout_p, 

1148 scale, 

1149 False, 

1150 is_causal, 

1151 non_null_window_left, 

1152 non_null_window_right, 

1153 softcap, 

1154 return_debug_mask and dropout_p > 0, 

1155 None, 

1156 ) 

1157 else: 

1158 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd( 

1159 query, 

1160 key, 

1161 value, 

1162 out, 

1163 alibi_slopes, 

1164 dropout_p, 

1165 softmax_scale, 

1166 is_causal, 

1167 non_null_window_left, 

1168 non_null_window_right, 

1169 softcap, 

1170 return_debug_mask, 

1171 disable_splitkv=disable_splitkv, 

1172 ) 

1173 

1174 if HEAD_DIM_K != original_head_dim: 

1175 out = out[..., :original_head_dim] 

1176 return (out, lse, philox_seed, philox_offset, p) 

1177 

1178 

1179# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py 

1180def maybe_contiguous(x): 

1181 return x.contiguous() if x is not None and x.stride(-1) != 1 else x 

1182 

1183 

1184def flash_attn_varlen_func( 

1185 q, 

1186 k, 

1187 v, 

1188 max_seqlen_q, 

1189 cu_seqlens_q, 

1190 max_seqlen_k, 

1191 cu_seqlens_k=None, # only used for non-paged prefill 

1192 seqused_k=None, 

1193 q_v=None, 

1194 dropout_p=0.0, 

1195 softmax_scale=None, 

1196 causal=False, 

1197 window_size=None, 

1198 softcap=0.0, # 0.0 means deactivated 

1199 alibi_slopes=None, 

1200 deterministic=False, 

1201 return_attn_probs=False, 

1202 block_table=None, 

1203 return_softmax_lse=False, 

1204 out=None, 

1205 # Dummy FA3 arguments 

1206 scheduler_metadata=None, 

1207 q_descale=None, 

1208 k_descale=None, 

1209 v_descale=None, 

1210 s_aux=None, 

1211 num_splits: int = 0, 

1212 cp_world_size: int = 1, 

1213 cp_rank: int = 0, 

1214 cp_tot_seqused_k=None, 

1215 fa_version: int = 2, 

1216): 

1217 """dropout_p should be set to 0.0 during evaluation 

1218 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads 

1219 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. 

1220 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 

1221 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. 

1222 

1223 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. 

1224 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 

1225 1 1 1 1 0 

1226 1 1 1 1 1 

1227 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 

1228 0 0 

1229 0 0 

1230 0 0 

1231 1 0 

1232 1 1 

1233 If the row of the mask is all zero, the output will be zero. 

1234 

1235 If window_size != (-1, -1), implements sliding window local attention. Query at position i 

1236 will only attend to keys between 

1237 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. 

1238 

1239 Arguments: 

1240 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. 

1241 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1242 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. 

1243 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1244 of the sequences in the batch, used to index into q. 

1245 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 

1246 of the sequences in the batch, used to index into kv. 

1247 max_seqlen_q: int. Maximum query sequence length in the batch. 

1248 max_seqlen_k: int. Maximum key sequence length in the batch. 

1249 dropout_p: float. Dropout probability. 

1250 softmax_scale: float. The scaling of QK^T before applying softmax. 

1251 Default to 1 / sqrt(headdim). 

1252 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 

1253 window_size: (left, right). If not (-1, -1), implements sliding window local attention. 

1254 softcap: float. Anything > 0 activates softcapping attention. 

1255 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of 

1256 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 

1257 is added to the attention score of query i and key j. 

1258 deterministic: bool. Whether to use the deterministic implementation of the backward pass, 

1259 which is slightly slower and uses more memory. The forward pass is always deterministic. 

1260 return_attn_probs: bool. Whether to return the attention probabilities. This option is for 

1261 testing only. The returned probabilities are not guaranteed to be correct 

1262 (they might not have the right scaling). 

1263 Return: 

1264 out: (total, nheads, headdim). 

1265 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The 

1266 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax 

1267 normalization factor). 

1268 """ 

1269 if fa_version != 2: 

1270 raise RuntimeError("Only FA2 is implemented.") 

1271 if num_splits > 0: 

1272 raise RuntimeError("num_splits > 0 is not implemented in GEMS.") 

1273 if use_c_extension: 

1274 logger.debug("GEMS_TSINGMICRO FLASH_ATTN_VARLEN_FUNC (C Extension)") 

1275 with torch_device_fn.device(q.device): 

1276 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func( 

1277 q, 

1278 k, 

1279 v, 

1280 max_seqlen_q, 

1281 cu_seqlens_q, 

1282 max_seqlen_k, 

1283 cu_seqlens_k, 

1284 seqused_k, 

1285 q_v, 

1286 dropout_p, 

1287 softmax_scale, 

1288 causal, 

1289 window_size, 

1290 softcap, 

1291 alibi_slopes, 

1292 deterministic, 

1293 return_attn_probs, 

1294 block_table, 

1295 return_softmax_lse, 

1296 out, 

1297 scheduler_metadata, 

1298 q_descale, 

1299 k_descale, 

1300 v_descale, 

1301 s_aux, 

1302 num_splits, 

1303 cp_world_size, 

1304 cp_rank, 

1305 cp_tot_seqused_k, 

1306 fa_version, 

1307 ) 

1308 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp 

1309 else: 

1310 logger.debug("GEMS_TSINGMICRO FLASH_ATTN_VARLEN_FUNC") 

1311 assert ( 

1312 cu_seqlens_k is not None or seqused_k is not None 

1313 ), "cu_seqlens_k or seqused_k must be provided" 

1314 assert ( 

1315 cu_seqlens_k is None or seqused_k is None 

1316 ), "cu_seqlens_k and seqused_k cannot be provided at the same time" 

1317 assert ( 

1318 block_table is None or seqused_k is not None 

1319 ), "seqused_k must be provided if block_table is provided" 

1320 if softmax_scale is None: 

1321 softmax_scale = q.shape[-1] ** (-0.5) 

1322 # custom op does not support non-tuple input 

1323 if window_size is None: 

1324 real_window_size = (-1, -1) 

1325 else: 

1326 assert len(window_size) == 2 

1327 real_window_size = (window_size[0], window_size[1]) 

1328 q, k, v = [maybe_contiguous(x) for x in (q, k, v)] 

1329 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q) 

1330 max_seqlen_q = ( 

1331 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q 

1332 ) 

1333 max_seqlen_k = ( 

1334 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k 

1335 ) 

1336 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd( 

1337 q, 

1338 k, 

1339 v, 

1340 out, 

1341 cu_seqlens_q, 

1342 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp 

1343 # still wants it so we pass all zeros 

1344 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k, 

1345 seqused_k, 

1346 None, 

1347 block_table, 

1348 alibi_slopes, 

1349 max_seqlen_q, 

1350 max_seqlen_k, 

1351 dropout_p, 

1352 softmax_scale, 

1353 False, 

1354 causal, 

1355 real_window_size[0], 

1356 real_window_size[1], 

1357 softcap, 

1358 return_softmax_lse and dropout_p > 0, 

1359 None, 

1360 ) 

1361 

1362 return (out, softmax_lse) if return_softmax_lse else out