Coverage for src/flag_gems/runtime/backend/_sunrise/ops/attention.py: 0%
440 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.runtime import torch_device_fn
13from flag_gems.runtime.backend._sunrise.ops.flash_api import (
14 mha_fwd,
15 mha_varlan_fwd,
16 mha_varlan_fwd_opt,
17)
18from flag_gems.runtime.backend._sunrise.ops.flash_kernel import keep
19from flag_gems.utils import libentry, libtuner
21logger = logging.getLogger(__name__)
24# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
25@triton.jit
26def _attn_fwd_inner(
27 acc,
28 l_i,
29 m_i,
30 query, #
31 K_block_ptr,
32 V_block_ptr, #
33 mask_block_ptr, #
34 stride_k_seqlen,
35 stride_v_seqlen,
36 stride_attn_mask_kv_seqlen, #
37 start_m,
38 qk_scale, #
39 q_load_mask,
40 BLOCK_M: tl.constexpr,
41 HEAD_DIM: tl.constexpr,
42 BLOCK_N: tl.constexpr, #
43 STAGE: tl.constexpr,
44 offs_m: tl.constexpr,
45 offs_n: tl.constexpr, #
46 KV_CTX: tl.constexpr,
47 fp8_v: tl.constexpr,
48 HAS_ATTN_MASK: tl.constexpr,
49 PRE_LOAD_V: tl.constexpr,
50):
51 # range of values handled by this stage
52 if STAGE == 1:
53 lo, hi = 0, start_m * BLOCK_M
54 elif STAGE == 2:
55 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
56 # causal = False
57 else:
58 lo, hi = 0, KV_CTX
60 K_block_ptr += lo * stride_k_seqlen
61 V_block_ptr += lo * stride_v_seqlen
62 if HAS_ATTN_MASK:
63 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
65 LOG2E = 1.44269504 # log2(e) constant
67 # loop over key, value and update accumulator
68 for start_n in range(lo, hi, BLOCK_N):
69 kv_load_mask = (start_n + offs_n) < KV_CTX
70 # start_n = tl.multiple_of(start_n, BLOCK_N)
71 # -- compute qk ----
72 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
73 if PRE_LOAD_V:
74 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
76 qk = tl.dot(query, key, allow_tf32=False)
77 # incase not divisible.
78 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
79 # qk = qk.to(tl.float32)
81 if HAS_ATTN_MASK:
82 attn_mask = tl.load(
83 mask_block_ptr,
84 mask=q_load_mask[:, None] & kv_load_mask[None, :],
85 other=0.0,
86 )
88 if STAGE == 2:
89 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
91 if HAS_ATTN_MASK:
92 qk = qk * qk_scale + attn_mask
93 qk *= LOG2E
94 qk = qk + tl.where(mask, 0, -1.0e6)
95 else:
96 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
98 m_ij = tl.maximum(m_i, tl.max(qk, 1))
99 qk -= m_ij[:, None]
100 else:
101 qk *= qk_scale * LOG2E
102 if HAS_ATTN_MASK:
103 qk = qk + attn_mask
104 m_ij = tl.maximum(m_i, tl.max(qk, 1))
105 qk = qk - m_ij[:, None]
107 p = tl.math.exp2(qk)
108 l_ij = tl.sum(p, 1)
109 # -- update m_i and l_i
110 alpha = tl.math.exp2(m_i - m_ij)
111 l_i = l_i * alpha + l_ij
112 # -- update output accumulator --
113 acc = acc * alpha[:, None]
114 # update acc
115 if not PRE_LOAD_V:
116 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
117 if fp8_v:
118 p = p.to(tl.float8e5)
119 else:
120 p = p.to(query.dtype)
121 p = p.to(value.dtype)
122 acc = tl.dot(p, value, acc, allow_tf32=False)
123 # update m_i and l_i
124 m_i = m_ij
126 K_block_ptr += BLOCK_N * stride_k_seqlen
127 V_block_ptr += BLOCK_N * stride_v_seqlen
129 if HAS_ATTN_MASK:
130 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
132 return acc, l_i, m_i
135# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
136# we need to generate more configs.
137configs = runtime.get_tuned_config("attention")
138SMALL_HEAD_DIM_CONFIGS = [
139 triton.Config(
140 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
141 )
142 for BM in [16, 32]
143 for BN in [8, 16]
144 for s in [0]
145 for w in [8, 16]
146]
147# configs += SMALL_HEAD_DIM_CONFIGS
148configs = SMALL_HEAD_DIM_CONFIGS
151@libentry()
152@libtuner(
153 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
154 key=["KV_CTX", "HEAD_DIM"],
155)
156@triton.jit
157def _attn_fwd(
158 Q,
159 K,
160 V,
161 attn_mask,
162 sm_scale,
163 M,
164 Out, #
165 stride_q_batch,
166 stride_q_head,
167 stride_q_seqlen,
168 stride_q_headsize,
169 stride_k_batch,
170 stride_k_head,
171 stride_k_seqlen,
172 stride_k_headsize,
173 stride_v_batch,
174 stride_v_head,
175 stride_v_seqlen,
176 stride_v_headsize,
177 stride_attn_mask_batch,
178 stride_attn_mask_head,
179 stride_attn_mask_q_seqlen,
180 stride_attn_mask_kv_seqlen,
181 stride_o_batch,
182 stride_o_head,
183 stride_o_seqlen,
184 stride_o_headsize,
185 Z,
186 q_head_num,
187 kv_head_num,
188 GROUP_HEAD: tl.constexpr,
189 Q_CTX,
190 KV_CTX,
191 HEAD_DIM: tl.constexpr,
192 BLOCK_M: tl.constexpr,
193 BLOCK_N: tl.constexpr,
194 STAGE: tl.constexpr,
195 HAS_ATTN_MASK: tl.constexpr,
196 PRE_LOAD_V: tl.constexpr,
197):
198 tl.static_assert(BLOCK_N <= HEAD_DIM)
199 start_m = tl.program_id(0)
200 off_hz = tl.program_id(1)
201 batch_id = off_hz // q_head_num
202 head_id = off_hz % q_head_num
203 kv_head_id = head_id // GROUP_HEAD
205 q_offset = (
206 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
207 )
208 o_offset = (
209 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
210 )
211 kv_offset = (
212 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
213 )
215 offs_headsize = tl.arange(0, HEAD_DIM)
217 # initialize offsets
218 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
219 q_load_mask = offs_m < Q_CTX
220 offs_n = tl.arange(0, BLOCK_N)
222 Q_block_ptr = (
223 Q
224 + q_offset
225 + offs_m[:, None] * stride_q_seqlen
226 + offs_headsize[None, :] * stride_q_headsize
227 )
228 K_block_ptr = (
229 K
230 + kv_offset
231 + offs_n[None, :] * stride_k_seqlen
232 + offs_headsize[:, None] * stride_k_headsize
233 )
234 V_block_ptr = (
235 V
236 + kv_offset
237 + offs_n[:, None] * stride_v_seqlen
238 + offs_headsize[None, :] * stride_v_headsize
239 )
241 if HAS_ATTN_MASK:
242 attn_mask_offset = (
243 batch_id.to(tl.int64) * stride_attn_mask_batch
244 + head_id.to(tl.int64) * stride_attn_mask_head
245 )
246 mask_block_ptr = (
247 attn_mask
248 + attn_mask_offset
249 + offs_m[:, None] * stride_attn_mask_q_seqlen
250 + offs_n[None, :] * stride_attn_mask_kv_seqlen
251 )
252 else:
253 mask_block_ptr = None
255 O_block_ptr = (
256 Out
257 + o_offset
258 + offs_m[:, None] * stride_o_seqlen
259 + offs_headsize[None, :] * stride_o_headsize
260 )
262 # initialize pointer to m and l
263 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
264 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
265 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
266 # load scales
267 qk_scale = sm_scale
268 # qk_scale *= 1.44269504 # 1/log(2)
269 # load query: it will stay in SRAM throughout
270 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
271 # stage 1: off-band
272 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
273 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
274 if STAGE & 1:
275 acc, l_i, m_i = _attn_fwd_inner(
276 acc,
277 l_i,
278 m_i,
279 query,
280 K_block_ptr,
281 V_block_ptr,
282 mask_block_ptr,
283 stride_k_seqlen,
284 stride_v_seqlen,
285 stride_attn_mask_kv_seqlen,
286 start_m,
287 qk_scale,
288 q_load_mask,
289 BLOCK_M,
290 HEAD_DIM,
291 BLOCK_N,
292 4 - STAGE,
293 offs_m,
294 offs_n,
295 KV_CTX,
296 V.dtype.element_ty == tl.float8e5,
297 HAS_ATTN_MASK,
298 PRE_LOAD_V,
299 )
300 # stage 2: on-band
301 if STAGE & 2:
302 # barrier makes it easier for compielr to schedule the
303 # two loops independently
304 acc, l_i, m_i = _attn_fwd_inner(
305 acc,
306 l_i,
307 m_i,
308 query,
309 K_block_ptr,
310 V_block_ptr,
311 mask_block_ptr,
312 stride_k_seqlen,
313 stride_v_seqlen,
314 stride_attn_mask_kv_seqlen,
315 start_m,
316 qk_scale,
317 q_load_mask,
318 BLOCK_M,
319 HEAD_DIM,
320 BLOCK_N,
321 2,
322 offs_m,
323 offs_n,
324 KV_CTX,
325 V.dtype.element_ty == tl.float8e5,
326 HAS_ATTN_MASK,
327 PRE_LOAD_V,
328 )
329 # epilogue
330 m_i += tl.math.log2(l_i)
331 acc = acc / l_i[:, None]
332 m_ptrs = M + off_hz * Q_CTX + offs_m
333 tl.store(m_ptrs, m_i, mask=q_load_mask)
334 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
337@triton.jit
338def _attn_bwd_preprocess(
339 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
340):
341 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
342 mask = off_m < Q_CTX
344 off_hz = tl.program_id(1)
345 off_n = tl.arange(0, D_HEAD)
346 # load
347 o = tl.load(
348 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
349 mask=mask[:, None],
350 other=0.0,
351 )
352 do = tl.load(
353 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
354 mask=mask[:, None],
355 other=0.0,
356 ).to(tl.float32)
357 delta = tl.sum(o * do, axis=1)
358 # write-back
359 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
362# The main inner-loop logic for computing dK and dV.
363@triton.jit
364def _attn_bwd_dkdv(
365 dk,
366 dv, #
367 Q,
368 key,
369 value,
370 sm_scale, #
371 DO, #
372 M,
373 D, #
374 # shared by Q/K/V/DO.
375 stride_tok,
376 stride_d, #
377 H,
378 Q_CTX,
379 KV_CTX,
380 BLOCK_M1: tl.constexpr, #
381 BLOCK_N1: tl.constexpr, #
382 BLOCK_DMODEL: tl.constexpr, #
383 # Filled in by the wrapper.
384 start_n,
385 start_m,
386 num_steps, #
387 MASK: tl.constexpr,
388):
389 # BLOCK_M1: 32
390 # BLOCK_N1: 128
391 offs_n = start_n + tl.arange(0, BLOCK_N1)
392 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
394 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
396 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
397 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
398 curr_m = start_m
399 step_m = BLOCK_M1
400 for blk_idx in range(num_steps):
401 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
402 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
404 qT_ptrs = (
405 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
406 ) # (BLOCK_DMODEL, BLOCK_M1)
407 do_ptrs = (
408 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
409 ) # (BLOCK_M1, BLOCK_DMODEL)
411 qT = tl.load(
412 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
413 ) # (BLOCK_DMODEL, BLOCK_M1)
415 # Load m before computing qk to reduce pipeline stall.
416 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
418 # key: (BLOCK_N1, BLOCK_DMODEL)
419 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
420 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
421 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
422 pT = tl.math.exp2(qkT - m)
423 # pT = tl.math.exp2(qkT - m[None, :])
425 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
426 :, None
427 ] # (BLOCK_N1, BLOCK_M1)
428 # Autoregressive masking.
429 if MASK:
430 mask &= offs_m[None, :] >= offs_n[:, None]
431 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
433 do = tl.load(do_ptrs)
434 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL)
436 # Compute dV.
437 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
438 # D (= delta) is pre-divided by ds_scale.
439 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
441 # Compute dP and dS.
442 dpT = tl.dot(value, tl.trans(do)).to(
443 tl.float32
444 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
445 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
446 dsT = dsT.to(qT.dtype)
447 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
448 dsT = tl.where(
449 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
450 ) # (BLOCK_N1, BLOCK_M1)
451 dk += tl.dot(
452 dsT, tl.trans(qT)
453 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
454 # Increment pointers.
455 curr_m += step_m
456 return dk, dv
459# the main inner-loop logic for computing dQ
460@triton.jit
461def _attn_bwd_dq(
462 dq,
463 query,
464 K,
465 V, #
466 do,
467 m,
468 D,
469 # shared by Q/K/V/DO.
470 stride_tok,
471 stride_d, #
472 H,
473 Q_CTX, #
474 KV_CTX, #
475 BLOCK_M2: tl.constexpr, #
476 BLOCK_N2: tl.constexpr, #
477 BLOCK_DMODEL: tl.constexpr,
478 # Filled in by the wrapper.
479 start_m,
480 start_n,
481 num_steps, #
482 MASK: tl.constexpr,
483):
484 offs_m = start_m + tl.arange(0, BLOCK_M2)
485 offs_m_mask = offs_m < Q_CTX
487 offs_k = tl.arange(0, BLOCK_DMODEL)
488 # D (= delta) is pre-divided by ds_scale.
489 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
490 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
491 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
492 curr_n = start_n
493 step_n = BLOCK_N2
494 for blk_idx in range(num_steps):
495 offs_n = curr_n + tl.arange(0, BLOCK_N2)
496 offs_n_mask = offs_n < KV_CTX
498 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
499 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
501 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
502 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
503 qk = tl.dot(query, kT)
504 p = tl.math.exp2(qk - m)
505 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
506 # Autoregressive masking.
507 if MASK:
508 # mask = (offs_m[:, None] >= offs_n[None, :])
509 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
510 mask &= offs_m[:, None] >= offs_n[None, :]
511 p = tl.where(mask, p, 0.0)
512 # Compute dP and dS.
513 dp = tl.dot(do, vT).to(tl.float32)
514 ds = p * (dp - Di[:, None])
515 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
516 # Compute dQ.
517 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
518 dq += tl.dot(ds, tl.trans(kT))
519 # Increment pointers.
520 curr_n += step_n
521 return dq
524config_backward = runtime.get_tuned_config("attention_bwd")
527@libentry()
528@libtuner(
529 configs=config_backward,
530 key=["KV_CTX", "BLOCK_DMODEL"],
531)
532@triton.jit
533def _attn_bwd(
534 Q,
535 K,
536 V,
537 sm_scale, #
538 DO, #
539 DQ,
540 DK,
541 DV, #
542 M,
543 D,
544 # shared by Q/K/V/DO.
545 stride_z,
546 stride_h,
547 stride_tok,
548 stride_d, #
549 kv_stride_z,
550 kv_stride_h, #
551 H, # query head num
552 Q_CTX, #
553 KV_CTX, #
554 kv_head_num, #
555 GROUP_HEAD: tl.constexpr, #
556 BLOCK_M1: tl.constexpr, #
557 BLOCK_N1: tl.constexpr, #
558 BLOCK_M2: tl.constexpr, #
559 BLOCK_N2: tl.constexpr, #
560 BLK_SLICE_FACTOR: tl.constexpr, #
561 BLOCK_DMODEL: tl.constexpr,
562):
563 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.")
565 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
567 bhid = tl.program_id(2)
568 off_chz = (bhid * Q_CTX).to(tl.int64)
569 batch_id = bhid // H
570 q_head_id = bhid % H
571 kv_head_id = q_head_id // GROUP_HEAD
572 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
573 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
575 pid = tl.program_id(0)
577 # offset pointers for batch/head
578 Q += adj
579 K += kv_adj
580 V += kv_adj
581 DO += adj
582 DQ += adj
583 DK += adj
584 DV += adj
585 M += off_chz
586 D += off_chz
588 # load scales
589 offs_k = tl.arange(0, BLOCK_DMODEL)
591 start_n = pid * BLOCK_N1
592 start_m = start_n
594 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
595 offs_n = start_n + tl.arange(0, BLOCK_N1)
596 offs_n_mask = offs_n < KV_CTX
598 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
599 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
601 # load K and V: they stay in SRAM throughout the inner loop.
602 key = tl.load(
603 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
604 mask=offs_n_mask[:, None],
605 other=0.0,
606 )
607 value = tl.load(
608 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
609 mask=offs_n_mask[:, None],
610 other=0.0,
611 )
613 num_steps = BLOCK_N1 // MASK_BLOCK_M1
615 dk, dv = _attn_bwd_dkdv(
616 dk,
617 dv, #
618 Q,
619 key,
620 value,
621 sm_scale, #
622 DO, #
623 M,
624 D, #
625 stride_tok,
626 stride_d, #
627 H,
628 Q_CTX, #
629 KV_CTX, #
630 MASK_BLOCK_M1,
631 BLOCK_N1,
632 BLOCK_DMODEL, #
633 start_n,
634 start_m,
635 num_steps, #
636 MASK=True, #
637 )
639 # Compute dK and dV for non-masked blocks.
640 start_m += num_steps * MASK_BLOCK_M1
641 remaining_m = Q_CTX - start_m
642 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
644 if num_steps > 0 and start_m < Q_CTX:
645 dk, dv = _attn_bwd_dkdv( #
646 dk,
647 dv, #
648 Q,
649 key,
650 value,
651 sm_scale, #
652 DO, #
653 M,
654 D, #
655 stride_tok,
656 stride_d, #
657 H,
658 Q_CTX, #
659 KV_CTX, #
660 BLOCK_M1,
661 BLOCK_N1,
662 BLOCK_DMODEL, #
663 start_n,
664 start_m,
665 num_steps, #
666 MASK=False, #
667 )
668 # tl.device_print("dv: ", dv)
670 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
671 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
673 # Write back dK.
674 dk *= sm_scale
675 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
676 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
678 # THIS BLOCK DOES DQ:
679 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
680 start_m = pid * BLOCK_M2
681 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX
682 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
684 offs_m = start_m + tl.arange(0, BLOCK_M2)
685 offs_m_mask = offs_m < Q_CTX
687 query = tl.load(
688 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
689 mask=offs_m_mask[:, None],
690 other=0.0,
691 )
692 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
693 do = tl.load(
694 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
695 mask=offs_m_mask[:, None],
696 other=0.0,
697 )
699 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
700 m = m[:, None]
702 # Stage 1 - Compute dQ for masked (diagonal) blocks.
703 # NOTE: This code scans each row of QK^T backward (from right to left,
704 # but inside each call to _attn_bwd_dq, from left to right), but that's
705 # not due to anything important. I just wanted to reuse the loop
706 # structure for dK & dV above as much as possible.
708 if num_steps > 0:
709 dq = _attn_bwd_dq(
710 dq,
711 query,
712 K,
713 V, #
714 do,
715 m,
716 D, #
717 stride_tok,
718 stride_d, #
719 H,
720 Q_CTX, #
721 KV_CTX, #
722 BLOCK_M2,
723 MASK_BLOCK_N2,
724 BLOCK_DMODEL, #
725 start_m,
726 start_n,
727 num_steps, #
728 MASK=True, #
729 )
731 # Stage 2 - non-masked blocks
732 stage2_end_n = start_n
733 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2
735 if stage2_num_steps > 0:
736 dq = _attn_bwd_dq(
737 dq,
738 query,
739 K,
740 V, #
741 do,
742 m,
743 D, #
744 stride_tok,
745 stride_d, #
746 H,
747 Q_CTX, #
748 KV_CTX, #
749 BLOCK_M2,
750 BLOCK_N2,
751 BLOCK_DMODEL, #
752 start_m,
753 stage2_end_n - stage2_num_steps * BLOCK_N2,
754 stage2_num_steps, #
755 MASK=False, #
756 )
757 # Write back dQ.
758 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
759 dq *= LN2
760 # tl.store(dq_ptrs, dq)
762 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
765def scaled_dot_product_attention_forward(
766 query,
767 key,
768 value,
769 attn_mask=None,
770 dropout_p=0.0,
771 is_causal=False,
772 scale=None,
773 enable_gqa=False,
774):
775 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD")
776 # shape constraints
777 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
778 # when v is in float8_e5m2 it is transposed.
779 HEAD_DIM_V = value.shape[-1]
780 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
781 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
782 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
784 o = torch.empty_like(query, dtype=value.dtype)
786 stage = 3 if is_causal else 1
788 if scale is None:
789 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
790 else:
791 sm_scale = scale
793 q_head_num = query.shape[1]
794 kv_head_num = key.shape[1]
795 assert enable_gqa or q_head_num == kv_head_num, (
796 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
797 "enable_gqa must be True to support different head numbers."
798 )
800 grid = lambda args: (
801 triton.cdiv(query.shape[2], args["BLOCK_M"]),
802 query.shape[0] * query.shape[1],
803 1,
804 )
806 if attn_mask is not None:
807 HAS_ATTN_MASK = True
808 if attn_mask.dtype == torch.bool:
809 attn_mask = attn_mask.to(query.dtype) * -1.0e6
810 stride_attn_mask_batch = attn_mask.stride(0)
811 stride_attn_mask_head = attn_mask.stride(1)
812 stride_attn_mask_q_seqlen = attn_mask.stride(2)
813 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
814 else:
815 HAS_ATTN_MASK = False
816 stride_attn_mask_batch = 1
817 stride_attn_mask_head = 1
818 stride_attn_mask_q_seqlen = 1
819 stride_attn_mask_kv_seqlen = 1
821 M = torch.empty(
822 (query.shape[0], query.shape[1], query.shape[2]),
823 device=query.device,
824 dtype=torch.float32,
825 )
827 with torch_device_fn.device(query.device):
828 _attn_fwd[grid](
829 query,
830 key,
831 value,
832 attn_mask,
833 sm_scale,
834 M,
835 o, #
836 query.stride(0),
837 query.stride(1),
838 query.stride(2),
839 query.stride(3), #
840 key.stride(0),
841 key.stride(1),
842 key.stride(2),
843 key.stride(3), #
844 value.stride(0),
845 value.stride(1),
846 value.stride(2),
847 value.stride(3), #
848 stride_attn_mask_batch,
849 stride_attn_mask_head,
850 stride_attn_mask_q_seqlen,
851 stride_attn_mask_kv_seqlen, #
852 o.stride(0),
853 o.stride(1),
854 o.stride(2),
855 o.stride(3), #
856 query.shape[0],
857 q_head_num,
858 kv_head_num, #
859 q_head_num // kv_head_num, # group_head
860 query.shape[2], #
861 key.shape[2], #
862 HEAD_DIM_K, #
863 STAGE=stage, #
864 HAS_ATTN_MASK=HAS_ATTN_MASK, #
865 )
866 return o, M
869def scaled_dot_product_attention_backward(
870 do,
871 query,
872 key,
873 value,
874 o,
875 M,
876 attn_mask=None,
877 dropout_p=0.0,
878 is_causal=False,
879 scale=None,
880 enable_gqa=False,
881):
882 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD")
883 # shape constraints
884 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
885 # when v is in float8_e5m2 it is transposed.
886 HEAD_DIM_V = value.shape[-1]
887 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
888 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
889 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
891 if scale is None:
892 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
893 else:
894 sm_scale = scale
896 assert do.is_contiguous()
897 assert (
898 query.is_contiguous()
899 and key.is_contiguous()
900 and value.is_contiguous()
901 and o.is_contiguous()
902 )
903 assert query.stride() == o.stride() == do.stride()
904 assert key.stride() == value.stride()
906 BLOCK_DMODEL = HEAD_DIM_K
907 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
908 _, KV_HEAD, KV_CTX = key.shape[:3]
909 group_head = Q_HEAD // KV_HEAD
911 # NUM_WARPS, NUM_STAGES = 4, 1
912 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
913 BLK_SLICE_FACTOR = 2
914 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
916 RCP_LN2 = 1.0 / math.log(2)
918 arg_k = key * (sm_scale * RCP_LN2)
919 # PRE_BLOCK = 128
920 PRE_BLOCK = 256
922 # PRE_BLOCK = 32
923 # assert N_CTX % PRE_BLOCK == 0
924 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
925 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
927 delta = torch.empty_like(M)
929 # NOTE that dk & dv always have the same number of heads as q
930 dq = torch.empty_like(query).contiguous()
931 dk = torch.empty(
932 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
933 device=key.device,
934 dtype=key.dtype,
935 memory_format=torch.contiguous_format,
936 )
937 dv = torch.empty(
938 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
939 device=value.device,
940 dtype=value.dtype,
941 memory_format=torch.contiguous_format,
942 )
944 _attn_bwd_preprocess[pre_grid](
945 o,
946 do, #
947 delta, #
948 BATCH,
949 Q_HEAD,
950 Q_CTX, #
951 BLOCK_M=PRE_BLOCK,
952 D_HEAD=BLOCK_DMODEL, #
953 )
955 max_block_n1 = (
956 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward])
957 if config_backward
958 else 128
959 )
960 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD)
961 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}")
962 # logger.info(f"{M.shape=}")
964 _attn_bwd[grid](
965 query,
966 arg_k,
967 value,
968 sm_scale,
969 do,
970 dq,
971 dk,
972 dv, #
973 M,
974 delta, #
975 query.stride(0),
976 query.stride(1),
977 query.stride(2),
978 query.stride(3), #
979 key.stride(0),
980 key.stride(1), #
981 Q_HEAD,
982 Q_CTX, #
983 KV_CTX, #
984 KV_HEAD, #
985 GROUP_HEAD=group_head, #
986 # BLOCK_M1=BLOCK_M1,
987 # BLOCK_N1=BLOCK_N1, #
988 # BLOCK_M2=BLOCK_M2,
989 # BLOCK_N2=BLOCK_N2, #
990 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
991 BLOCK_DMODEL=BLOCK_DMODEL, #
992 # num_warps=NUM_WARPS, #
993 # num_stages=NUM_STAGES, #
994 )
996 if group_head > 1:
997 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
998 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
999 dk = dk.sum(dim=2)
1000 dv = dv.sum(dim=2)
1002 return dq, dk, dv
1005class ScaleDotProductAttention(torch.autograd.Function):
1006 @staticmethod
1007 def forward(
1008 ctx,
1009 query,
1010 key,
1011 value,
1012 attn_mask=None,
1013 dropout_p=0.0,
1014 is_causal=False,
1015 scale=None,
1016 enable_gqa=False,
1017 ):
1018 # [sunrise fix] padding for unsupported head dims, since auto lower unsupported.
1019 head_size = key.shape[-1]
1020 supported_head_dims = {16, 32, 64, 128, 256}
1021 if head_size not in supported_head_dims:
1022 padded_head_dim = None
1023 for d in supported_head_dims:
1024 if d >= head_size:
1025 padded_head_dim = d
1026 break
1027 assert (
1028 padded_head_dim is not None
1029 ), f"Unsupported head dim {head_size}, max supported is {supported_head_dims[-1]}"
1030 pad = padded_head_dim - head_size
1031 query = F.pad(query, (0, pad))
1032 key = F.pad(key, (0, pad))
1033 value = F.pad(value, (0, pad))
1035 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1036 o, M = scaled_dot_product_attention_forward(
1037 query,
1038 key,
1039 value,
1040 attn_mask,
1041 dropout_p,
1042 is_causal,
1043 sm_scale,
1044 enable_gqa,
1045 )
1046 # [sunrise fix] padding for unsupported head dims, since auto lower unsupported.
1047 if head_size not in supported_head_dims:
1048 o = o[..., :head_size]
1050 ctx.save_for_backward(query, key, value, o, M)
1051 ctx.sm_scale = sm_scale
1052 ctx.causal = is_causal
1053 ctx.enable_gqa = enable_gqa
1054 return o
1056 @staticmethod
1057 def backward(ctx, do):
1058 query, key, value, o, M = ctx.saved_tensors
1059 is_causal = ctx.causal
1060 enable_gqa = ctx.enable_gqa
1061 sm_scale = ctx.sm_scale
1062 dq, dk, dv = scaled_dot_product_attention_backward(
1063 do,
1064 query,
1065 key,
1066 value,
1067 o,
1068 M,
1069 attn_mask=None,
1070 dropout_p=0.0,
1071 is_causal=is_causal,
1072 scale=sm_scale,
1073 enable_gqa=enable_gqa,
1074 )
1075 return dq, dk, dv, None, None, None, None, None
1078def scaled_dot_product_attention(
1079 query,
1080 key,
1081 value,
1082 attn_mask=None,
1083 dropout_p=0.0,
1084 is_causal=False,
1085 scale=None,
1086 enable_gqa=False,
1087):
1088 return ScaleDotProductAttention.apply(
1089 query,
1090 key,
1091 value,
1092 attn_mask,
1093 dropout_p,
1094 is_causal,
1095 scale,
1096 enable_gqa,
1097 )
1100def flash_attention_forward(
1101 query,
1102 key,
1103 value,
1104 cumulative_sequence_length_q,
1105 cumulative_sequence_length_k,
1106 max_q,
1107 max_k,
1108 dropout_p,
1109 is_causal,
1110 return_debug_mask,
1111 *,
1112 scale=None,
1113 softcap=0.0,
1114 window_size_left=None,
1115 window_size_right=None,
1116 seqused_k=None,
1117 alibi_slopes=None,
1118 disable_splitkv=False,
1119):
1120 logger.debug("GEMS FLASH_ATTENTION_FORWARD")
1121 assert (
1122 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1123 ), "varlen is not supported yet."
1125 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1126 HEAD_DIM_V = value.shape[-1]
1127 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1128 original_head_dim = HEAD_DIM_K
1129 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1130 if HEAD_DIM_K not in supported_head_dims:
1131 padded_head_dim = None
1132 for d in supported_head_dims:
1133 if d >= HEAD_DIM_K:
1134 padded_head_dim = d
1135 break
1136 assert (
1137 padded_head_dim is not None
1138 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1139 pad = padded_head_dim - HEAD_DIM_K
1140 query = F.pad(query, (0, pad))
1141 key = F.pad(key, (0, pad))
1142 value = F.pad(value, (0, pad))
1143 HEAD_DIM_K = padded_head_dim
1145 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1146 if window_size_left is not None:
1147 non_null_window_left = window_size_left
1148 else:
1149 non_null_window_left = -1
1150 if window_size_right is not None:
1151 non_null_window_right = window_size_right
1152 else:
1153 non_null_window_right = -1
1155 out = torch.empty_like(query)
1156 if cumulative_sequence_length_q is not None:
1157 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1158 query,
1159 key,
1160 value,
1161 out,
1162 cumulative_sequence_length_q,
1163 cumulative_sequence_length_k,
1164 seqused_k,
1165 None,
1166 None, # block_table
1167 alibi_slopes,
1168 max_q,
1169 max_k,
1170 dropout_p,
1171 scale,
1172 False,
1173 is_causal,
1174 non_null_window_left,
1175 non_null_window_right,
1176 softcap,
1177 return_debug_mask and dropout_p > 0,
1178 None,
1179 )
1180 else:
1181 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1182 query,
1183 key,
1184 value,
1185 out,
1186 alibi_slopes,
1187 dropout_p,
1188 softmax_scale,
1189 is_causal,
1190 non_null_window_left,
1191 non_null_window_right,
1192 softcap,
1193 return_debug_mask,
1194 disable_splitkv=disable_splitkv,
1195 )
1197 if HEAD_DIM_K != original_head_dim:
1198 out = out[..., :original_head_dim]
1199 return (out, lse, philox_seed, philox_offset, p)
1202# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1203def maybe_contiguous(x):
1204 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1207def flash_attn_varlen_func(
1208 q,
1209 k,
1210 v,
1211 max_seqlen_q,
1212 cu_seqlens_q,
1213 max_seqlen_k,
1214 cu_seqlens_k=None, # only used for non-paged prefill
1215 seqused_k=None,
1216 q_v=None,
1217 dropout_p=0.0,
1218 softmax_scale=None,
1219 causal=False,
1220 window_size=None,
1221 softcap=0.0, # 0.0 means deactivated
1222 alibi_slopes=None,
1223 deterministic=False,
1224 return_attn_probs=False,
1225 block_table=None,
1226 return_softmax_lse=False,
1227 out=None,
1228 # Dummy FA3 arguments
1229 scheduler_metadata=None,
1230 q_descale=None,
1231 k_descale=None,
1232 v_descale=None,
1233 s_aux=None,
1234 num_splits: int = 0,
1235 cp_world_size: int = 1,
1236 cp_rank: int = 0,
1237 cp_tot_seqused_k=None,
1238 fa_version: int = 2,
1239):
1240 """dropout_p should be set to 0.0 during evaluation
1241 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1242 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1243 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1244 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1246 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1247 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1248 1 1 1 1 0
1249 1 1 1 1 1
1250 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1251 0 0
1252 0 0
1253 0 0
1254 1 0
1255 1 1
1256 If the row of the mask is all zero, the output will be zero.
1258 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1259 will only attend to keys between
1260 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1262 Arguments:
1263 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1264 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1265 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1266 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1267 of the sequences in the batch, used to index into q.
1268 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1269 of the sequences in the batch, used to index into kv.
1270 max_seqlen_q: int. Maximum query sequence length in the batch.
1271 max_seqlen_k: int. Maximum key sequence length in the batch.
1272 dropout_p: float. Dropout probability.
1273 softmax_scale: float. The scaling of QK^T before applying softmax.
1274 Default to 1 / sqrt(headdim).
1275 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1276 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1277 softcap: float. Anything > 0 activates softcapping attention.
1278 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1279 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1280 is added to the attention score of query i and key j.
1281 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1282 which is slightly slower and uses more memory. The forward pass is always deterministic.
1283 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1284 testing only. The returned probabilities are not guaranteed to be correct
1285 (they might not have the right scaling).
1286 Return:
1287 out: (total, nheads, headdim).
1288 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1289 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1290 normalization factor).
1291 """
1292 import os
1294 os.environ["OFF_ASYNC"] = "1"
1295 if fa_version != 2:
1296 raise RuntimeError("Only FA2 is implemented.")
1297 if num_splits > 0:
1298 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1299 if use_c_extension:
1300 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1301 with torch_device_fn.device(q.device):
1302 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1303 q,
1304 k,
1305 v,
1306 max_seqlen_q,
1307 cu_seqlens_q,
1308 max_seqlen_k,
1309 cu_seqlens_k,
1310 seqused_k,
1311 q_v,
1312 dropout_p,
1313 softmax_scale,
1314 causal,
1315 window_size,
1316 softcap,
1317 alibi_slopes,
1318 deterministic,
1319 return_attn_probs,
1320 block_table,
1321 return_softmax_lse,
1322 out,
1323 scheduler_metadata,
1324 q_descale,
1325 k_descale,
1326 v_descale,
1327 s_aux,
1328 num_splits,
1329 cp_world_size,
1330 cp_rank,
1331 cp_tot_seqused_k,
1332 fa_version,
1333 )
1334 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1335 else:
1336 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC")
1337 assert (
1338 cu_seqlens_k is not None or seqused_k is not None
1339 ), "cu_seqlens_k or seqused_k must be provided"
1340 assert (
1341 cu_seqlens_k is None or seqused_k is None
1342 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1343 assert (
1344 block_table is None or seqused_k is not None
1345 ), "seqused_k must be provided if block_table is provided"
1346 if softmax_scale is None:
1347 softmax_scale = q.shape[-1] ** (-0.5)
1348 # custom op does not support non-tuple input
1349 if window_size is None:
1350 real_window_size = (-1, -1)
1351 else:
1352 assert len(window_size) == 2
1353 real_window_size = (window_size[0], window_size[1])
1354 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1355 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1356 max_seqlen_q = (
1357 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1358 )
1359 max_seqlen_k = (
1360 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1361 )
1362 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1363 q,
1364 k,
1365 v,
1366 out,
1367 cu_seqlens_q,
1368 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1369 # still wants it so we pass all zeros
1370 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1371 seqused_k,
1372 None,
1373 block_table,
1374 alibi_slopes,
1375 max_seqlen_q,
1376 max_seqlen_k,
1377 dropout_p,
1378 softmax_scale,
1379 False,
1380 causal,
1381 real_window_size[0],
1382 real_window_size[1],
1383 softcap,
1384 return_softmax_lse and dropout_p > 0,
1385 None,
1386 )
1388 return (out, softmax_lse) if return_softmax_lse else out
1391def flash_attn_varlen_opt_func(
1392 q,
1393 k,
1394 v,
1395 max_seqlen_q,
1396 cu_seqlens_q,
1397 max_seqlen_k,
1398 cu_seqlens_k=None, # only used for non-paged prefill
1399 seqused_k=None,
1400 q_v=None,
1401 dropout_p=0.0,
1402 softmax_scale=None,
1403 causal=False,
1404 window_size=None,
1405 softcap=0.0, # 0.0 means deactivated
1406 alibi_slopes=None,
1407 deterministic=False,
1408 return_attn_probs=False,
1409 block_table=None,
1410 return_softmax_lse=False,
1411 out=None,
1412 lse=None,
1413 # Dummy FA3 arguments
1414 scheduler_metadata=None,
1415 q_descale=None,
1416 k_descale=None,
1417 v_descale=None,
1418 s_aux=None,
1419 num_splits: int = 0,
1420 cp_world_size: int = 1,
1421 cp_rank: int = 0,
1422 cp_tot_seqused_k=None,
1423 fa_version: int = 2,
1424):
1425 """dropout_p should be set to 0.0 during evaluation
1426 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1427 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1428 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1429 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1431 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1432 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1433 1 1 1 1 0
1434 1 1 1 1 1
1435 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1436 0 0
1437 0 0
1438 0 0
1439 1 0
1440 1 1
1441 If the row of the mask is all zero, the output will be zero.
1443 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1444 will only attend to keys between
1445 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1447 Arguments:
1448 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1449 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1450 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1451 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1452 of the sequences in the batch, used to index into q.
1453 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1454 of the sequences in the batch, used to index into kv.
1455 max_seqlen_q: int. Maximum query sequence length in the batch.
1456 max_seqlen_k: int. Maximum key sequence length in the batch.
1457 dropout_p: float. Dropout probability.
1458 softmax_scale: float. The scaling of QK^T before applying softmax.
1459 Default to 1 / sqrt(headdim).
1460 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1461 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1462 softcap: float. Anything > 0 activates softcapping attention.
1463 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1464 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1465 is added to the attention score of query i and key j.
1466 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1467 which is slightly slower and uses more memory. The forward pass is always deterministic.
1468 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1469 testing only. The returned probabilities are not guaranteed to be correct
1470 (they might not have the right scaling).
1471 Return:
1472 out: (total, nheads, headdim).
1473 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1474 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1475 normalization factor).
1476 """
1477 if fa_version != 2:
1478 raise RuntimeError("Only FA2 is implemented.")
1479 if num_splits > 0:
1480 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1481 if use_c_extension:
1482 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1483 with torch_device_fn.device(q.device):
1484 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1485 q,
1486 k,
1487 v,
1488 max_seqlen_q,
1489 cu_seqlens_q,
1490 max_seqlen_k,
1491 cu_seqlens_k,
1492 seqused_k,
1493 q_v,
1494 dropout_p,
1495 softmax_scale,
1496 causal,
1497 window_size,
1498 softcap,
1499 alibi_slopes,
1500 deterministic,
1501 return_attn_probs,
1502 block_table,
1503 return_softmax_lse,
1504 out,
1505 scheduler_metadata,
1506 q_descale,
1507 k_descale,
1508 v_descale,
1509 s_aux,
1510 num_splits,
1511 cp_world_size,
1512 cp_rank,
1513 cp_tot_seqused_k,
1514 fa_version,
1515 )
1516 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1517 else:
1518 logger.debug("GEMS FLASH_ATTN_VARLEN_OPT_FUNC")
1519 assert (
1520 cu_seqlens_k is not None or seqused_k is not None
1521 ), "cu_seqlens_k or seqused_k must be provided"
1522 assert (
1523 cu_seqlens_k is None or seqused_k is None
1524 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1525 assert (
1526 block_table is None or seqused_k is not None
1527 ), "seqused_k must be provided if block_table is provided"
1528 if softmax_scale is None:
1529 softmax_scale = q.shape[-1] ** (-0.5)
1530 # custom op does not support non-tuple input
1531 if window_size is None:
1532 real_window_size = (-1, -1)
1533 else:
1534 assert len(window_size) == 2
1535 real_window_size = (window_size[0], window_size[1])
1536 # q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1537 # dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1538 max_seqlen_q = (
1539 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1540 )
1541 max_seqlen_k = (
1542 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1543 )
1544 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd_opt(
1545 q,
1546 k,
1547 v,
1548 out,
1549 lse,
1550 cu_seqlens_q,
1551 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1552 # still wants it so we pass all zeros
1553 # dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1554 cu_seqlens_q if cu_seqlens_k is None else cu_seqlens_k,
1555 seqused_k,
1556 None,
1557 block_table,
1558 alibi_slopes,
1559 max_seqlen_q,
1560 max_seqlen_k,
1561 dropout_p,
1562 softmax_scale,
1563 False,
1564 causal,
1565 real_window_size[0],
1566 real_window_size[1],
1567 softcap,
1568 return_softmax_lse and dropout_p > 0,
1569 None,
1570 )
1572 return (out, softmax_lse) if return_softmax_lse else out