Coverage for src/flag_gems/runtime/backend/_sunrise/ops/attention.py: 0%
416 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.runtime import torch_device_fn
13from flag_gems.runtime.backend._sunrise.ops.flash_api import mha_fwd, mha_varlan_fwd
14from flag_gems.runtime.backend._sunrise.ops.flash_kernel import keep
15from flag_gems.utils import libentry, libtuner
17logger = logging.getLogger(__name__)
20# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
21@triton.jit
22def _attn_fwd_inner(
23 acc,
24 l_i,
25 m_i,
26 query, #
27 K_block_ptr,
28 V_block_ptr, #
29 mask_block_ptr, #
30 stride_k_seqlen,
31 stride_v_seqlen,
32 stride_attn_mask_kv_seqlen, #
33 start_m,
34 qk_scale, #
35 q_load_mask,
36 BLOCK_M: tl.constexpr,
37 HEAD_DIM: tl.constexpr,
38 BLOCK_N: tl.constexpr, #
39 STAGE: tl.constexpr,
40 offs_m: tl.constexpr,
41 offs_n: tl.constexpr, #
42 KV_CTX: tl.constexpr,
43 fp8_v: tl.constexpr,
44 HAS_ATTN_MASK: tl.constexpr,
45 PRE_LOAD_V: tl.constexpr,
46):
47 # range of values handled by this stage
48 if STAGE == 1:
49 lo, hi = 0, start_m * BLOCK_M
50 elif STAGE == 2:
51 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
52 # causal = False
53 else:
54 lo, hi = 0, KV_CTX
56 K_block_ptr += lo * stride_k_seqlen
57 V_block_ptr += lo * stride_v_seqlen
58 if HAS_ATTN_MASK:
59 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
61 LOG2E = 1.44269504 # log2(e) constant
63 # loop over key, value and update accumulator
64 for start_n in range(lo, hi, BLOCK_N):
65 kv_load_mask = (start_n + offs_n) < KV_CTX
66 # start_n = tl.multiple_of(start_n, BLOCK_N)
67 # -- compute qk ----
68 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
69 if PRE_LOAD_V:
70 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
72 qk = tl.dot(query, key, allow_tf32=False)
73 # incase not divisible.
74 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
75 # qk = qk.to(tl.float32)
77 if HAS_ATTN_MASK:
78 attn_mask = tl.load(
79 mask_block_ptr,
80 mask=q_load_mask[:, None] & kv_load_mask[None, :],
81 other=0.0,
82 )
84 if STAGE == 2:
85 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
87 if HAS_ATTN_MASK:
88 qk = qk * qk_scale + attn_mask
89 qk *= LOG2E
90 qk = qk + tl.where(mask, 0, -1.0e6)
91 else:
92 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
94 m_ij = tl.maximum(m_i, tl.max(qk, 1))
95 qk -= m_ij[:, None]
96 else:
97 qk *= qk_scale * LOG2E
98 if HAS_ATTN_MASK:
99 qk = qk + attn_mask
100 m_ij = tl.maximum(m_i, tl.max(qk, 1))
101 qk = qk - m_ij[:, None]
103 p = tl.math.exp2(qk)
104 l_ij = tl.sum(p, 1)
105 # -- update m_i and l_i
106 alpha = tl.math.exp2(m_i - m_ij)
107 l_i = l_i * alpha + l_ij
108 # -- update output accumulator --
109 acc = acc * alpha[:, None]
110 # update acc
111 if not PRE_LOAD_V:
112 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
113 if fp8_v:
114 p = p.to(tl.float8e5)
115 else:
116 p = p.to(query.dtype)
117 p = p.to(value.dtype)
118 acc = tl.dot(p, value, acc, allow_tf32=False)
119 # update m_i and l_i
120 m_i = m_ij
122 K_block_ptr += BLOCK_N * stride_k_seqlen
123 V_block_ptr += BLOCK_N * stride_v_seqlen
125 if HAS_ATTN_MASK:
126 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
128 return acc, l_i, m_i
131# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
132# we need to generate more configs.
133configs = runtime.get_tuned_config("attention")
134SMALL_HEAD_DIM_CONFIGS = [
135 triton.Config(
136 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
137 )
138 for BM in [16, 32]
139 for BN in [8, 16]
140 for s in [0]
141 for w in [8, 16]
142]
143# configs += SMALL_HEAD_DIM_CONFIGS
144configs = SMALL_HEAD_DIM_CONFIGS
147@libentry()
148@libtuner(
149 configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
150 key=["KV_CTX", "HEAD_DIM"],
151)
152@triton.jit
153def _attn_fwd(
154 Q,
155 K,
156 V,
157 attn_mask,
158 sm_scale,
159 M,
160 Out, #
161 stride_q_batch,
162 stride_q_head,
163 stride_q_seqlen,
164 stride_q_headsize,
165 stride_k_batch,
166 stride_k_head,
167 stride_k_seqlen,
168 stride_k_headsize,
169 stride_v_batch,
170 stride_v_head,
171 stride_v_seqlen,
172 stride_v_headsize,
173 stride_attn_mask_batch,
174 stride_attn_mask_head,
175 stride_attn_mask_q_seqlen,
176 stride_attn_mask_kv_seqlen,
177 stride_o_batch,
178 stride_o_head,
179 stride_o_seqlen,
180 stride_o_headsize,
181 Z,
182 q_head_num,
183 kv_head_num,
184 GROUP_HEAD: tl.constexpr,
185 Q_CTX,
186 KV_CTX,
187 HEAD_DIM: tl.constexpr,
188 BLOCK_M: tl.constexpr,
189 BLOCK_N: tl.constexpr,
190 STAGE: tl.constexpr,
191 HAS_ATTN_MASK: tl.constexpr,
192 PRE_LOAD_V: tl.constexpr,
193):
194 tl.static_assert(BLOCK_N <= HEAD_DIM)
195 start_m = tl.program_id(0)
196 off_hz = tl.program_id(1)
197 batch_id = off_hz // q_head_num
198 head_id = off_hz % q_head_num
199 kv_head_id = head_id // GROUP_HEAD
201 q_offset = (
202 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
203 )
204 o_offset = (
205 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
206 )
207 kv_offset = (
208 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
209 )
211 offs_headsize = tl.arange(0, HEAD_DIM)
213 # initialize offsets
214 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
215 q_load_mask = offs_m < Q_CTX
216 offs_n = tl.arange(0, BLOCK_N)
218 Q_block_ptr = (
219 Q
220 + q_offset
221 + offs_m[:, None] * stride_q_seqlen
222 + offs_headsize[None, :] * stride_q_headsize
223 )
224 K_block_ptr = (
225 K
226 + kv_offset
227 + offs_n[None, :] * stride_k_seqlen
228 + offs_headsize[:, None] * stride_k_headsize
229 )
230 V_block_ptr = (
231 V
232 + kv_offset
233 + offs_n[:, None] * stride_v_seqlen
234 + offs_headsize[None, :] * stride_v_headsize
235 )
237 if HAS_ATTN_MASK:
238 attn_mask_offset = (
239 batch_id.to(tl.int64) * stride_attn_mask_batch
240 + head_id.to(tl.int64) * stride_attn_mask_head
241 )
242 mask_block_ptr = (
243 attn_mask
244 + attn_mask_offset
245 + offs_m[:, None] * stride_attn_mask_q_seqlen
246 + offs_n[None, :] * stride_attn_mask_kv_seqlen
247 )
248 else:
249 mask_block_ptr = None
251 O_block_ptr = (
252 Out
253 + o_offset
254 + offs_m[:, None] * stride_o_seqlen
255 + offs_headsize[None, :] * stride_o_headsize
256 )
258 # initialize pointer to m and l
259 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
260 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
261 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
262 # load scales
263 qk_scale = sm_scale
264 # qk_scale *= 1.44269504 # 1/log(2)
265 # load query: it will stay in SRAM throughout
266 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
267 # stage 1: off-band
268 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
269 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
270 if STAGE & 1:
271 acc, l_i, m_i = _attn_fwd_inner(
272 acc,
273 l_i,
274 m_i,
275 query,
276 K_block_ptr,
277 V_block_ptr,
278 mask_block_ptr,
279 stride_k_seqlen,
280 stride_v_seqlen,
281 stride_attn_mask_kv_seqlen,
282 start_m,
283 qk_scale,
284 q_load_mask,
285 BLOCK_M,
286 HEAD_DIM,
287 BLOCK_N,
288 4 - STAGE,
289 offs_m,
290 offs_n,
291 KV_CTX,
292 V.dtype.element_ty == tl.float8e5,
293 HAS_ATTN_MASK,
294 PRE_LOAD_V,
295 )
296 # stage 2: on-band
297 if STAGE & 2:
298 # barrier makes it easier for compielr to schedule the
299 # two loops independently
300 acc, l_i, m_i = _attn_fwd_inner(
301 acc,
302 l_i,
303 m_i,
304 query,
305 K_block_ptr,
306 V_block_ptr,
307 mask_block_ptr,
308 stride_k_seqlen,
309 stride_v_seqlen,
310 stride_attn_mask_kv_seqlen,
311 start_m,
312 qk_scale,
313 q_load_mask,
314 BLOCK_M,
315 HEAD_DIM,
316 BLOCK_N,
317 2,
318 offs_m,
319 offs_n,
320 KV_CTX,
321 V.dtype.element_ty == tl.float8e5,
322 HAS_ATTN_MASK,
323 PRE_LOAD_V,
324 )
325 # epilogue
326 m_i += tl.math.log2(l_i)
327 acc = acc / l_i[:, None]
328 m_ptrs = M + off_hz * Q_CTX + offs_m
329 tl.store(m_ptrs, m_i, mask=q_load_mask)
330 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
333@triton.jit
334def _attn_bwd_preprocess(
335 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
336):
337 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
338 mask = off_m < Q_CTX
340 off_hz = tl.program_id(1)
341 off_n = tl.arange(0, D_HEAD)
342 # load
343 o = tl.load(
344 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
345 mask=mask[:, None],
346 other=0.0,
347 )
348 do = tl.load(
349 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
350 mask=mask[:, None],
351 other=0.0,
352 ).to(tl.float32)
353 delta = tl.sum(o * do, axis=1)
354 # write-back
355 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
358# The main inner-loop logic for computing dK and dV.
359@triton.jit
360def _attn_bwd_dkdv(
361 dk,
362 dv, #
363 Q,
364 key,
365 value,
366 sm_scale, #
367 DO, #
368 M,
369 D, #
370 # shared by Q/K/V/DO.
371 stride_tok,
372 stride_d, #
373 H,
374 Q_CTX,
375 KV_CTX,
376 BLOCK_M1: tl.constexpr, #
377 BLOCK_N1: tl.constexpr, #
378 BLOCK_DMODEL: tl.constexpr, #
379 # Filled in by the wrapper.
380 start_n,
381 start_m,
382 num_steps, #
383 MASK: tl.constexpr,
384):
385 # BLOCK_M1: 32
386 # BLOCK_N1: 128
387 offs_n = start_n + tl.arange(0, BLOCK_N1)
388 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
390 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
392 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
393 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
394 curr_m = start_m
395 step_m = BLOCK_M1
396 for blk_idx in range(num_steps):
397 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
398 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
400 qT_ptrs = (
401 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
402 ) # (BLOCK_DMODEL, BLOCK_M1)
403 do_ptrs = (
404 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
405 ) # (BLOCK_M1, BLOCK_DMODEL)
407 qT = tl.load(
408 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
409 ) # (BLOCK_DMODEL, BLOCK_M1)
411 # Load m before computing qk to reduce pipeline stall.
412 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
414 # key: (BLOCK_N1, BLOCK_DMODEL)
415 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
416 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
417 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
418 pT = tl.math.exp2(qkT - m)
419 # pT = tl.math.exp2(qkT - m[None, :])
421 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
422 :, None
423 ] # (BLOCK_N1, BLOCK_M1)
424 # Autoregressive masking.
425 if MASK:
426 mask &= offs_m[None, :] >= offs_n[:, None]
427 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
429 do = tl.load(do_ptrs)
430 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL)
432 # Compute dV.
433 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
434 # D (= delta) is pre-divided by ds_scale.
435 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
437 # Compute dP and dS.
438 dpT = tl.dot(value, tl.trans(do)).to(
439 tl.float32
440 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
441 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
442 dsT = dsT.to(qT.dtype)
443 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
444 dsT = tl.where(
445 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
446 ) # (BLOCK_N1, BLOCK_M1)
447 dk += tl.dot(
448 dsT, tl.trans(qT)
449 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
450 # Increment pointers.
451 curr_m += step_m
452 return dk, dv
455# the main inner-loop logic for computing dQ
456@triton.jit
457def _attn_bwd_dq(
458 dq,
459 query,
460 K,
461 V, #
462 do,
463 m,
464 D,
465 # shared by Q/K/V/DO.
466 stride_tok,
467 stride_d, #
468 H,
469 Q_CTX, #
470 KV_CTX, #
471 BLOCK_M2: tl.constexpr, #
472 BLOCK_N2: tl.constexpr, #
473 BLOCK_DMODEL: tl.constexpr,
474 # Filled in by the wrapper.
475 start_m,
476 start_n,
477 num_steps, #
478 MASK: tl.constexpr,
479):
480 offs_m = start_m + tl.arange(0, BLOCK_M2)
481 offs_m_mask = offs_m < Q_CTX
483 offs_k = tl.arange(0, BLOCK_DMODEL)
484 # D (= delta) is pre-divided by ds_scale.
485 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
486 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
487 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
488 curr_n = start_n
489 step_n = BLOCK_N2
490 for blk_idx in range(num_steps):
491 offs_n = curr_n + tl.arange(0, BLOCK_N2)
492 offs_n_mask = offs_n < KV_CTX
494 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
495 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
497 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
498 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
499 qk = tl.dot(query, kT)
500 p = tl.math.exp2(qk - m)
501 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
502 # Autoregressive masking.
503 if MASK:
504 # mask = (offs_m[:, None] >= offs_n[None, :])
505 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
506 mask &= offs_m[:, None] >= offs_n[None, :]
507 p = tl.where(mask, p, 0.0)
508 # Compute dP and dS.
509 dp = tl.dot(do, vT).to(tl.float32)
510 ds = p * (dp - Di[:, None])
511 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
512 # Compute dQ.
513 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
514 dq += tl.dot(ds, tl.trans(kT))
515 # Increment pointers.
516 curr_n += step_n
517 return dq
520config_backward = runtime.get_tuned_config("attention_bwd")
523@libentry()
524@libtuner(
525 configs=config_backward,
526 key=["KV_CTX", "BLOCK_DMODEL"],
527)
528@triton.jit
529def _attn_bwd(
530 Q,
531 K,
532 V,
533 sm_scale, #
534 DO, #
535 DQ,
536 DK,
537 DV, #
538 M,
539 D,
540 # shared by Q/K/V/DO.
541 stride_z,
542 stride_h,
543 stride_tok,
544 stride_d, #
545 kv_stride_z,
546 kv_stride_h, #
547 H, # query head num
548 Q_CTX, #
549 KV_CTX, #
550 kv_head_num, #
551 GROUP_HEAD: tl.constexpr, #
552 BLOCK_M1: tl.constexpr, #
553 BLOCK_N1: tl.constexpr, #
554 BLOCK_M2: tl.constexpr, #
555 BLOCK_N2: tl.constexpr, #
556 BLK_SLICE_FACTOR: tl.constexpr, #
557 BLOCK_DMODEL: tl.constexpr,
558):
559 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.")
561 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
563 bhid = tl.program_id(2)
564 off_chz = (bhid * Q_CTX).to(tl.int64)
565 batch_id = bhid // H
566 q_head_id = bhid % H
567 kv_head_id = q_head_id // GROUP_HEAD
568 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
569 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
571 pid = tl.program_id(0)
573 # offset pointers for batch/head
574 Q += adj
575 K += kv_adj
576 V += kv_adj
577 DO += adj
578 DQ += adj
579 DK += adj
580 DV += adj
581 M += off_chz
582 D += off_chz
584 # load scales
585 offs_k = tl.arange(0, BLOCK_DMODEL)
587 start_n = pid * BLOCK_N1
588 start_m = start_n
590 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
591 offs_n = start_n + tl.arange(0, BLOCK_N1)
592 offs_n_mask = offs_n < KV_CTX
594 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
595 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
597 # load K and V: they stay in SRAM throughout the inner loop.
598 key = tl.load(
599 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
600 mask=offs_n_mask[:, None],
601 other=0.0,
602 )
603 value = tl.load(
604 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
605 mask=offs_n_mask[:, None],
606 other=0.0,
607 )
609 num_steps = BLOCK_N1 // MASK_BLOCK_M1
611 dk, dv = _attn_bwd_dkdv(
612 dk,
613 dv, #
614 Q,
615 key,
616 value,
617 sm_scale, #
618 DO, #
619 M,
620 D, #
621 stride_tok,
622 stride_d, #
623 H,
624 Q_CTX, #
625 KV_CTX, #
626 MASK_BLOCK_M1,
627 BLOCK_N1,
628 BLOCK_DMODEL, #
629 start_n,
630 start_m,
631 num_steps, #
632 MASK=True, #
633 )
635 # Compute dK and dV for non-masked blocks.
636 start_m += num_steps * MASK_BLOCK_M1
637 remaining_m = Q_CTX - start_m
638 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
640 if num_steps > 0 and start_m < Q_CTX:
641 dk, dv = _attn_bwd_dkdv( #
642 dk,
643 dv, #
644 Q,
645 key,
646 value,
647 sm_scale, #
648 DO, #
649 M,
650 D, #
651 stride_tok,
652 stride_d, #
653 H,
654 Q_CTX, #
655 KV_CTX, #
656 BLOCK_M1,
657 BLOCK_N1,
658 BLOCK_DMODEL, #
659 start_n,
660 start_m,
661 num_steps, #
662 MASK=False, #
663 )
664 # tl.device_print("dv: ", dv)
666 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
667 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
669 # Write back dK.
670 dk *= sm_scale
671 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
672 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
674 # THIS BLOCK DOES DQ:
675 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
676 start_m = pid * BLOCK_M2
677 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX
678 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
680 offs_m = start_m + tl.arange(0, BLOCK_M2)
681 offs_m_mask = offs_m < Q_CTX
683 query = tl.load(
684 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
685 mask=offs_m_mask[:, None],
686 other=0.0,
687 )
688 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
689 do = tl.load(
690 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
691 mask=offs_m_mask[:, None],
692 other=0.0,
693 )
695 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
696 m = m[:, None]
698 # Stage 1 - Compute dQ for masked (diagonal) blocks.
699 # NOTE: This code scans each row of QK^T backward (from right to left,
700 # but inside each call to _attn_bwd_dq, from left to right), but that's
701 # not due to anything important. I just wanted to reuse the loop
702 # structure for dK & dV above as much as possible.
704 if num_steps > 0:
705 dq = _attn_bwd_dq(
706 dq,
707 query,
708 K,
709 V, #
710 do,
711 m,
712 D, #
713 stride_tok,
714 stride_d, #
715 H,
716 Q_CTX, #
717 KV_CTX, #
718 BLOCK_M2,
719 MASK_BLOCK_N2,
720 BLOCK_DMODEL, #
721 start_m,
722 start_n,
723 num_steps, #
724 MASK=True, #
725 )
727 # Stage 2 - non-masked blocks
728 stage2_end_n = start_n
729 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2
731 if stage2_num_steps > 0:
732 dq = _attn_bwd_dq(
733 dq,
734 query,
735 K,
736 V, #
737 do,
738 m,
739 D, #
740 stride_tok,
741 stride_d, #
742 H,
743 Q_CTX, #
744 KV_CTX, #
745 BLOCK_M2,
746 BLOCK_N2,
747 BLOCK_DMODEL, #
748 start_m,
749 stage2_end_n - stage2_num_steps * BLOCK_N2,
750 stage2_num_steps, #
751 MASK=False, #
752 )
753 # Write back dQ.
754 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
755 dq *= LN2
756 # tl.store(dq_ptrs, dq)
758 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
761def scaled_dot_product_attention_forward(
762 query,
763 key,
764 value,
765 attn_mask=None,
766 dropout_p=0.0,
767 is_causal=False,
768 scale=None,
769 enable_gqa=False,
770):
771 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION FORWARD")
772 # shape constraints
773 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
774 # when v is in float8_e5m2 it is transposed.
775 HEAD_DIM_V = value.shape[-1]
776 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
777 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
778 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
780 o = torch.empty_like(query, dtype=value.dtype)
782 stage = 3 if is_causal else 1
784 if scale is None:
785 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
786 else:
787 sm_scale = scale
789 q_head_num = query.shape[1]
790 kv_head_num = key.shape[1]
791 assert enable_gqa or q_head_num == kv_head_num, (
792 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
793 "enable_gqa must be True to support different head numbers."
794 )
796 grid = lambda args: (
797 triton.cdiv(query.shape[2], args["BLOCK_M"]),
798 query.shape[0] * query.shape[1],
799 1,
800 )
802 if attn_mask is not None:
803 HAS_ATTN_MASK = True
804 if attn_mask.dtype == torch.bool:
805 attn_mask = attn_mask.to(query.dtype) * -1.0e6
806 stride_attn_mask_batch = attn_mask.stride(0)
807 stride_attn_mask_head = attn_mask.stride(1)
808 stride_attn_mask_q_seqlen = attn_mask.stride(2)
809 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
810 else:
811 HAS_ATTN_MASK = False
812 stride_attn_mask_batch = 1
813 stride_attn_mask_head = 1
814 stride_attn_mask_q_seqlen = 1
815 stride_attn_mask_kv_seqlen = 1
817 M = torch.empty(
818 (query.shape[0], query.shape[1], query.shape[2]),
819 device=query.device,
820 dtype=torch.float32,
821 )
823 with torch_device_fn.device(query.device):
824 _attn_fwd[grid](
825 query,
826 key,
827 value,
828 attn_mask,
829 sm_scale,
830 M,
831 o, #
832 query.stride(0),
833 query.stride(1),
834 query.stride(2),
835 query.stride(3), #
836 key.stride(0),
837 key.stride(1),
838 key.stride(2),
839 key.stride(3), #
840 value.stride(0),
841 value.stride(1),
842 value.stride(2),
843 value.stride(3), #
844 stride_attn_mask_batch,
845 stride_attn_mask_head,
846 stride_attn_mask_q_seqlen,
847 stride_attn_mask_kv_seqlen, #
848 o.stride(0),
849 o.stride(1),
850 o.stride(2),
851 o.stride(3), #
852 query.shape[0],
853 q_head_num,
854 kv_head_num, #
855 q_head_num // kv_head_num, # group_head
856 query.shape[2], #
857 key.shape[2], #
858 HEAD_DIM_K, #
859 STAGE=stage, #
860 HAS_ATTN_MASK=HAS_ATTN_MASK, #
861 )
862 return o, M
865def scaled_dot_product_attention_backward(
866 do,
867 query,
868 key,
869 value,
870 o,
871 M,
872 attn_mask=None,
873 dropout_p=0.0,
874 is_causal=False,
875 scale=None,
876 enable_gqa=False,
877):
878 logger.debug("GEMS SCALED DOT PRODUCT ATTENTION BACKWARD")
879 # shape constraints
880 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
881 # when v is in float8_e5m2 it is transposed.
882 HEAD_DIM_V = value.shape[-1]
883 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
884 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
885 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
887 if scale is None:
888 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
889 else:
890 sm_scale = scale
892 assert do.is_contiguous()
893 assert (
894 query.is_contiguous()
895 and key.is_contiguous()
896 and value.is_contiguous()
897 and o.is_contiguous()
898 )
899 assert query.stride() == o.stride() == do.stride()
900 assert key.stride() == value.stride()
902 BLOCK_DMODEL = HEAD_DIM_K
903 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
904 _, KV_HEAD, KV_CTX = key.shape[:3]
905 group_head = Q_HEAD // KV_HEAD
907 # NUM_WARPS, NUM_STAGES = 4, 1
908 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
909 BLK_SLICE_FACTOR = 2
910 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
912 RCP_LN2 = 1.0 / math.log(2)
914 arg_k = key * (sm_scale * RCP_LN2)
915 # PRE_BLOCK = 128
916 PRE_BLOCK = 256
918 # PRE_BLOCK = 32
919 # assert N_CTX % PRE_BLOCK == 0
920 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
921 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
923 delta = torch.empty_like(M)
925 # NOTE that dk & dv always have the same number of heads as q
926 dq = torch.empty_like(query).contiguous()
927 dk = torch.empty(
928 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
929 device=key.device,
930 dtype=key.dtype,
931 memory_format=torch.contiguous_format,
932 )
933 dv = torch.empty(
934 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
935 device=value.device,
936 dtype=value.dtype,
937 memory_format=torch.contiguous_format,
938 )
940 _attn_bwd_preprocess[pre_grid](
941 o,
942 do, #
943 delta, #
944 BATCH,
945 Q_HEAD,
946 Q_CTX, #
947 BLOCK_M=PRE_BLOCK,
948 D_HEAD=BLOCK_DMODEL, #
949 )
951 max_block_n1 = (
952 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward])
953 if config_backward
954 else 128
955 )
956 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD)
957 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}")
958 # logger.info(f"{M.shape=}")
960 _attn_bwd[grid](
961 query,
962 arg_k,
963 value,
964 sm_scale,
965 do,
966 dq,
967 dk,
968 dv, #
969 M,
970 delta, #
971 query.stride(0),
972 query.stride(1),
973 query.stride(2),
974 query.stride(3), #
975 key.stride(0),
976 key.stride(1), #
977 Q_HEAD,
978 Q_CTX, #
979 KV_CTX, #
980 KV_HEAD, #
981 GROUP_HEAD=group_head, #
982 # BLOCK_M1=BLOCK_M1,
983 # BLOCK_N1=BLOCK_N1, #
984 # BLOCK_M2=BLOCK_M2,
985 # BLOCK_N2=BLOCK_N2, #
986 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
987 BLOCK_DMODEL=BLOCK_DMODEL, #
988 # num_warps=NUM_WARPS, #
989 # num_stages=NUM_STAGES, #
990 )
992 if group_head > 1:
993 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
994 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
995 dk = dk.sum(dim=2)
996 dv = dv.sum(dim=2)
998 return dq, dk, dv
1001class ScaleDotProductAttention(torch.autograd.Function):
1002 @staticmethod
1003 def forward(
1004 ctx,
1005 query,
1006 key,
1007 value,
1008 attn_mask=None,
1009 dropout_p=0.0,
1010 is_causal=False,
1011 scale=None,
1012 enable_gqa=False,
1013 ):
1014 # [sunrise fix] padding for unsupported head dims, since auto lower unsupported.
1015 head_size = key.shape[-1]
1016 supported_head_dims = {16, 32, 64, 128, 256}
1017 if head_size not in supported_head_dims:
1018 padded_head_dim = None
1019 for d in supported_head_dims:
1020 if d >= head_size:
1021 padded_head_dim = d
1022 break
1023 assert (
1024 padded_head_dim is not None
1025 ), f"Unsupported head dim {head_size}, max supported is {supported_head_dims[-1]}"
1026 pad = padded_head_dim - head_size
1027 query = F.pad(query, (0, pad))
1028 key = F.pad(key, (0, pad))
1029 value = F.pad(value, (0, pad))
1031 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1032 o, M = scaled_dot_product_attention_forward(
1033 query,
1034 key,
1035 value,
1036 attn_mask,
1037 dropout_p,
1038 is_causal,
1039 sm_scale,
1040 enable_gqa,
1041 )
1042 # [sunrise fix] padding for unsupported head dims, since auto lower unsupported.
1043 if head_size not in supported_head_dims:
1044 o = o[..., :head_size]
1046 ctx.save_for_backward(query, key, value, o, M)
1047 ctx.sm_scale = sm_scale
1048 ctx.causal = is_causal
1049 ctx.enable_gqa = enable_gqa
1050 return o
1052 @staticmethod
1053 def backward(ctx, do):
1054 query, key, value, o, M = ctx.saved_tensors
1055 is_causal = ctx.causal
1056 enable_gqa = ctx.enable_gqa
1057 sm_scale = ctx.sm_scale
1058 dq, dk, dv = scaled_dot_product_attention_backward(
1059 do,
1060 query,
1061 key,
1062 value,
1063 o,
1064 M,
1065 attn_mask=None,
1066 dropout_p=0.0,
1067 is_causal=is_causal,
1068 scale=sm_scale,
1069 enable_gqa=enable_gqa,
1070 )
1071 return dq, dk, dv, None, None, None, None, None
1074def scaled_dot_product_attention(
1075 query,
1076 key,
1077 value,
1078 attn_mask=None,
1079 dropout_p=0.0,
1080 is_causal=False,
1081 scale=None,
1082 enable_gqa=False,
1083):
1084 return ScaleDotProductAttention.apply(
1085 query,
1086 key,
1087 value,
1088 attn_mask,
1089 dropout_p,
1090 is_causal,
1091 scale,
1092 enable_gqa,
1093 )
1096def flash_attention_forward(
1097 query,
1098 key,
1099 value,
1100 cumulative_sequence_length_q,
1101 cumulative_sequence_length_k,
1102 max_q,
1103 max_k,
1104 dropout_p,
1105 is_causal,
1106 return_debug_mask,
1107 *,
1108 scale=None,
1109 softcap=0.0,
1110 window_size_left=None,
1111 window_size_right=None,
1112 seqused_k=None,
1113 alibi_slopes=None,
1114 disable_splitkv=False,
1115):
1116 logger.debug("GEMS FLASH_ATTENTION_FORWARD")
1117 assert (
1118 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1119 ), "varlen is not supported yet."
1121 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1122 HEAD_DIM_V = value.shape[-1]
1123 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1124 original_head_dim = HEAD_DIM_K
1125 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1126 if HEAD_DIM_K not in supported_head_dims:
1127 padded_head_dim = None
1128 for d in supported_head_dims:
1129 if d >= HEAD_DIM_K:
1130 padded_head_dim = d
1131 break
1132 assert (
1133 padded_head_dim is not None
1134 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1135 pad = padded_head_dim - HEAD_DIM_K
1136 query = F.pad(query, (0, pad))
1137 key = F.pad(key, (0, pad))
1138 value = F.pad(value, (0, pad))
1139 HEAD_DIM_K = padded_head_dim
1141 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1142 if window_size_left is not None:
1143 non_null_window_left = window_size_left
1144 else:
1145 non_null_window_left = -1
1146 if window_size_right is not None:
1147 non_null_window_right = window_size_right
1148 else:
1149 non_null_window_right = -1
1151 out = torch.empty_like(query)
1152 if cumulative_sequence_length_q is not None:
1153 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1154 query,
1155 key,
1156 value,
1157 out,
1158 cumulative_sequence_length_q,
1159 cumulative_sequence_length_k,
1160 seqused_k,
1161 None,
1162 None, # block_table
1163 alibi_slopes,
1164 max_q,
1165 max_k,
1166 dropout_p,
1167 scale,
1168 False,
1169 is_causal,
1170 non_null_window_left,
1171 non_null_window_right,
1172 softcap,
1173 return_debug_mask and dropout_p > 0,
1174 None,
1175 )
1176 else:
1177 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1178 query,
1179 key,
1180 value,
1181 out,
1182 alibi_slopes,
1183 dropout_p,
1184 softmax_scale,
1185 is_causal,
1186 non_null_window_left,
1187 non_null_window_right,
1188 softcap,
1189 return_debug_mask,
1190 disable_splitkv=disable_splitkv,
1191 )
1193 if HEAD_DIM_K != original_head_dim:
1194 out = out[..., :original_head_dim]
1195 return (out, lse, philox_seed, philox_offset, p)
1198# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1199def maybe_contiguous(x):
1200 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1203def flash_attn_varlen_func(
1204 q,
1205 k,
1206 v,
1207 max_seqlen_q,
1208 cu_seqlens_q,
1209 max_seqlen_k,
1210 cu_seqlens_k=None, # only used for non-paged prefill
1211 seqused_k=None,
1212 q_v=None,
1213 dropout_p=0.0,
1214 softmax_scale=None,
1215 causal=False,
1216 window_size=None,
1217 softcap=0.0, # 0.0 means deactivated
1218 alibi_slopes=None,
1219 deterministic=False,
1220 return_attn_probs=False,
1221 block_table=None,
1222 return_softmax_lse=False,
1223 out=None,
1224 # Dummy FA3 arguments
1225 scheduler_metadata=None,
1226 q_descale=None,
1227 k_descale=None,
1228 v_descale=None,
1229 s_aux=None,
1230 num_splits: int = 0,
1231 cp_world_size: int = 1,
1232 cp_rank: int = 0,
1233 cp_tot_seqused_k=None,
1234 fa_version: int = 2,
1235):
1236 """dropout_p should be set to 0.0 during evaluation
1237 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1238 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1239 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1240 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1242 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1243 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1244 1 1 1 1 0
1245 1 1 1 1 1
1246 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1247 0 0
1248 0 0
1249 0 0
1250 1 0
1251 1 1
1252 If the row of the mask is all zero, the output will be zero.
1254 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1255 will only attend to keys between
1256 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1258 Arguments:
1259 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1260 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1261 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1262 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1263 of the sequences in the batch, used to index into q.
1264 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1265 of the sequences in the batch, used to index into kv.
1266 max_seqlen_q: int. Maximum query sequence length in the batch.
1267 max_seqlen_k: int. Maximum key sequence length in the batch.
1268 dropout_p: float. Dropout probability.
1269 softmax_scale: float. The scaling of QK^T before applying softmax.
1270 Default to 1 / sqrt(headdim).
1271 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1272 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1273 softcap: float. Anything > 0 activates softcapping attention.
1274 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1275 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1276 is added to the attention score of query i and key j.
1277 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1278 which is slightly slower and uses more memory. The forward pass is always deterministic.
1279 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1280 testing only. The returned probabilities are not guaranteed to be correct
1281 (they might not have the right scaling).
1282 Return:
1283 out: (total, nheads, headdim).
1284 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1285 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1286 normalization factor).
1287 """
1288 import os
1290 os.environ["OFF_ASYNC"] = "1"
1291 if fa_version != 2:
1292 raise RuntimeError("Only FA2 is implemented.")
1293 if num_splits > 0:
1294 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1295 if use_c_extension:
1296 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC(C EXTENSION)")
1297 with torch_device_fn.device(q.device):
1298 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1299 q,
1300 k,
1301 v,
1302 max_seqlen_q,
1303 cu_seqlens_q,
1304 max_seqlen_k,
1305 cu_seqlens_k,
1306 seqused_k,
1307 q_v,
1308 dropout_p,
1309 softmax_scale,
1310 causal,
1311 window_size,
1312 softcap,
1313 alibi_slopes,
1314 deterministic,
1315 return_attn_probs,
1316 block_table,
1317 return_softmax_lse,
1318 out,
1319 scheduler_metadata,
1320 q_descale,
1321 k_descale,
1322 v_descale,
1323 s_aux,
1324 num_splits,
1325 cp_world_size,
1326 cp_rank,
1327 cp_tot_seqused_k,
1328 fa_version,
1329 )
1330 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1331 else:
1332 logger.debug("GEMS FLASH_ATTN_VARLEN_FUNC")
1333 assert (
1334 cu_seqlens_k is not None or seqused_k is not None
1335 ), "cu_seqlens_k or seqused_k must be provided"
1336 assert (
1337 cu_seqlens_k is None or seqused_k is None
1338 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1339 assert (
1340 block_table is None or seqused_k is not None
1341 ), "seqused_k must be provided if block_table is provided"
1342 if softmax_scale is None:
1343 softmax_scale = q.shape[-1] ** (-0.5)
1344 # custom op does not support non-tuple input
1345 if window_size is None:
1346 real_window_size = (-1, -1)
1347 else:
1348 assert len(window_size) == 2
1349 real_window_size = (window_size[0], window_size[1])
1350 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1351 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1352 max_seqlen_q = (
1353 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1354 )
1355 max_seqlen_k = (
1356 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1357 )
1358 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1359 q,
1360 k,
1361 v,
1362 out,
1363 cu_seqlens_q,
1364 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1365 # still wants it so we pass all zeros
1366 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1367 seqused_k,
1368 None,
1369 block_table,
1370 alibi_slopes,
1371 max_seqlen_q,
1372 max_seqlen_k,
1373 dropout_p,
1374 softmax_scale,
1375 False,
1376 causal,
1377 real_window_size[0],
1378 real_window_size[1],
1379 softcap,
1380 return_softmax_lse and dropout_p > 0,
1381 None,
1382 )
1384 return (out, softmax_lse) if return_softmax_lse else out