Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/attention.py: 0%
399 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from functools import partial
5import torch
6import torch.nn.functional as F
7import triton
8import triton.language as tl
10from flag_gems import runtime
11from flag_gems.config import use_c_extension
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
15from .flash_api import mha_fwd, mha_varlan_fwd
16from .flash_kernel import keep
18logger = logging.getLogger(__name__)
21# Modified from Triton tutorial: https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
22@triton.jit
23def _attn_fwd_inner(
24 acc,
25 l_i,
26 m_i,
27 query, #
28 K_block_ptr,
29 V_block_ptr, #
30 mask_block_ptr, #
31 stride_k_seqlen,
32 stride_v_seqlen,
33 stride_attn_mask_kv_seqlen, #
34 start_m,
35 qk_scale, #
36 q_load_mask,
37 BLOCK_M: tl.constexpr,
38 HEAD_DIM: tl.constexpr,
39 BLOCK_N: tl.constexpr, #
40 STAGE: tl.constexpr,
41 offs_m: tl.constexpr,
42 offs_n: tl.constexpr, #
43 KV_CTX: tl.constexpr,
44 fp8_v: tl.constexpr,
45 HAS_ATTN_MASK: tl.constexpr,
46 PRE_LOAD_V: tl.constexpr,
47):
48 # range of values handled by this stage
49 if STAGE == 1:
50 lo, hi = 0, start_m * BLOCK_M
51 elif STAGE == 2:
52 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
53 # causal = False
54 else:
55 lo, hi = 0, KV_CTX
57 K_block_ptr += lo * stride_k_seqlen
58 V_block_ptr += lo * stride_v_seqlen
59 if HAS_ATTN_MASK:
60 mask_block_ptr += lo * stride_attn_mask_kv_seqlen
62 LOG2E = 1.44269504 # log2(e) constant
64 # loop over key, value and update accumulator
65 for start_n in range(lo, hi, BLOCK_N):
66 kv_load_mask = (start_n + offs_n) < KV_CTX
67 # start_n = tl.multiple_of(start_n, BLOCK_N)
68 # -- compute qk ----
69 key = tl.load(K_block_ptr, mask=kv_load_mask[None, :], other=0.0)
70 if PRE_LOAD_V:
71 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
73 qk = tl.dot(query, key, allow_tf32=False)
74 # incase not divisible.
75 qk = tl.where(kv_load_mask[None, :], qk, -float("inf"))
76 # qk = qk.to(tl.float32)
78 if HAS_ATTN_MASK:
79 attn_mask = tl.load(
80 mask_block_ptr,
81 mask=q_load_mask[:, None] & kv_load_mask[None, :],
82 other=0.0,
83 )
85 if STAGE == 2:
86 mask = offs_m[:, None] >= (start_n + offs_n[None, :])
88 if HAS_ATTN_MASK:
89 qk = qk * qk_scale + attn_mask
90 qk *= LOG2E
91 qk = qk + tl.where(mask, 0, -1.0e6)
92 else:
93 qk = qk * qk_scale * LOG2E + tl.where(mask, 0, -1.0e6)
95 m_ij = tl.maximum(m_i, tl.max(qk, 1))
96 qk -= m_ij[:, None]
97 else:
98 qk *= qk_scale * LOG2E
99 if HAS_ATTN_MASK:
100 qk = qk + attn_mask
101 m_ij = tl.maximum(m_i, tl.max(qk, 1))
102 qk = qk - m_ij[:, None]
104 p = tl.math.exp2(qk)
105 l_ij = tl.sum(p, 1)
106 # -- update m_i and l_i
107 alpha = tl.math.exp2(m_i - m_ij)
108 l_i = l_i * alpha + l_ij
109 # -- update output accumulator --
110 acc = acc * alpha[:, None]
111 # update acc
112 if not PRE_LOAD_V:
113 value = tl.load(V_block_ptr, mask=kv_load_mask[:, None], other=0.0)
114 if fp8_v:
115 p = p.to(tl.float8e5)
116 else:
117 p = p.to(query.dtype)
118 p = p.to(value.dtype)
119 acc = tl.dot(p, value, acc, allow_tf32=False)
120 # update m_i and l_i
121 m_i = m_ij
123 K_block_ptr += BLOCK_N * stride_k_seqlen
124 V_block_ptr += BLOCK_N * stride_v_seqlen
126 if HAS_ATTN_MASK:
127 mask_block_ptr += BLOCK_N * stride_attn_mask_kv_seqlen
129 return acc, l_i, m_i
132# NOTE: we assert BLOCK_N <= HEAD_DIM in _attn_fwd, so for small head_dim,
133# we need to generate more configs.
134configs = runtime.get_tuned_config("attention")
135SMALL_HEAD_DIM_CONFIGS = [
136 triton.Config(
137 {"BLOCK_M": BM, "BLOCK_N": BN, "PRE_LOAD_V": 0}, num_stages=s, num_warps=w
138 )
139 for BM in [64, 128]
140 for BN in [16, 32]
141 for s in [2, 3, 4]
142 for w in [4, 8]
143]
144configs += SMALL_HEAD_DIM_CONFIGS
147@libentry()
148@libtuner(
149 # configs=list(filter(partial(keep, must_keep=SMALL_HEAD_DIM_CONFIGS), configs)),
150 configs=list(filter(partial(keep), configs)),
151 key=["KV_CTX", "HEAD_DIM"],
152)
153@triton.jit
154def _attn_fwd(
155 Q,
156 K,
157 V,
158 attn_mask,
159 sm_scale,
160 M,
161 Out, #
162 stride_q_batch,
163 stride_q_head,
164 stride_q_seqlen,
165 stride_q_headsize,
166 stride_k_batch,
167 stride_k_head,
168 stride_k_seqlen,
169 stride_k_headsize,
170 stride_v_batch,
171 stride_v_head,
172 stride_v_seqlen,
173 stride_v_headsize,
174 stride_attn_mask_batch,
175 stride_attn_mask_head,
176 stride_attn_mask_q_seqlen,
177 stride_attn_mask_kv_seqlen,
178 stride_o_batch,
179 stride_o_head,
180 stride_o_seqlen,
181 stride_o_headsize,
182 Z,
183 q_head_num,
184 kv_head_num,
185 GROUP_HEAD: tl.constexpr,
186 Q_CTX,
187 KV_CTX,
188 HEAD_DIM: tl.constexpr,
189 BLOCK_M: tl.constexpr,
190 BLOCK_N: tl.constexpr,
191 STAGE: tl.constexpr,
192 HAS_ATTN_MASK: tl.constexpr,
193 PRE_LOAD_V: tl.constexpr,
194):
195 tl.static_assert(BLOCK_N <= HEAD_DIM)
196 start_m = tl.program_id(0)
197 off_hz = tl.program_id(1)
198 batch_id = off_hz // q_head_num
199 head_id = off_hz % q_head_num
200 kv_head_id = head_id // GROUP_HEAD
202 q_offset = (
203 batch_id.to(tl.int64) * stride_q_batch + head_id.to(tl.int64) * stride_q_head
204 )
205 o_offset = (
206 batch_id.to(tl.int64) * stride_o_batch + head_id.to(tl.int64) * stride_o_head
207 )
208 kv_offset = (
209 batch_id.to(tl.int64) * stride_k_batch + kv_head_id.to(tl.int64) * stride_k_head
210 )
212 offs_headsize = tl.arange(0, HEAD_DIM)
214 # initialize offsets
215 offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
216 q_load_mask = offs_m < Q_CTX
217 offs_n = tl.arange(0, BLOCK_N)
219 Q_block_ptr = (
220 Q
221 + q_offset
222 + offs_m[:, None] * stride_q_seqlen
223 + offs_headsize[None, :] * stride_q_headsize
224 )
225 K_block_ptr = (
226 K
227 + kv_offset
228 + offs_n[None, :] * stride_k_seqlen
229 + offs_headsize[:, None] * stride_k_headsize
230 )
231 V_block_ptr = (
232 V
233 + kv_offset
234 + offs_n[:, None] * stride_v_seqlen
235 + offs_headsize[None, :] * stride_v_headsize
236 )
238 if HAS_ATTN_MASK:
239 attn_mask_offset = (
240 batch_id.to(tl.int64) * stride_attn_mask_batch
241 + head_id.to(tl.int64) * stride_attn_mask_head
242 )
243 mask_block_ptr = (
244 attn_mask
245 + attn_mask_offset
246 + offs_m[:, None] * stride_attn_mask_q_seqlen
247 + offs_n[None, :] * stride_attn_mask_kv_seqlen
248 )
249 else:
250 mask_block_ptr = None
252 O_block_ptr = (
253 Out
254 + o_offset
255 + offs_m[:, None] * stride_o_seqlen
256 + offs_headsize[None, :] * stride_o_headsize
257 )
259 # initialize pointer to m and l
260 m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
261 l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
262 acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
263 # load scales
264 qk_scale = sm_scale
265 # qk_scale *= 1.44269504 # 1/log(2)
266 # load query: it will stay in SRAM throughout
267 query = tl.load(Q_block_ptr, mask=q_load_mask[:, None], other=0.0)
268 # stage 1: off-band
269 # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
270 # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
271 if STAGE & 1:
272 acc, l_i, m_i = _attn_fwd_inner(
273 acc,
274 l_i,
275 m_i,
276 query,
277 K_block_ptr,
278 V_block_ptr,
279 mask_block_ptr,
280 stride_k_seqlen,
281 stride_v_seqlen,
282 stride_attn_mask_kv_seqlen,
283 start_m,
284 qk_scale,
285 q_load_mask,
286 BLOCK_M,
287 HEAD_DIM,
288 BLOCK_N,
289 4 - STAGE,
290 offs_m,
291 offs_n,
292 KV_CTX,
293 V.dtype.element_ty == tl.float8e5,
294 HAS_ATTN_MASK,
295 PRE_LOAD_V,
296 )
297 # stage 2: on-band
298 if STAGE & 2:
299 # barrier makes it easier for compielr to schedule the
300 # two loops independently
301 acc, l_i, m_i = _attn_fwd_inner(
302 acc,
303 l_i,
304 m_i,
305 query,
306 K_block_ptr,
307 V_block_ptr,
308 mask_block_ptr,
309 stride_k_seqlen,
310 stride_v_seqlen,
311 stride_attn_mask_kv_seqlen,
312 start_m,
313 qk_scale,
314 q_load_mask,
315 BLOCK_M,
316 HEAD_DIM,
317 BLOCK_N,
318 2,
319 offs_m,
320 offs_n,
321 KV_CTX,
322 V.dtype.element_ty == tl.float8e5,
323 HAS_ATTN_MASK,
324 PRE_LOAD_V,
325 )
326 # epilogue
327 m_i += tl.math.log2(l_i)
328 acc = acc / l_i[:, None]
329 m_ptrs = M + off_hz * Q_CTX + offs_m
330 tl.store(m_ptrs, m_i, mask=q_load_mask)
331 tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])
334@triton.jit
335def _attn_bwd_preprocess(
336 O, DO, Delta, Z, H, Q_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
337):
338 off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
339 mask = off_m < Q_CTX
341 off_hz = tl.program_id(1)
342 off_n = tl.arange(0, D_HEAD)
343 # load
344 o = tl.load(
345 O + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
346 mask=mask[:, None],
347 other=0.0,
348 )
349 do = tl.load(
350 DO + off_hz * D_HEAD * Q_CTX + off_m[:, None] * D_HEAD + off_n[None, :],
351 mask=mask[:, None],
352 other=0.0,
353 ).to(tl.float32)
354 delta = tl.sum(o * do, axis=1)
355 # write-back
356 tl.store(Delta + off_hz * Q_CTX + off_m, delta, mask=mask)
359# The main inner-loop logic for computing dK and dV.
360@triton.jit
361def _attn_bwd_dkdv(
362 dk,
363 dv, #
364 Q,
365 key,
366 value,
367 sm_scale, #
368 DO, #
369 M,
370 D, #
371 # shared by Q/K/V/DO.
372 stride_tok,
373 stride_d, #
374 H,
375 Q_CTX,
376 KV_CTX,
377 BLOCK_M1: tl.constexpr, #
378 BLOCK_N1: tl.constexpr, #
379 BLOCK_DMODEL: tl.constexpr, #
380 # Filled in by the wrapper.
381 start_n,
382 start_m,
383 num_steps, #
384 MASK: tl.constexpr,
385):
386 # BLOCK_M1: 32
387 # BLOCK_N1: 128
388 offs_n = start_n + tl.arange(0, BLOCK_N1)
389 offs_n_mask = offs_n < KV_CTX # (BLOCK_N1, )
391 offs_k = tl.arange(0, BLOCK_DMODEL) # (BLOCK_DMODEL, )
393 # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
394 tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
395 curr_m = start_m
396 step_m = BLOCK_M1
397 for blk_idx in range(num_steps):
398 offs_m = curr_m + tl.arange(0, BLOCK_M1) # (BLOCK_M1, )
399 offs_m_mask = offs_m < Q_CTX # (BLOCK_M1, )
401 qT_ptrs = (
402 Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
403 ) # (BLOCK_DMODEL, BLOCK_M1)
404 do_ptrs = (
405 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
406 ) # (BLOCK_M1, BLOCK_DMODEL)
408 qT = tl.load(
409 qT_ptrs, mask=offs_m_mask[None, :], other=0.0
410 ) # (BLOCK_DMODEL, BLOCK_M1)
412 # Load m before computing qk to reduce pipeline stall.
413 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf")) # (BLOCK_M1, )
415 # key: (BLOCK_N1, BLOCK_DMODEL)
416 qkT = tl.dot(key, qT) # (BLOCK_N1, BLOCK_M1)
417 m = tl.broadcast_to(m[None, :], (BLOCK_N1, BLOCK_M1)) # (BLOCK_N1, BLOCK_M1)
418 m = tl.where(offs_n_mask[:, None], m, float("inf")) # (BLOCK_N1, BLOCK_M1)
419 pT = tl.math.exp2(qkT - m)
420 # pT = tl.math.exp2(qkT - m[None, :])
422 mask = (offs_m < Q_CTX)[None, :] & (offs_n < KV_CTX)[
423 :, None
424 ] # (BLOCK_N1, BLOCK_M1)
425 # Autoregressive masking.
426 if MASK:
427 mask &= offs_m[None, :] >= offs_n[:, None]
428 pT = tl.where(mask, pT, 0.0) # (BLOCK_N1, BLOCK_M1)
430 do = tl.load(do_ptrs)
431 # do = tl.load(do_ptrs, mask=offs_m_mask[:, None], other=0.0) # (BLOCK_M1, BLOCK_DMODEL)
433 # Compute dV.
434 dv += tl.dot(pT, do.to(tl.float32)) # (BLOCK_N1, BLOCK_DMODEL)
435 # D (= delta) is pre-divided by ds_scale.
436 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0) # (BLOCK_M1, )
438 # Compute dP and dS.
439 dpT = tl.dot(value, tl.trans(do)).to(
440 tl.float32
441 ) # (BLOCK_N1, BLOCK_DMODEL) @ (BLOCK_M1, BLOCK_DMODEL).T -> (BLOCK_N1, BLOCK_M1)
442 dsT = pT * (dpT - Di[None, :]) # (BLOCK_N1, BLOCK_M1)
443 dsT = dsT.to(qT.dtype)
444 qT = tl.where(offs_m_mask[None, :], qT, 0.0) # (BLOCK_DMODEL, BLOCK_M1)
445 dsT = tl.where(
446 offs_m_mask[None, :] & offs_n_mask[:, None], dsT, 0.0
447 ) # (BLOCK_N1, BLOCK_M1)
448 dk += tl.dot(
449 dsT, tl.trans(qT)
450 ) # (BLOCK_N1, BLOCK_M1) @ (BLOCK_DMODEL, BLOCK_M1).T -> (BLOCK_N1, BLOCK_DMODEL)
451 # Increment pointers.
452 curr_m += step_m
453 return dk, dv
456# the main inner-loop logic for computing dQ
457@triton.jit
458def _attn_bwd_dq(
459 dq,
460 query,
461 K,
462 V, #
463 do,
464 m,
465 D,
466 # shared by Q/K/V/DO.
467 stride_tok,
468 stride_d, #
469 H,
470 Q_CTX, #
471 KV_CTX, #
472 BLOCK_M2: tl.constexpr, #
473 BLOCK_N2: tl.constexpr, #
474 BLOCK_DMODEL: tl.constexpr,
475 # Filled in by the wrapper.
476 start_m,
477 start_n,
478 num_steps, #
479 MASK: tl.constexpr,
480):
481 offs_m = start_m + tl.arange(0, BLOCK_M2)
482 offs_m_mask = offs_m < Q_CTX
484 offs_k = tl.arange(0, BLOCK_DMODEL)
485 # D (= delta) is pre-divided by ds_scale.
486 Di = tl.load(D + offs_m, mask=offs_m_mask, other=0.0)
487 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
488 tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
489 curr_n = start_n
490 step_n = BLOCK_N2
491 for blk_idx in range(num_steps):
492 offs_n = curr_n + tl.arange(0, BLOCK_N2)
493 offs_n_mask = offs_n < KV_CTX
495 kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
496 vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
498 kT = tl.load(kT_ptrs, mask=offs_n_mask[None, :], other=0.0)
499 vT = tl.load(vT_ptrs, mask=offs_n_mask[None, :], other=0.0)
500 qk = tl.dot(query, kT)
501 p = tl.math.exp2(qk - m)
502 mask = (offs_m < Q_CTX)[:, None] & (offs_n < KV_CTX)[None, :]
503 # Autoregressive masking.
504 if MASK:
505 # mask = (offs_m[:, None] >= offs_n[None, :])
506 # mask = (offs_m[:, None] >= offs_n[None, :]) & (offs_m < N_CTX)[:, None] & (offs_n < N_CTX)[None, :]
507 mask &= offs_m[:, None] >= offs_n[None, :]
508 p = tl.where(mask, p, 0.0)
509 # Compute dP and dS.
510 dp = tl.dot(do, vT).to(tl.float32)
511 ds = p * (dp - Di[:, None])
512 ds = tl.where(mask, ds, 0.0).to(kT.dtype)
513 # Compute dQ.
514 # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
515 dq += tl.dot(ds, tl.trans(kT))
516 # Increment pointers.
517 curr_n += step_n
518 return dq
521config_backward = runtime.get_tuned_config("attention_bwd")
524@libentry()
525@libtuner(
526 configs=config_backward,
527 key=["KV_CTX", "BLOCK_DMODEL"],
528)
529@triton.jit
530def _attn_bwd(
531 Q,
532 K,
533 V,
534 sm_scale, #
535 DO, #
536 DQ,
537 DK,
538 DV, #
539 M,
540 D,
541 # shared by Q/K/V/DO.
542 stride_z,
543 stride_h,
544 stride_tok,
545 stride_d, #
546 kv_stride_z,
547 kv_stride_h, #
548 H, # query head num
549 Q_CTX, #
550 KV_CTX, #
551 kv_head_num, #
552 GROUP_HEAD: tl.constexpr, #
553 BLOCK_M1: tl.constexpr, #
554 BLOCK_N1: tl.constexpr, #
555 BLOCK_M2: tl.constexpr, #
556 BLOCK_N2: tl.constexpr, #
557 BLK_SLICE_FACTOR: tl.constexpr, #
558 BLOCK_DMODEL: tl.constexpr,
559):
560 tl.device_assert(Q_CTX % BLOCK_M1 == 0, "Q_CTX must be a multiple of BLOCK_M1.")
562 LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
564 bhid = tl.program_id(2)
565 off_chz = (bhid * Q_CTX).to(tl.int64)
566 batch_id = bhid // H
567 q_head_id = bhid % H
568 kv_head_id = q_head_id // GROUP_HEAD
569 adj = (stride_h * q_head_id + stride_z * batch_id).to(tl.int64)
570 kv_adj = (kv_stride_h * kv_head_id + kv_stride_z * batch_id).to(tl.int64)
572 pid = tl.program_id(0)
574 # offset pointers for batch/head
575 Q += adj
576 K += kv_adj
577 V += kv_adj
578 DO += adj
579 DQ += adj
580 DK += adj
581 DV += adj
582 M += off_chz
583 D += off_chz
585 # load scales
586 offs_k = tl.arange(0, BLOCK_DMODEL)
588 start_n = pid * BLOCK_N1
589 start_m = start_n
591 MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
592 offs_n = start_n + tl.arange(0, BLOCK_N1)
593 offs_n_mask = offs_n < KV_CTX
595 dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
596 dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
598 # load K and V: they stay in SRAM throughout the inner loop.
599 key = tl.load(
600 K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
601 mask=offs_n_mask[:, None],
602 other=0.0,
603 )
604 value = tl.load(
605 V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d,
606 mask=offs_n_mask[:, None],
607 other=0.0,
608 )
610 num_steps = BLOCK_N1 // MASK_BLOCK_M1
612 dk, dv = _attn_bwd_dkdv(
613 dk,
614 dv, #
615 Q,
616 key,
617 value,
618 sm_scale, #
619 DO, #
620 M,
621 D, #
622 stride_tok,
623 stride_d, #
624 H,
625 Q_CTX, #
626 KV_CTX, #
627 MASK_BLOCK_M1,
628 BLOCK_N1,
629 BLOCK_DMODEL, #
630 start_n,
631 start_m,
632 num_steps, #
633 MASK=True, #
634 )
636 # Compute dK and dV for non-masked blocks.
637 start_m += num_steps * MASK_BLOCK_M1
638 remaining_m = Q_CTX - start_m
639 num_steps = (remaining_m + BLOCK_M1 - 1) // BLOCK_M1
641 if num_steps > 0 and start_m < Q_CTX:
642 dk, dv = _attn_bwd_dkdv( #
643 dk,
644 dv, #
645 Q,
646 key,
647 value,
648 sm_scale, #
649 DO, #
650 M,
651 D, #
652 stride_tok,
653 stride_d, #
654 H,
655 Q_CTX, #
656 KV_CTX, #
657 BLOCK_M1,
658 BLOCK_N1,
659 BLOCK_DMODEL, #
660 start_n,
661 start_m,
662 num_steps, #
663 MASK=False, #
664 )
665 # tl.device_print("dv: ", dv)
667 dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
668 tl.store(dv_ptrs, dv, mask=offs_n_mask[:, None])
670 # Write back dK.
671 dk *= sm_scale
672 dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
673 tl.store(dk_ptrs, dk, mask=offs_n_mask[:, None])
675 # THIS BLOCK DOES DQ:
676 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
677 start_m = pid * BLOCK_M2
678 end_n = min(start_m + BLOCK_M2, KV_CTX) # Ensure end_n does not exceed N_CTX
679 num_steps = (end_n - start_n + MASK_BLOCK_N2 - 1) // MASK_BLOCK_N2
681 offs_m = start_m + tl.arange(0, BLOCK_M2)
682 offs_m_mask = offs_m < Q_CTX
684 query = tl.load(
685 Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
686 mask=offs_m_mask[:, None],
687 other=0.0,
688 )
689 dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
690 do = tl.load(
691 DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
692 mask=offs_m_mask[:, None],
693 other=0.0,
694 )
696 m = tl.load(M + offs_m, mask=offs_m_mask, other=float("inf"))
697 m = m[:, None]
699 # Stage 1 - Compute dQ for masked (diagonal) blocks.
700 # NOTE: This code scans each row of QK^T backward (from right to left,
701 # but inside each call to _attn_bwd_dq, from left to right), but that's
702 # not due to anything important. I just wanted to reuse the loop
703 # structure for dK & dV above as much as possible.
705 if num_steps > 0:
706 dq = _attn_bwd_dq(
707 dq,
708 query,
709 K,
710 V, #
711 do,
712 m,
713 D, #
714 stride_tok,
715 stride_d, #
716 H,
717 Q_CTX, #
718 KV_CTX, #
719 BLOCK_M2,
720 MASK_BLOCK_N2,
721 BLOCK_DMODEL, #
722 start_m,
723 start_n,
724 num_steps, #
725 MASK=True, #
726 )
728 # Stage 2 - non-masked blocks
729 stage2_end_n = start_n
730 stage2_num_steps = (stage2_end_n + BLOCK_N2 - 1) // BLOCK_N2
732 if stage2_num_steps > 0:
733 dq = _attn_bwd_dq(
734 dq,
735 query,
736 K,
737 V, #
738 do,
739 m,
740 D, #
741 stride_tok,
742 stride_d, #
743 H,
744 Q_CTX, #
745 KV_CTX, #
746 BLOCK_M2,
747 BLOCK_N2,
748 BLOCK_DMODEL, #
749 start_m,
750 stage2_end_n - stage2_num_steps * BLOCK_N2,
751 stage2_num_steps, #
752 MASK=False, #
753 )
754 # Write back dQ.
755 dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
756 dq *= LN2
757 # tl.store(dq_ptrs, dq)
759 tl.store(dq_ptrs, dq, mask=offs_m_mask[:, None])
762def scaled_dot_product_attention_forward(
763 query,
764 key,
765 value,
766 attn_mask=None,
767 dropout_p=0.0,
768 is_causal=False,
769 scale=None,
770 enable_gqa=False,
771):
772 logger.debug("GEMS_TSINGMICRO SCALED_DOT_PRODUCT_ATTENTION")
773 # shape constraints
774 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
775 # when v is in float8_e5m2 it is transposed.
776 HEAD_DIM_V = value.shape[-1]
777 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
778 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
779 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
781 o = torch.empty_like(query, dtype=value.dtype)
783 stage = 3 if is_causal else 1
785 if scale is None:
786 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
787 else:
788 sm_scale = scale
790 q_head_num = query.shape[1]
791 kv_head_num = key.shape[1]
792 assert enable_gqa or q_head_num == kv_head_num, (
793 f"q_head_num {q_head_num} != kv_head_num {kv_head_num}, "
794 "enable_gqa must be True to support different head numbers."
795 )
797 grid = lambda args: (
798 triton.cdiv(query.shape[2], args["BLOCK_M"]),
799 query.shape[0] * query.shape[1],
800 1,
801 )
803 if attn_mask is not None:
804 HAS_ATTN_MASK = True
805 if attn_mask.dtype == torch.bool:
806 attn_mask = attn_mask.to(query.dtype) * -1.0e6
807 stride_attn_mask_batch = attn_mask.stride(0)
808 stride_attn_mask_head = attn_mask.stride(1)
809 stride_attn_mask_q_seqlen = attn_mask.stride(2)
810 stride_attn_mask_kv_seqlen = attn_mask.stride(3)
811 else:
812 HAS_ATTN_MASK = False
813 stride_attn_mask_batch = 1
814 stride_attn_mask_head = 1
815 stride_attn_mask_q_seqlen = 1
816 stride_attn_mask_kv_seqlen = 1
818 M = torch.empty(
819 (query.shape[0], query.shape[1], query.shape[2]),
820 device=query.device,
821 dtype=torch.float32,
822 )
824 with torch_device_fn.device(query.device):
825 _attn_fwd[grid](
826 query,
827 key,
828 value,
829 attn_mask,
830 sm_scale,
831 M,
832 o, #
833 query.stride(0),
834 query.stride(1),
835 query.stride(2),
836 query.stride(3), #
837 key.stride(0),
838 key.stride(1),
839 key.stride(2),
840 key.stride(3), #
841 value.stride(0),
842 value.stride(1),
843 value.stride(2),
844 value.stride(3), #
845 stride_attn_mask_batch,
846 stride_attn_mask_head,
847 stride_attn_mask_q_seqlen,
848 stride_attn_mask_kv_seqlen, #
849 o.stride(0),
850 o.stride(1),
851 o.stride(2),
852 o.stride(3), #
853 query.shape[0],
854 q_head_num,
855 kv_head_num, #
856 q_head_num // kv_head_num, # group_head
857 query.shape[2], #
858 key.shape[2], #
859 HEAD_DIM_K, #
860 STAGE=stage, #
861 HAS_ATTN_MASK=HAS_ATTN_MASK, #
862 )
863 return o, M
866def scaled_dot_product_attention_backward(
867 do,
868 query,
869 key,
870 value,
871 o,
872 M,
873 attn_mask=None,
874 dropout_p=0.0,
875 is_causal=False,
876 scale=None,
877 enable_gqa=False,
878):
879 logger.debug("GEMS_TSINGMICRO SCALED_DOT_PRODUCT_ATTENTION_BACKWARD")
880 # shape constraints
881 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
882 # when v is in float8_e5m2 it is transposed.
883 HEAD_DIM_V = value.shape[-1]
884 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
885 assert HEAD_DIM_K in {16, 32, 64, 128, 256}
886 assert dropout_p == 0.0, "Currenty only support dropout_p=0.0"
888 if scale is None:
889 sm_scale = 1.0 / (HEAD_DIM_K**0.5)
890 else:
891 sm_scale = scale
893 assert do.is_contiguous()
894 assert (
895 query.is_contiguous()
896 and key.is_contiguous()
897 and value.is_contiguous()
898 and o.is_contiguous()
899 )
900 assert query.stride() == o.stride() == do.stride()
901 assert key.stride() == value.stride()
903 BLOCK_DMODEL = HEAD_DIM_K
904 BATCH, Q_HEAD, Q_CTX = query.shape[:3]
905 _, KV_HEAD, KV_CTX = key.shape[:3]
906 group_head = Q_HEAD // KV_HEAD
908 # NUM_WARPS, NUM_STAGES = 4, 1
909 # BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
910 BLK_SLICE_FACTOR = 2
911 # RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
913 RCP_LN2 = 1.0 / math.log(2)
915 arg_k = key * (sm_scale * RCP_LN2)
916 # PRE_BLOCK = 128
917 PRE_BLOCK = 256
919 # PRE_BLOCK = 32
920 # assert N_CTX % PRE_BLOCK == 0
921 # pre_grid = (N_CTX // PRE_BLOCK, BATCH * Q_HEAD)
922 pre_grid = (triton.cdiv(Q_CTX, PRE_BLOCK), BATCH * Q_HEAD)
924 delta = torch.empty_like(M)
926 # NOTE that dk & dv always have the same number of heads as q
927 dq = torch.empty_like(query).contiguous()
928 dk = torch.empty(
929 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_K),
930 device=key.device,
931 dtype=key.dtype,
932 memory_format=torch.contiguous_format,
933 )
934 dv = torch.empty(
935 (BATCH, Q_HEAD, KV_CTX, HEAD_DIM_V),
936 device=value.device,
937 dtype=value.dtype,
938 memory_format=torch.contiguous_format,
939 )
941 _attn_bwd_preprocess[pre_grid](
942 o,
943 do, #
944 delta, #
945 BATCH,
946 Q_HEAD,
947 Q_CTX, #
948 BLOCK_M=PRE_BLOCK,
949 D_HEAD=BLOCK_DMODEL, #
950 )
952 max_block_n1 = (
953 max([cfg.kwargs["BLOCK_N1"] for cfg in config_backward])
954 if config_backward
955 else 128
956 )
957 grid = (triton.cdiv(Q_CTX, max_block_n1), 1, BATCH * Q_HEAD)
958 # logger.info(f"{triton.cdiv(Q_CTX, BLOCK_N1)=}")
959 # logger.info(f"{M.shape=}")
961 _attn_bwd[grid](
962 query,
963 arg_k,
964 value,
965 sm_scale,
966 do,
967 dq,
968 dk,
969 dv, #
970 M,
971 delta, #
972 query.stride(0),
973 query.stride(1),
974 query.stride(2),
975 query.stride(3), #
976 key.stride(0),
977 key.stride(1), #
978 Q_HEAD,
979 Q_CTX, #
980 KV_CTX, #
981 KV_HEAD, #
982 GROUP_HEAD=group_head, #
983 # BLOCK_M1=BLOCK_M1,
984 # BLOCK_N1=BLOCK_N1, #
985 # BLOCK_M2=BLOCK_M2,
986 # BLOCK_N2=BLOCK_N2, #
987 BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
988 BLOCK_DMODEL=BLOCK_DMODEL, #
989 # num_warps=NUM_WARPS, #
990 # num_stages=NUM_STAGES, #
991 )
993 if group_head > 1:
994 dk = dk.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_K)
995 dv = dv.reshape(BATCH, Q_HEAD // group_head, group_head, KV_CTX, HEAD_DIM_V)
996 dk = dk.sum(dim=2)
997 dv = dv.sum(dim=2)
999 return dq, dk, dv
1002class ScaleDotProductAttention(torch.autograd.Function):
1003 @staticmethod
1004 def forward(
1005 ctx,
1006 query,
1007 key,
1008 value,
1009 attn_mask=None,
1010 dropout_p=0.0,
1011 is_causal=False,
1012 scale=None,
1013 enable_gqa=False,
1014 ):
1015 sm_scale = scale if scale is not None else 1.0 / (key.shape[-1] ** 0.5)
1016 o, M = scaled_dot_product_attention_forward(
1017 query,
1018 key,
1019 value,
1020 attn_mask,
1021 dropout_p,
1022 is_causal,
1023 sm_scale,
1024 enable_gqa,
1025 )
1027 ctx.save_for_backward(query, key, value, o, M)
1028 ctx.sm_scale = sm_scale
1029 ctx.causal = is_causal
1030 ctx.enable_gqa = enable_gqa
1031 return o
1033 @staticmethod
1034 def backward(ctx, do):
1035 query, key, value, o, M = ctx.saved_tensors
1036 is_causal = ctx.causal
1037 enable_gqa = ctx.enable_gqa
1038 sm_scale = ctx.sm_scale
1039 dq, dk, dv = scaled_dot_product_attention_backward(
1040 do,
1041 query,
1042 key,
1043 value,
1044 o,
1045 M,
1046 attn_mask=None,
1047 dropout_p=0.0,
1048 is_causal=is_causal,
1049 scale=sm_scale,
1050 enable_gqa=enable_gqa,
1051 )
1052 return dq, dk, dv, None, None, None, None, None
1055def scaled_dot_product_attention(
1056 query,
1057 key,
1058 value,
1059 attn_mask=None,
1060 dropout_p=0.0,
1061 is_causal=False,
1062 scale=None,
1063 enable_gqa=False,
1064):
1065 return ScaleDotProductAttention.apply(
1066 query,
1067 key,
1068 value,
1069 attn_mask,
1070 dropout_p,
1071 is_causal,
1072 scale,
1073 enable_gqa,
1074 )
1077def flash_attention_forward(
1078 query,
1079 key,
1080 value,
1081 cumulative_sequence_length_q,
1082 cumulative_sequence_length_k,
1083 max_q,
1084 max_k,
1085 dropout_p,
1086 is_causal,
1087 return_debug_mask,
1088 *,
1089 scale=None,
1090 softcap=0.0,
1091 window_size_left=None,
1092 window_size_right=None,
1093 seqused_k=None,
1094 alibi_slopes=None,
1095 disable_splitkv=False,
1096):
1097 logger.debug("GEMS_TSINGMICRO FLASH_ATTENTION_FORWARD")
1098 assert (
1099 cumulative_sequence_length_q is None and cumulative_sequence_length_k is None
1100 ), "varlen is not supported yet."
1102 HEAD_DIM_Q, HEAD_DIM_K = query.shape[-1], key.shape[-1]
1103 HEAD_DIM_V = value.shape[-1]
1104 assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
1105 original_head_dim = HEAD_DIM_K
1106 supported_head_dims = (16, 32, 64, 96, 128, 192, 256)
1107 if HEAD_DIM_K not in supported_head_dims:
1108 padded_head_dim = None
1109 for d in supported_head_dims:
1110 if d >= HEAD_DIM_K:
1111 padded_head_dim = d
1112 break
1113 assert (
1114 padded_head_dim is not None
1115 ), f"Unsupported head dim {HEAD_DIM_K}, max supported is {supported_head_dims[-1]}"
1116 pad = padded_head_dim - HEAD_DIM_K
1117 query = F.pad(query, (0, pad))
1118 key = F.pad(key, (0, pad))
1119 value = F.pad(value, (0, pad))
1120 HEAD_DIM_K = padded_head_dim
1122 softmax_scale = scale or 1.0 / (original_head_dim**0.5)
1123 if window_size_left is not None:
1124 non_null_window_left = window_size_left
1125 else:
1126 non_null_window_left = -1
1127 if window_size_right is not None:
1128 non_null_window_right = window_size_right
1129 else:
1130 non_null_window_right = -1
1132 out = torch.empty_like(query)
1133 if cumulative_sequence_length_q is not None:
1134 out, q, k, v, lse, philox_seed, philox_offset, p = mha_varlan_fwd(
1135 query,
1136 key,
1137 value,
1138 out,
1139 cumulative_sequence_length_q,
1140 cumulative_sequence_length_k,
1141 seqused_k,
1142 None,
1143 None, # block_table
1144 alibi_slopes,
1145 max_q,
1146 max_k,
1147 dropout_p,
1148 scale,
1149 False,
1150 is_causal,
1151 non_null_window_left,
1152 non_null_window_right,
1153 softcap,
1154 return_debug_mask and dropout_p > 0,
1155 None,
1156 )
1157 else:
1158 out, q, k, v, lse, philox_seed, philox_offset, p = mha_fwd(
1159 query,
1160 key,
1161 value,
1162 out,
1163 alibi_slopes,
1164 dropout_p,
1165 softmax_scale,
1166 is_causal,
1167 non_null_window_left,
1168 non_null_window_right,
1169 softcap,
1170 return_debug_mask,
1171 disable_splitkv=disable_splitkv,
1172 )
1174 if HEAD_DIM_K != original_head_dim:
1175 out = out[..., :original_head_dim]
1176 return (out, lse, philox_seed, philox_offset, p)
1179# Adapted from https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py
1180def maybe_contiguous(x):
1181 return x.contiguous() if x is not None and x.stride(-1) != 1 else x
1184def flash_attn_varlen_func(
1185 q,
1186 k,
1187 v,
1188 max_seqlen_q,
1189 cu_seqlens_q,
1190 max_seqlen_k,
1191 cu_seqlens_k=None, # only used for non-paged prefill
1192 seqused_k=None,
1193 q_v=None,
1194 dropout_p=0.0,
1195 softmax_scale=None,
1196 causal=False,
1197 window_size=None,
1198 softcap=0.0, # 0.0 means deactivated
1199 alibi_slopes=None,
1200 deterministic=False,
1201 return_attn_probs=False,
1202 block_table=None,
1203 return_softmax_lse=False,
1204 out=None,
1205 # Dummy FA3 arguments
1206 scheduler_metadata=None,
1207 q_descale=None,
1208 k_descale=None,
1209 v_descale=None,
1210 s_aux=None,
1211 num_splits: int = 0,
1212 cp_world_size: int = 1,
1213 cp_rank: int = 0,
1214 cp_tot_seqused_k=None,
1215 fa_version: int = 2,
1216):
1217 """dropout_p should be set to 0.0 during evaluation
1218 Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
1219 than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
1220 For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
1221 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
1223 If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
1224 For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1225 1 1 1 1 0
1226 1 1 1 1 1
1227 If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
1228 0 0
1229 0 0
1230 0 0
1231 1 0
1232 1 1
1233 If the row of the mask is all zero, the output will be zero.
1235 If window_size != (-1, -1), implements sliding window local attention. Query at position i
1236 will only attend to keys between
1237 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
1239 Arguments:
1240 q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
1241 k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1242 v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
1243 cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1244 of the sequences in the batch, used to index into q.
1245 cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1246 of the sequences in the batch, used to index into kv.
1247 max_seqlen_q: int. Maximum query sequence length in the batch.
1248 max_seqlen_k: int. Maximum key sequence length in the batch.
1249 dropout_p: float. Dropout probability.
1250 softmax_scale: float. The scaling of QK^T before applying softmax.
1251 Default to 1 / sqrt(headdim).
1252 causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
1253 window_size: (left, right). If not (-1, -1), implements sliding window local attention.
1254 softcap: float. Anything > 0 activates softcapping attention.
1255 alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
1256 (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
1257 is added to the attention score of query i and key j.
1258 deterministic: bool. Whether to use the deterministic implementation of the backward pass,
1259 which is slightly slower and uses more memory. The forward pass is always deterministic.
1260 return_attn_probs: bool. Whether to return the attention probabilities. This option is for
1261 testing only. The returned probabilities are not guaranteed to be correct
1262 (they might not have the right scaling).
1263 Return:
1264 out: (total, nheads, headdim).
1265 softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
1266 logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
1267 normalization factor).
1268 """
1269 if fa_version != 2:
1270 raise RuntimeError("Only FA2 is implemented.")
1271 if num_splits > 0:
1272 raise RuntimeError("num_splits > 0 is not implemented in GEMS.")
1273 if use_c_extension:
1274 logger.debug("GEMS_TSINGMICRO FLASH_ATTN_VARLEN_FUNC (C Extension)")
1275 with torch_device_fn.device(q.device):
1276 out_cpp, softmax_lse = torch.ops.flag_gems.flash_attn_varlen_func(
1277 q,
1278 k,
1279 v,
1280 max_seqlen_q,
1281 cu_seqlens_q,
1282 max_seqlen_k,
1283 cu_seqlens_k,
1284 seqused_k,
1285 q_v,
1286 dropout_p,
1287 softmax_scale,
1288 causal,
1289 window_size,
1290 softcap,
1291 alibi_slopes,
1292 deterministic,
1293 return_attn_probs,
1294 block_table,
1295 return_softmax_lse,
1296 out,
1297 scheduler_metadata,
1298 q_descale,
1299 k_descale,
1300 v_descale,
1301 s_aux,
1302 num_splits,
1303 cp_world_size,
1304 cp_rank,
1305 cp_tot_seqused_k,
1306 fa_version,
1307 )
1308 return (out_cpp, softmax_lse) if return_softmax_lse else out_cpp
1309 else:
1310 logger.debug("GEMS_TSINGMICRO FLASH_ATTN_VARLEN_FUNC")
1311 assert (
1312 cu_seqlens_k is not None or seqused_k is not None
1313 ), "cu_seqlens_k or seqused_k must be provided"
1314 assert (
1315 cu_seqlens_k is None or seqused_k is None
1316 ), "cu_seqlens_k and seqused_k cannot be provided at the same time"
1317 assert (
1318 block_table is None or seqused_k is not None
1319 ), "seqused_k must be provided if block_table is provided"
1320 if softmax_scale is None:
1321 softmax_scale = q.shape[-1] ** (-0.5)
1322 # custom op does not support non-tuple input
1323 if window_size is None:
1324 real_window_size = (-1, -1)
1325 else:
1326 assert len(window_size) == 2
1327 real_window_size = (window_size[0], window_size[1])
1328 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
1329 dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
1330 max_seqlen_q = (
1331 max_seqlen_q.item() if hasattr(max_seqlen_q, "item") else max_seqlen_q
1332 )
1333 max_seqlen_k = (
1334 max_seqlen_k.item() if hasattr(max_seqlen_k, "item") else max_seqlen_k
1335 )
1336 out, q, k, v, softmax_lse, *_ = mha_varlan_fwd(
1337 q,
1338 k,
1339 v,
1340 out,
1341 cu_seqlens_q,
1342 # cu_seqlens_k not used since we use seqused_k, but flash_api.cpp
1343 # still wants it so we pass all zeros
1344 dummy_cu_seqlens_k if cu_seqlens_k is None else cu_seqlens_k,
1345 seqused_k,
1346 None,
1347 block_table,
1348 alibi_slopes,
1349 max_seqlen_q,
1350 max_seqlen_k,
1351 dropout_p,
1352 softmax_scale,
1353 False,
1354 causal,
1355 real_window_size[0],
1356 real_window_size[1],
1357 softcap,
1358 return_softmax_lse and dropout_p > 0,
1359 None,
1360 )
1362 return (out, softmax_lse) if return_softmax_lse else out