Coverage for src/flag_gems/ops/attention.py: 30%
430 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.ops.flash_api import mha_fwd, mha_varlan_fwd, mha_varlan_fwd_opt
13from flag_gems.ops.flash_kernel import keep
14from flag_gems.runtime import torch_device_fn
15from flag_gems.utils import libentry, libtuner
17logger = logging.getLogger(__name__)
20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
21@triton.jit
22def _attn_fwd_inner(
23 acc,
24 l_i,
25 m_i,
26 query, #
27 K_block_ptr,
28 V_block_ptr, #
29 mask_block_ptr, #
30 stride_k_seqlen,
31 stride_v_seqlen,
32 stride_attn_mask_kv_seqlen, #
33 start_m,
34 qk_scale, #
35 q_load_mask,
36 BLOCK_M: tl.constexpr,
37 HEAD_DIM: tl.constexpr,
38 BLOCK_N: tl.constexpr, #
39 STAGE: tl.constexpr,
40 offs_m: tl.constexpr,
41 offs_n: tl.constexpr, #
42 KV_CTX: tl.constexpr,
43 fp8_v: tl.constexpr,
44 HAS_ATTN_MASK: tl.constexpr,
45 PRE_LOAD_V: tl.constexpr,
46):
47 # range of values handled by this stage
48 if STAGE == 1:
49 lo, hi = 0, start_m * BLOCK_M
50 elif STAGE == 2:
51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
52 # causal = False
53 else:
54 lo, hi = 0, KV_CTX
56 K_block_ptr += lo * stride_k_seqlen
57 V_block_ptr += lo * stride_v_seqlen
58 if HAS_ATTN_MASK:
59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
61 LOG2E = 1.44269504 # log2(e) constant
63 # loop over key, value and update accumulator
64 for start_n in range(lo, hi, BLOCK_N):
65 kv_load_mask = (start_n + offs_n) < KV_CTX
66 # start_n = tl.multiple_of(start_n, BLOCK_N)
67 # -- compute qk ----
68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
69 if PRE_LOAD_V:
70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
72 qk = tl.dot(query, key, allow_tf32=False)
73 # incase not divisible.
74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
75 # qk = qk.to(tl.float32)
77 if HAS_ATTN_MASK:
78 attn_mask = tl.load(
79 mask_block_ptr,
80 mask=q_load_mask[:, None] & kv_load_mask[None, :],
81 other=0.0,
82 )
84 if STAGE == 2:
85 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
87 if HAS_ATTN_MASK:
88 qk = qk * qk_scale + attn_mask
89 qk *= LOG2E
90 qk = qk + tl.where(mask, 0, -1.0e6)
91 else:
92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
94 m_ij = tl.maximum(m_i, tl.max(qk, 1))
95 qk -= m_ij[:, None]
96 else:
97 qk *= qk_scale * LOG2E
98 if HAS_ATTN_MASK:
99 qk = qk + attn_mask
100 m_ij = tl.maximum(m_i, tl.max(qk, 1))
101 qk = qk - m_ij[:, None]
103 p = tl.math.exp2(qk)
104 l_ij = tl.sum(p, 1)
105 # -- update m_i and l_i
106 alpha = tl.math.exp2(m_i - m_ij)
107 l_i = l_i * alpha + l_ij
108 # -- update output accumulator --
109 acc = acc * alpha[:, None]
110 # update acc
111 if not PRE_LOAD_V:
112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
113 if fp8_v:
114 p = p.to(tl.float8e5)
115 else:
116 p = p.to(query.dtype)
117 p = p.to(value.dtype)
118 acc = tl.dot(p, value, acc, allow_tf32=False)
119 # update m_i and l_i
120 m_i = m_ij
122 K_block_ptr += BLOCK_N * stride_k_seqlen
123 V_block_ptr += BLOCK_N * stride_v_seqlen
125 if HAS_ATTN_MASK:
126 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
128 return acc, l_i, m_i
131# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
132# we need to generate more configs.
133configs = runtime.get_tuned_config("attention")
134SMALL_HEAD_DIM_CONFIGS = [
135 triton.Config(
136 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
137 )
138 for BM in [64, 128]
139 for BN in [16, 32]
140 for s in [2, 3, 4]
141 for w in [4, 8]
142]
143configs += SMALL_HEAD_DIM_CONFIGS
146@libentry()
147@libtuner(
148 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
149 key=["KV_CTX", "HEAD_DIM"],
150)
151@triton.jit
152def _attn_fwd(
153 Q,
154 K,
155 V,
156 attn_mask,
157 sm_scale,
158 M,
159 Out, #
160 stride_q_batch,
161 stride_q_head,
162 stride_q_seqlen,
163 stride_q_headsize,
164 stride_k_batch,
165 stride_k_head,
166 stride_k_seqlen,
167 stride_k_headsize,
168 stride_v_batch,
169 stride_v_head,
170 stride_v_seqlen,
171 stride_v_headsize,
172 stride_attn_mask_batch,
173 stride_attn_mask_head,
174 stride_attn_mask_q_seqlen,
175 stride_attn_mask_kv_seqlen,
176 stride_o_batch,
177 stride_o_head,
178 stride_o_seqlen,
179 stride_o_headsize,
180 Z,
181 q_head_num,
182 kv_head_num,
183 GROUP_HEAD: tl.constexpr,
184 Q_CTX,
185 KV_CTX,
186 HEAD_DIM: tl.constexpr,
187 BLOCK_M: tl.constexpr,
188 BLOCK_N: tl.constexpr,
189 STAGE: tl.constexpr,
190 HAS_ATTN_MASK: tl.constexpr,
191 PRE_LOAD_V: tl.constexpr,
192):
193 tl.static_assert(BLOCK_N <= HEAD_DIM)
194 start_m = tl.program_id(0)
195 off_hz = tl.program_id(1)
196 batch_id = off_hz // q_head_num
197 head_id = off_hz % q_head_num
198 kv_head_id = head_id // GROUP_HEAD
200 q_offset = (
201 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
202 )
203 o_offset = (
204 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
205 )
206 kv_offset = (
207 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
208 )
210 offs_headsize = tl.arange(0, HEAD_DIM)
212 # initialize offsets
213 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
214 q_load_mask = offs_m < Q_CTX
215 offs_n = tl.arange(0, BLOCK_N)
217 Q_block_ptr = (
218 Q
219 + q_offset
220 + offs_m[:, None] * stride_q_seqlen
221 + offs_headsize[None, :] * stride_q_headsize
222 )
223 K_block_ptr = (
224 K
225 + kv_offset
226 + offs_n[None, :] * stride_k_seqlen
227 + offs_headsize[:, None] * stride_k_headsize
228 )
229 V_block_ptr = (
230 V
231 + kv_offset
232 + offs_n[:, None] * stride_v_seqlen
233 + offs_headsize[None, :] * stride_v_headsize
234 )
236 if HAS_ATTN_MASK:
237 attn_mask_offset = (
238 batch_id.to(tl.int64) * stride_attn_mask_batch
239 + head_id.to(tl.int64) * stride_attn_mask_head
240 )
241 mask_block_ptr = (
242 attn_mask
243 + attn_mask_offset
244 + offs_m[:, None] * stride_attn_mask_q_seqlen
245 + offs_n[None, :] * stride_attn_mask_kv_seqlen
246 )
247 else:
248 mask_block_ptr = None
250 O_block_ptr = (
251 Out
252 + o_offset
253 + offs_m[:, None] * stride_o_seqlen
254 + offs_headsize[None, :] * stride_o_headsize
255 )
257 # initialize pointer to m and l
258 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
259 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
260 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
261 # load scales
262 qk_scale = sm_scale
263 # qk_scale *= 1.44269504 # 1/log(2)
264 # load query: it will stay in SRAM throughout
265 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
266 # stage 1: off-band
267 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
268 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
269 if STAGE & 1:
270 acc, l_i, m_i = _attn_fwd_inner(
271 acc,
272 l_i,
273 m_i,
274 query,
275 K_block_ptr,
276 V_block_ptr,
277 mask_block_ptr,
278 stride_k_seqlen,
279 stride_v_seqlen,
280 stride_attn_mask_kv_seqlen,
281 start_m,
282 qk_scale,
283 q_load_mask,
284 BLOCK_M,
285 HEAD_DIM,
286 BLOCK_N,
287 4 - STAGE,
288 offs_m,
289 offs_n,
290 KV_CTX,
291 V.dtype.element_ty == tl.float8e5,
292 HAS_ATTN_MASK,
293 PRE_LOAD_V,
294 )
295 # stage 2: on-band
296 if STAGE & 2:
297 # barrier makes it easier for compielr to schedule the
298 # two loops independently
299 acc, l_i, m_i = _attn_fwd_inner(
300 acc,
301 l_i,
302 m_i,
303 query,
304 K_block_ptr,
305 V_block_ptr,
306 mask_block_ptr,
307 stride_k_seqlen,
308 stride_v_seqlen,
309 stride_attn_mask_kv_seqlen,
310 start_m,
311 qk_scale,
312 q_load_mask,
313 BLOCK_M,
314 HEAD_DIM,
315 BLOCK_N,
316 2,
317 offs_m,
318 offs_n,
319 KV_CTX,
320 V.dtype.element_ty == tl.float8e5,
321 HAS_ATTN_MASK,
322 PRE_LOAD_V,
323 )
324 # epilogue
325 m_i += tl.math.log2(l_i)
326 acc = acc / l_i[:, None]
327 m_ptrs = M + off_hz * Q_CTX + offs_m
328 tl.store(m_ptrs, m_i, mask=q_load_mask)
329 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
332@triton.jit
333def _attn_bwd_preprocess(
334 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
335):
336 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
337 mask = off_m < Q_CTX
339 off_hz = tl.program_id(1)
340 off_n = tl.arange(0, D_HEAD)
341 # load
342 o = tl.load(
343 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
344 mask=mask[:, None],
345 other=0.0,
346 )
347 do = tl.load(
348 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
349 mask=mask[:, None],
350 other=0.0,
351 ).to(tl.float32)
352 delta = tl.sum(o * do, axis=1)
353 # write-back
354 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
357# The main inner-loop logic for computing dK and dV.
358@triton.jit
359def _attn_bwd_dkdv(
360 dk,
361 dv, #
362 Q,
363 key,
364 value,
365 sm_scale, #
366 DO, #
367 M,
368 D, #
369 # shared by Q/K/V/DO.
370 stride_tok,
371 stride_d, #
372 H,
373 Q_CTX,
374 KV_CTX,
375 BLOCK_M1: tl.constexpr, #
376 BLOCK_N1: tl.constexpr, #
377 BLOCK_DMODEL: tl.constexpr, #
378 # Filled in by the wrapper.
379 start_n,
380 start_m,
381 num_steps, #
382 MASK: tl.constexpr,
383):
384 # BLOCK_M1: 32
385 # BLOCK_N1: 128
386 offs_n = start_n + tl.arange(0, BLOCK_N1)
387 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
389 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
391 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
392 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
393 curr_m = start_m
394 step_m = BLOCK_M1
395 for blk_idx in range(num_steps):
396 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
397 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
399 qT_ptrs = (
400 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
401 ) # (BLOCK_DMODEL, BLOCK_M1)
402 do_ptrs = (
403 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
404 ) # (BLOCK_M1, BLOCK_DMODEL)
406 qT = tl.load(
407 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
408 ) # (BLOCK_DMODEL, BLOCK_M1)
410 # Load m before computing qk to reduce pipeline stall.
411 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
413 # key: (BLOCK_N1, BLOCK_DMODEL)
414 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
415 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
416 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
417 pT = tl.math.exp2(qkT - m)
418 # pT = tl.math.exp2(qkT - m[None, :])
420 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
421 :, None
422 ] # (BLOCK_N1, BLOCK_M1)
423 # Autoregressive masking.
424 if MASK:
425 mask &= offs_m[None, :] >= offs_n[:, None]
426 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
428 do = tl.load(
429 do_ptrs, mask=offs_m_mask[:, None], other=0.0
430 ) # (BLOCK_M1, BLOCK_DMODEL)
432 # Compute dV.
433 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
434 # D (= delta) is pre-divided by ds_scale.
435 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
437 # Compute dP and dS.
438 dpT = tl.dot(value, tl.trans(do)).to(
439 tl.float32
440 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
441 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
442 dsT = dsT.to(qT.dtype)
443 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
444 dsT = tl.where(
445 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
446 ) # (BLOCK_N1, BLOCK_M1)
447 dk += tl.dot(
448 dsT, tl.trans(qT)
449 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
450 # Increment pointers.
451 curr_m += step_m
452 return dk, dv
455# the main inner-loop logic for computing dQ
456@triton.jit
457def _attn_bwd_dq(
458 dq,
459 query,
460 K,
461 V, #
462 do,
463 m,
464 D,
465 # shared by Q/K/V/DO.
466 stride_tok,
467 stride_d, #
468 H,
469 Q_CTX, #
470 KV_CTX, #
471 BLOCK_M2: tl.constexpr, #
472 BLOCK_N2: tl.constexpr, #
473 BLOCK_DMODEL: tl.constexpr,
474 # Filled in by the wrapper.
475 start_m,
476 start_n,
477 num_steps, #
478 MASK: tl.constexpr,
479):
480 offs_m = start_m + tl.arange(0, BLOCK_M2)
481 offs_m_mask = offs_m < Q_CTX
483 offs_k = tl.arange(0, BLOCK_DMODEL)
484 # D (= delta) is pre-divided by ds_scale.
485 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
486 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
487 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
488 curr_n = start_n
489 step_n = BLOCK_N2
490 for blk_idx in range(num_steps):
491 offs_n = curr_n + tl.arange(0, BLOCK_N2)
492 offs_n_mask = offs_n < KV_CTX
494 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
495 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
497 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
498 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
499 qk = tl.dot(query, kT)
500 p = tl.math.exp2(qk - m)
501 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
502 # Autoregressive masking.
503 if MASK:
504 # mask = (offs_m[:, None] >= offs_n[None, :])
505 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
506 mask &= offs_m[:, None] >= offs_n[None, :]
507 p = tl.where(mask, p, 0.0)
508 # Compute dP and dS.
509 dp = tl.dot(do, vT).to(tl.float32)
510 ds = p * (dp - Di[:, None])
511 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
512 # Compute dQ.
513 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
514 dq += tl.dot(ds, tl.trans(kT))
515 # Increment pointers.
516 curr_n += step_n
517 return dq
520config_backward = runtime.get_tuned_config("attention_bwd")
523@libentry()
524@libtuner(
525 configs=config_backward,
526 key=["KV_CTX", "BLOCK_DMODEL"],
527)
528@triton.jit
529def _attn_bwd(
530 Q,
531 K,
532 V,
533 sm_scale, #
534 DO, #
535 DQ,
536 DK,
537 DV, #
538 M,
539 D,
540 # shared by Q/K/V/DO.
541 stride_z,
542 stride_h,
543 stride_tok,
544 stride_d, #
545 kv_stride_z,
546 kv_stride_h, #
547 dk_stride_z,
548 dk_stride_h,
549 dk_stride_tok, #
550 H, # query head num
551 Q_CTX, #
552 KV_CTX, #
553 kv_head_num, #
554 GROUP_HEAD: tl.constexpr, #
555 BLOCK_M1: tl.constexpr, #
556 BLOCK_N1: tl.constexpr, #
557 BLOCK_M2: tl.constexpr, #
558 BLOCK_N2: tl.constexpr, #
559 BLK_SLICE_FACTOR: tl.constexpr, #
560 BLOCK_DMODEL: tl.constexpr,
561 IS_CAUSAL: tl.constexpr = True,
562):
563 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
565 bhid = tl.program_id(2)
566 off_chz = (bhid * Q_CTX).to(tl.int64)
567 batch_id = bhid // H
568 q_head_id = bhid % H
569 kv_head_id = q_head_id // GROUP_HEAD
570 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
571 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
572 dk_adj = (dk_stride_h * q_head_id + dk_stride_z * batch_id).to(tl.int64)
574 pid = tl.program_id(0)
576 # offset pointers for batch/head
577 Q += adj
578 K += kv_adj
579 V += kv_adj
580 DO += adj
581 DQ += adj
582 DK += dk_adj
583 DV += dk_adj
584 M += off_chz
585 D += off_chz
587 # load scales
588 offs_k = tl.arange(0, BLOCK_DMODEL)
590 # dK/dV: only execute when this pid covers a valid KV block
591 start_n = pid * BLOCK_N1
592 if start_n < KV_CTX:
593 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
594 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
596 # load K and V: they stay in SRAM throughout the inner loop.
597 offs_n = start_n + tl.arange(0, BLOCK_N1)
598 offs_n_mask = offs_n < KV_CTX
599 key = tl.load(
600 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
601 mask=offs_n_mask[:, None],
602 other=0.0,
603 )
604 value = tl.load(
605 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
606 mask=offs_n_mask[:, None],
607 other=0.0,
608 )
610 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
612 # Causal: masked diagonal phase, then unmasked above-diagonal phase.
613 # Non-causal: skip masked phase, single unmasked pass over all Q rows.
614 if IS_CAUSAL:
615 # The causal mask is q_idx >= kv_idx, so for KV block starting at
616 # start_n, the first Q row that can attend is start_n itself.
617 start_m = start_n
618 # Clamp to valid Q range
619 if start_m < Q_CTX:
620 end_m = min(start_m + BLOCK_N1, Q_CTX)
621 num_steps = (end_m - start_m + MASK_BLOCK_M1 - 1) // MASK_BLOCK_M1
622 dk, dv = _attn_bwd_dkdv(
623 dk,
624 dv, #
625 Q,
626 key,
627 value,
628 sm_scale, #
629 DO, #
630 M,
631 D, #
632 stride_tok,
633 stride_d, #
634 H,
635 Q_CTX, #
636 KV_CTX, #
637 MASK_BLOCK_M1,
638 BLOCK_N1,
639 BLOCK_DMODEL, #
640 start_n,
641 start_m,
642 num_steps, #
643 MASK=True, #
644 )
645 start_m += num_steps * MASK_BLOCK_M1
646 # else: start_n >= Q_CTX, no Q rows can attend to this KV block
647 else:
648 start_m = 0
650 # Unmasked phase (shared): traverse remaining Q rows.
651 remaining_m = Q_CTX - start_m
652 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
653 if num_steps > 0:
654 dk, dv = _attn_bwd_dkdv(
655 dk,
656 dv, #
657 Q,
658 key,
659 value,
660 sm_scale, #
661 DO, #
662 M,
663 D, #
664 stride_tok,
665 stride_d, #
666 H,
667 Q_CTX, #
668 KV_CTX, #
669 BLOCK_M1,
670 BLOCK_N1,
671 BLOCK_DMODEL, #
672 start_n,
673 start_m,
674 num_steps, #
675 MASK=False, #
676 )
678 dv_ptrs = DV + offs_n[:, None] * dk_stride_tok + offs_k[None, :] * stride_d
679 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
681 # Write back dK.
682 dk *= sm_scale
683 dk_ptrs = DK + offs_n[:, None] * dk_stride_tok + offs_k[None, :] * stride_d
684 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
686 # dQ: only execute when this pid covers a valid Q block
687 start_m = pid * BLOCK_M2
688 if start_m < Q_CTX:
689 offs_m = start_m + tl.arange(0, BLOCK_M2)
690 offs_m_mask = offs_m < Q_CTX
691 query = tl.load(
692 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
693 mask=offs_m_mask[:, None],
694 other=0.0,
695 )
696 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
697 do = tl.load(
698 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
699 mask=offs_m_mask[:, None],
700 other=0.0,
701 )
702 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
703 m = m[:, None]
705 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
707 if IS_CAUSAL:
708 # Masked diagonal phase: KV columns [diag_n, end_n)
709 # diag_n is the KV position where the causal boundary starts for
710 # this Q block. Only needed when diag_n < KV_CTX.
711 diag_n = min(start_m, KV_CTX)
712 end_n = min(start_m + BLOCK_M2, KV_CTX)
713 num_steps = (end_n - diag_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
715 if num_steps > 0:
716 dq = _attn_bwd_dq(
717 dq,
718 query,
719 K,
720 V, #
721 do,
722 m,
723 D, #
724 stride_tok,
725 stride_d, #
726 H,
727 Q_CTX, #
728 KV_CTX, #
729 BLOCK_M2,
730 MASK_BLOCK_N2,
731 BLOCK_DMODEL, #
732 start_m,
733 diag_n,
734 num_steps, #
735 MASK=True, #
736 )
738 # Unmasked phase: KV columns [0, diag_n), all fully visible.
739 stage2_num_steps = (diag_n + BLOCK_N2 - 1) // BLOCK_N2
740 else:
741 # Non-causal: single unmasked pass over all KV columns.
742 stage2_num_steps = (KV_CTX + BLOCK_N2 - 1) // BLOCK_N2
744 if stage2_num_steps > 0:
745 dq = _attn_bwd_dq(
746 dq,
747 query,
748 K,
749 V, #
750 do,
751 m,
752 D, #
753 stride_tok,
754 stride_d, #
755 H,
756 Q_CTX, #
757 KV_CTX, #
758 BLOCK_M2,
759 BLOCK_N2,
760 BLOCK_DMODEL, #
761 start_m,
762 0,
763 stage2_num_steps, #
764 MASK=False, #
765 )
767 # Write back dQ.
768 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
769 dq *= LN2
770 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
773def scaled_dot_product_attention_forward(
774 query,
775 key,
776 value,
777 attn_mask=None,
778 dropout_p=0.0,
779 is_causal=False,
780 scale=None,
781 enable_gqa=False,
782):
783 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD")
784 # shape constraints
785 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
786 # when v is in float8_e5m2 it is transposed.
787 HEAD_DIM_V = value.shape[-1]
788 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
789 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
790 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
792 o = torch.empty_like(query, dtype=value.dtype)
794 stage = 3 if is_causal else 1
796 if scale is None:
797 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
798 else:
799 sm_scale = scale
801 q_head_num = query.shape[1]
802 kv_head_num = key.shape[1]
803 assert enable_gqa or q_head_num == kv_head_num, (
804 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
805 "enable_gqa must be True to support different head numbers."
806 )
808 grid = lambda args: (
809 triton.cdiv(query.shape[2], args["BLOCK_M"]),
810 query.shape[0] * query.shape[1],
811 1,
812 )
814 if attn_mask is not None:
815 HAS_ATTN_MASK = True
816 if attn_mask.dtype == torch.bool:
817 attn_mask = attn_mask.to(query.dtype) * -1.0e6
818 stride_attn_mask_batch = attn_mask.stride(0)
819 stride_attn_mask_head = attn_mask.stride(1)
820 stride_attn_mask_q_seqlen = attn_mask.stride(2)
821 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
822 else:
823 HAS_ATTN_MASK = False
824 stride_attn_mask_batch = 1
825 stride_attn_mask_head = 1
826 stride_attn_mask_q_seqlen = 1
827 stride_attn_mask_kv_seqlen = 1
829 M = torch.empty(
830 (query.shape[0], query.shape[1], query.shape[2]),
831 device=query.device,
832 dtype=torch.float32,
833 )
835 with torch_device_fn.device(query.device):
836 _attn_fwd[grid](
837 query,
838 key,
839 value,
840 attn_mask,
841 sm_scale,
842 M,
843 o, #
844 query.stride(0),
845 query.stride(1),
846 query.stride(2),
847 query.stride(3), #
848 key.stride(0),
849 key.stride(1),
850 key.stride(2),
851 key.stride(3), #
852 value.stride(0),
853 value.stride(1),
854 value.stride(2),
855 value.stride(3), #
856 stride_attn_mask_batch,
857 stride_attn_mask_head,
858 stride_attn_mask_q_seqlen,
859 stride_attn_mask_kv_seqlen, #
860 o.stride(0),
861 o.stride(1),
862 o.stride(2),
863 o.stride(3), #
864 query.shape[0],
865 q_head_num,
866 kv_head_num, #
867 q_head_num // kv_head_num, # group_head
868 query.shape[2], #
869 key.shape[2], #
870 HEAD_DIM_K, #
871 STAGE=stage, #
872 HAS_ATTN_MASK=HAS_ATTN_MASK, #
873 )
874 return o, M
877def scaled_dot_product_attention_backward(
878 do,
879 query,
880 key,
881 value,
882 o,
883 M,
884 attn_mask=None,
885 dropout_p=0.0,
886 is_causal=False,
887 scale=None,
888 enable_gqa=False,
889):
890 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD")
891 # shape constraints
892 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
893 # when v is in float8_e5m2 it is transposed.
894 HEAD_DIM_V = value.shape[-1]
895 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
896 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
897 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
899 if scale is None:
900 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
901 else:
902 sm_scale = scale
904 assert do.is_contiguous()
905 assert (
906 query.is_contiguous()
907 and key.is_contiguous()
908 and value.is_contiguous()
909 and o.is_contiguous()
910 )
911 assert query.stride() == o.stride() == do.stride()
912 assert key.stride() == value.stride()
914 BLOCK_DMODEL = HEAD_DIM_K
915 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
916 _, KV_HEAD, KV_CTX = key.shape[:3]
917 group_head = Q_HEAD // KV_HEAD
919 # NUM_WARPS, NUM_STAGES = 4, 1
920 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
921 BLK_SLICE_FACTOR = 2
922 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
924 RCP_LN2 = 1.0 / math.log(2)
926 arg_k = key * (sm_scale * RCP_LN2)
927 # PRE_BLOCK = 128
928 PRE_BLOCK = 256
930 # PRE_BLOCK = 32
931 # assert N_CTX % PRE_BLOCK == 0
932 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
933 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
935 delta = torch.empty_like(M)
937 # NOTE that dk & dv always have the same number of heads as q
938 dq = torch.empty_like(query).contiguous()
939 dk = torch.empty(
940 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
941 device=key.device,
942 dtype=key.dtype,
943 memory_format=torch.contiguous_format,
944 )
945 dv = torch.empty(
946 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
947 device=value.device,
948 dtype=value.dtype,
949 memory_format=torch.contiguous_format,
950 )
952 _attn_bwd_preprocess[pre_grid](
953 o,
954 do, #
955 delta, #
956 BATCH,
957 Q_HEAD,
958 Q_CTX, #
959 BLOCK_M=PRE_BLOCK,
960 D_HEAD=BLOCK_DMODEL, #
961 )
963 grid = lambda meta: (
964 max(
965 triton.cdiv(
966 KV_CTX, meta["BLOCK_N1"]
967 ), # _attn_bwd_dq traverse the key-value sequence
968 triton.cdiv(
969 Q_CTX, meta["BLOCK_M2"]
970 ), # _attn_bwd_dkdv traverse the query sequence
971 ),
972 1,
973 BATCH * Q_HEAD,
974 )
975 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}")
976 # logger.info(f"{M.shape=}")
978 _attn_bwd[grid](
979 query,
980 arg_k,
981 value,
982 sm_scale,
983 do,
984 dq,
985 dk,
986 dv, #
987 M,
988 delta, #
989 query.stride(0),
990 query.stride(1),
991 query.stride(2),
992 query.stride(3), #
993 key.stride(0),
994 key.stride(1), #
995 dk.stride(0),
996 dk.stride(1),
997 dk.stride(2), #
998 Q_HEAD,
999 Q_CTX, #
1000 KV_CTX, #
1001 KV_HEAD, #
1002 GROUP_HEAD=group_head, #
1003 # BLOCK_M1=BLOCK_M1,
1004 # BLOCK_N1=BLOCK_N1, #
1005 # BLOCK_M2=BLOCK_M2,
1006 # BLOCK_N2=BLOCK_N2, #
1007 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
1008 BLOCK_DMODEL=BLOCK_DMODEL, #
1009 IS_CAUSAL=is_causal, #
1010 )
1012 if group_head > 1:
1013 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
1014 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
1015 dk = dk.sum(dim=2)
1016 dv = dv.sum(dim=2)
1018 return dq, dk, dv
1021class ScaleDotProductAttention(torch.autograd.Function):
1022 @staticmethod
1023 def forward(
1024 ctx,
1025 query,
1026 key,
1027 value,
1028 attn_mask=None,
1029 dropout_p=0.0,
1030 is_causal=False,
1031 scale=None,
1032 enable_gqa=False,
1033 ):
1034 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1035 o, M = scaled_dot_product_attention_forward(
1036 query,
1037 key,
1038 value,
1039 attn_mask,
1040 dropout_p,
1041 is_causal,
1042 sm_scale,
1043 enable_gqa,
1044 )
1046 ctx.save_for_backward(query, key, value, o, M)
1047 ctx.sm_scale = sm_scale
1048 ctx.causal = is_causal
1049 ctx.enable_gqa = enable_gqa
1050 return o
1052 @staticmethod
1053 def backward(ctx, do):
1054 query, key, value, o, M = ctx.saved_tensors
1055 is_causal = ctx.causal
1056 enable_gqa = ctx.enable_gqa
1057 sm_scale = ctx.sm_scale
1058 dq, dk, dv = scaled_dot_product_attention_backward(
1059 do,
1060 query,
1061 key,
1062 value,
1063 o,
1064 M,
1065 attn_mask=None,
1066 dropout_p=0.0,
1067 is_causal=is_causal,
1068 scale=sm_scale,
1069 enable_gqa=enable_gqa,
1070 )
1071 return dq, dk, dv, None, None, None, None, None
1074def scaled_dot_product_attention(
1075 query,
1076 key,
1077 value,
1078 attn_mask=None,
1079 dropout_p=0.0,
1080 is_causal=False,
1081 scale=None,
1082 enable_gqa=False,
1083):
1084 return ScaleDotProductAttention.apply(
1085 query,
1086 key,
1087 value,
1088 attn_mask,
1089 dropout_p,
1090 is_causal,
1091 scale,
1092 enable_gqa,
1093 )
1096def flash_attention_forward(
1097 query,
1098 key,
1099 value,
1100 cumulative_sequence_length_q,
1101 cumulative_sequence_length_k,
1102 max_q,
1103 max_k,
1104 dropout_p,
1105 is_causal,
1106 return_debug_mask,
1107 *,
1108 scale=None,
1109 softcap=0.0,
1110 window_size_left=None,
1111 window_size_right=None,
1112 seqused_k=None,
1113 alibi_slopes=None,
1114 disable_splitkv=False,
1115):
1116 logger.debug("GEMS FLASH_ATTENTION_FORWARD")
1117 assert (
1118 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1119 ), "varlen is not supported yet."
1121 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1122 HEAD_DIM_V = value.shape[-1]
1123 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1124 original_head_dim = HEAD_DIM_K
1125 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1126 if HEAD_DIM_K not in supported_head_dims:
1127 padded_head_dim = None
1128 for d in supported_head_dims:
1129 if d >= HEAD_DIM_K:
1130 padded_head_dim = d
1131 break
1132 assert (
1133 padded_head_dim is not None
1134 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1135 pad = padded_head_dim - HEAD_DIM_K
1136 query = F.pad(query, (0, pad))
1137 key = F.pad(key, (0, pad))
1138 value = F.pad(value, (0, pad))
1139 HEAD_DIM_K = padded_head_dim
1141 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1142 if window_size_left is not None:
1143 non_null_window_left = window_size_left
1144 else:
1145 non_null_window_left = -1
1146 if window_size_right is not None:
1147 non_null_window_right = window_size_right
1148 else:
1149 non_null_window_right = -1
1151 out = torch.empty_like(query)
1152 if cumulative_sequence_length_q is not None:
1153 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1154 query,
1155 key,
1156 value,
1157 out,
1158 cumulative_sequence_length_q,
1159 cumulative_sequence_length_k,
1160 seqused_k,
1161 None,
1162 None, # block_table
1163 alibi_slopes,
1164 max_q,
1165 max_k,
1166 dropout_p,
1167 scale,
1168 False,
1169 is_causal,
1170 non_null_window_left,
1171 non_null_window_right,
1172 softcap,
1173 return_debug_mask and dropout_p > 0,
1174 None,
1175 )
1176 else:
1177 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1178 query,
1179 key,
1180 value,
1181 out,
1182 alibi_slopes,
1183 dropout_p,
1184 softmax_scale,
1185 is_causal,
1186 non_null_window_left,
1187 non_null_window_right,
1188 softcap,
1189 return_debug_mask,
1190 disable_splitkv=disable_splitkv,
1191 )
1193 if HEAD_DIM_K != original_head_dim:
1194 out = out[..., :original_head_dim]
1195 return (out, lse, philox_seed, philox_offset, p)
1198# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1199def maybe_contiguous(x):
1200 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1203def flash_attn_varlen_func(
1204 q,
1205 k,
1206 v,
1207 max_seqlen_q,
1208 cu_seqlens_q,
1209 max_seqlen_k,
1210 cu_seqlens_k=None, # only used for non-paged prefill
1211 seqused_k=None,
1212 q_v=None,
1213 dropout_p=0.0,
1214 softmax_scale=None,
1215 causal=False,
1216 window_size=None,
1217 softcap=0.0, # 0.0 means deactivated
1218 alibi_slopes=None,
1219 deterministic=False,
1220 return_attn_probs=False,
1221 block_table=None,
1222 return_softmax_lse=False,
1223 out=None,
1224 # Dummy FA3 arguments
1225 scheduler_metadata=None,
1226 q_descale=None,
1227 k_descale=None,
1228 v_descale=None,
1229 s_aux=None,
1230 num_splits: int = 0,
1231 cp_world_size: int = 1,
1232 cp_rank: int = 0,
1233 cp_tot_seqused_k=None,
1234 fa_version: int = 2,
1235):
1236 """dropout_p should be set to 0.0 during evaluation
1237 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1238 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1239 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1240 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1242 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1243 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1244 1 1 1 1 0
1245 1 1 1 1 1
1246 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1247 0 0
1248 0 0
1249 0 0
1250 1 0
1251 1 1
1252 If the row of the mask is all zero, the output will be zero.
1254 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1255 will only attend to keys between
1256 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1258 Arguments:
1259 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1260 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1261 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1262 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1263 of the sequences in the batch, used to index into q.
1264 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1265 of the sequences in the batch, used to index into kv.
1266 max_seqlen_q: int. Maximum query sequence length in the batch.
1267 max_seqlen_k: int. Maximum key sequence length in the batch.
1268 dropout_p: float. Dropout probability.
1269 softmax_scale: float. The scaling of QK^T before applying softmax.
1270 Default to 1 / sqrt(headdim).
1271 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1272 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1273 softcap: float. Anything > 0 activates softcapping attention.
1274 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1275 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1276 is added to the attention score of query i and key j.
1277 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1278 which is slightly slower and uses more memory. The forward pass is always deterministic.
1279 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1280 testing only. The returned probabilities are not guaranteed to be correct
1281 (they might not have the right scaling).
1282 Return:
1283 out: (total, nheads, headdim).
1284 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1285 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1286 normalization factor).
1287 """
1288 if fa_version != 2:
1289 raise RuntimeError("Only FA2 is implemented.")
1290 if num_splits > 0:
1291 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1292 if use_c_extension:
1293 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1294 with torch_device_fn.device(q.device):
1295 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1296 q,
1297 k,
1298 v,
1299 max_seqlen_q,
1300 cu_seqlens_q,
1301 max_seqlen_k,
1302 cu_seqlens_k,
1303 seqused_k,
1304 q_v,
1305 dropout_p,
1306 softmax_scale,
1307 causal,
1308 window_size,
1309 softcap,
1310 alibi_slopes,
1311 deterministic,
1312 return_attn_probs,
1313 block_table,
1314 return_softmax_lse,
1315 out,
1316 scheduler_metadata,
1317 q_descale,
1318 k_descale,
1319 v_descale,
1320 s_aux,
1321 num_splits,
1322 cp_world_size,
1323 cp_rank,
1324 cp_tot_seqused_k,
1325 fa_version,
1326 )
1327 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1328 else:
1329 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC")
1330 assert (
1331 cu_seqlens_k is not None or seqused_k is not None
1332 ), "cu_seqlens_k or seqused_k must be provided"
1333 assert (
1334 cu_seqlens_k is None or seqused_k is None
1335 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1336 assert (
1337 block_table is None or seqused_k is not None
1338 ), "seqused_k must be provided if block_table is provided"
1339 if softmax_scale is None:
1340 softmax_scale = q.shape[-1] ** (-0.5)
1341 # custom op does not support non-tuple input
1342 if window_size is None:
1343 real_window_size = (-1, -1)
1344 else:
1345 assert len(window_size) == 2
1346 real_window_size = (window_size[0], window_size[1])
1347 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1348 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1349 max_seqlen_q = (
1350 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1351 )
1352 max_seqlen_k = (
1353 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1354 )
1355 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1356 q,
1357 k,
1358 v,
1359 out,
1360 cu_seqlens_q,
1361 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1362 # still wants it so we pass all zeros
1363 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1364 seqused_k,
1365 None,
1366 block_table,
1367 alibi_slopes,
1368 max_seqlen_q,
1369 max_seqlen_k,
1370 dropout_p,
1371 softmax_scale,
1372 False,
1373 causal,
1374 real_window_size[0],
1375 real_window_size[1],
1376 softcap,
1377 return_softmax_lse and dropout_p > 0,
1378 None,
1379 )
1381 return (out, softmax_lse) if return_softmax_lse else out
1384def flash_attn_varlen_opt_func(
1385 q,
1386 k,
1387 v,
1388 max_seqlen_q,
1389 cu_seqlens_q,
1390 max_seqlen_k,
1391 cu_seqlens_k=None, # only used for non-paged prefill
1392 seqused_k=None,
1393 q_v=None,
1394 dropout_p=0.0,
1395 softmax_scale=None,
1396 causal=False,
1397 window_size=None,
1398 softcap=0.0, # 0.0 means deactivated
1399 alibi_slopes=None,
1400 deterministic=False,
1401 return_attn_probs=False,
1402 block_table=None,
1403 return_softmax_lse=False,
1404 out=None,
1405 lse=None,
1406 # Dummy FA3 arguments
1407 scheduler_metadata=None,
1408 q_descale=None,
1409 k_descale=None,
1410 v_descale=None,
1411 s_aux=None,
1412 num_splits: int = 0,
1413 cp_world_size: int = 1,
1414 cp_rank: int = 0,
1415 cp_tot_seqused_k=None,
1416 fa_version: int = 2,
1417):
1418 """dropout_p should be set to 0.0 during evaluation
1419 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1420 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1421 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1422 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1424 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1425 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1426 1 1 1 1 0
1427 1 1 1 1 1
1428 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1429 0 0
1430 0 0
1431 0 0
1432 1 0
1433 1 1
1434 If the row of the mask is all zero, the output will be zero.
1436 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1437 will only attend to keys between
1438 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1440 Arguments:
1441 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1442 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1443 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1444 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1445 of the sequences in the batch, used to index into q.
1446 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1447 of the sequences in the batch, used to index into kv.
1448 max_seqlen_q: int. Maximum query sequence length in the batch.
1449 max_seqlen_k: int. Maximum key sequence length in the batch.
1450 dropout_p: float. Dropout probability.
1451 softmax_scale: float. The scaling of QK^T before applying softmax.
1452 Default to 1 / sqrt(headdim).
1453 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1454 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1455 softcap: float. Anything > 0 activates softcapping attention.
1456 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1457 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1458 is added to the attention score of query i and key j.
1459 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1460 which is slightly slower and uses more memory. The forward pass is always deterministic.
1461 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1462 testing only. The returned probabilities are not guaranteed to be correct
1463 (they might not have the right scaling).
1464 Return:
1465 out: (total, nheads, headdim).
1466 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1467 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1468 normalization factor).
1469 """
1470 if fa_version != 2:
1471 raise RuntimeError("Only FA2 is implemented.")
1472 if num_splits > 0:
1473 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1474 if use_c_extension:
1475 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1476 with torch_device_fn.device(q.device):
1477 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1478 q,
1479 k,
1480 v,
1481 max_seqlen_q,
1482 cu_seqlens_q,
1483 max_seqlen_k,
1484 cu_seqlens_k,
1485 seqused_k,
1486 q_v,
1487 dropout_p,
1488 softmax_scale,
1489 causal,
1490 window_size,
1491 softcap,
1492 alibi_slopes,
1493 deterministic,
1494 return_attn_probs,
1495 block_table,
1496 return_softmax_lse,
1497 out,
1498 scheduler_metadata,
1499 q_descale,
1500 k_descale,
1501 v_descale,
1502 s_aux,
1503 num_splits,
1504 cp_world_size,
1505 cp_rank,
1506 cp_tot_seqused_k,
1507 fa_version,
1508 )
1509 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1510 else:
1511 logger.debug("GEMS FLASH_ATTN_VARLEN_OPT_FUNC")
1512 assert (
1513 cu_seqlens_k is not None or seqused_k is not None
1514 ), "cu_seqlens_k or seqused_k must be provided"
1515 assert (
1516 cu_seqlens_k is None or seqused_k is None
1517 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1518 assert (
1519 block_table is None or seqused_k is not None
1520 ), "seqused_k must be provided if block_table is provided"
1521 if softmax_scale is None:
1522 softmax_scale = q.shape[-1] ** (-0.5)
1523 # custom op does not support non-tuple input
1524 if window_size is None:
1525 real_window_size = (-1, -1)
1526 else:
1527 assert len(window_size) == 2
1528 real_window_size = (window_size[0], window_size[1])
1529 # q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1530 # dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1531 max_seqlen_q = (
1532 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1533 )
1534 max_seqlen_k = (
1535 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1536 )
1537 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd_opt(
1538 q,
1539 k,
1540 v,
1541 out,
1542 lse,
1543 cu_seqlens_q,
1544 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1545 # still wants it so we pass all zeros
1546 # dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1547 cu_seqlens_q if cu_seqlens_k is None else cu_seqlens_k,
1548 seqused_k,
1549 None,
1550 block_table,
1551 alibi_slopes,
1552 max_seqlen_q,
1553 max_seqlen_k,
1554 dropout_p,
1555 softmax_scale,
1556 False,
1557 causal,
1558 real_window_size[0],
1559 real_window_size[1],
1560 softcap,
1561 return_softmax_lse and dropout_p > 0,
1562 None,
1563 )
1565 return (out, softmax_lse) if return_softmax_lse else out