Coverage for src/flag_gems/ops/flash_api.py: 86%
537 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
7import flag_gems
8from flag_gems import runtime
9from flag_gems.ops.flash_kernel import (
10 block_m_splitkv_heuristic,
11 block_n_splitkv_heuristic,
12 flash_fwd_kernel,
13 flash_fwd_splitkv_combine_kernel,
14 flash_fwd_splitkv_kernel,
15 flash_varlen_fwd_kernel,
16)
17from flag_gems.runtime import torch_device_fn
18from flag_gems.utils.random_utils import philox_backend_seed_offset
20logger = logging.getLogger(__name__)
21_debug = False
24def CHECK_DEVICE(x):
25 assert x.device.type == flag_gems.device
28class fwd_params:
29 __slots__ = (
30 # pointers and strides
31 "q_ptr",
32 "k_ptr",
33 "v_ptr",
34 "o_ptr",
35 "p_ptr",
36 "softmax_lse_ptr",
37 "q_row_stride",
38 "k_row_stride",
39 "v_row_stride",
40 "q_head_stride",
41 "k_head_stride",
42 "v_head_stride",
43 "o_row_stride",
44 "o_head_stride",
45 "q_batch_stride",
46 "k_batch_stride",
47 "v_batch_stride",
48 "o_batch_stride",
49 "is_cu_seqlens_q",
50 "cu_seqlens_q_ptr",
51 "is_cu_seqlens_k",
52 "cu_seqlens_k_ptr",
53 "is_seqused_k",
54 "seqused_k_ptr",
55 # sizes
56 "b",
57 "bk",
58 "h",
59 "hk",
60 "h_hk_ratio",
61 "seqlen_q",
62 "seqlen_k",
63 "seqlen_q_rounded",
64 "seqlen_k_rounded",
65 "d",
66 "d_rounded",
67 # scaling factors
68 "is_softcap",
69 "softcap",
70 "scale_softmax",
71 "scale_softmax_log2",
72 # dropout
73 "is_dropout",
74 "p_dropout",
75 "rp_dropout",
76 "p_dropout_in_uint8_t",
77 "philox_args",
78 "return_softmax",
79 # masking
80 "is_causal",
81 "is_local",
82 "window_size_left",
83 "window_size_right",
84 "seqlenq_ngroups_swapped",
85 "is_paged",
86 # alibi
87 "is_alibi",
88 "alibi_slopes_ptr",
89 "alibi_slopes_batch_stride",
90 # block table
91 "total_q",
92 "page_table_ptr",
93 "page_table_batch_stride",
94 "block_size",
95 "k_page_stride",
96 )
98 def __init__(
99 self,
100 q_ptr,
101 k_ptr,
102 v_ptr,
103 o_ptr,
104 p_ptr,
105 softmax_lse_ptr,
106 q_row_stride,
107 k_row_stride,
108 v_row_stride,
109 q_head_stride,
110 k_head_stride,
111 v_head_stride,
112 o_row_stride,
113 o_head_stride,
114 q_batch_stride,
115 k_batch_stride,
116 v_batch_stride,
117 o_batch_stride,
118 is_cu_seqlens_q,
119 cu_seqlens_q_ptr,
120 is_cu_seqlens_k,
121 cu_seqlens_k_ptr,
122 is_seqused_k,
123 seqused_k_ptr,
124 # sizes
125 b,
126 bk,
127 h,
128 hk,
129 h_hk_ratio,
130 seqlen_q,
131 seqlen_k,
132 seqlen_q_rounded,
133 seqlen_k_rounded,
134 d,
135 d_rounded,
136 # scaling factors
137 is_softcap,
138 softcap,
139 scale_softmax,
140 scale_softmax_log2,
141 # dropout
142 is_dropout,
143 p_dropout,
144 rp_dropout,
145 p_dropout_in_uint8_t,
146 philox_args,
147 return_softmax,
148 # masking
149 is_causal,
150 is_local,
151 window_size_left,
152 window_size_right,
153 seqlenq_ngroups_swapped,
154 is_paged,
155 # alibi
156 is_alibi,
157 alibi_slopes_ptr,
158 alibi_slopes_batch_stride,
159 # block table
160 total_q,
161 page_table_ptr,
162 page_table_batch_stride,
163 block_size,
164 k_page_stride,
165 ):
166 self.q_ptr = q_ptr
167 self.k_ptr = k_ptr
168 self.v_ptr = v_ptr
169 self.o_ptr = o_ptr
170 self.p_ptr = p_ptr
171 self.softmax_lse_ptr = softmax_lse_ptr
172 self.q_row_stride = q_row_stride
173 self.k_row_stride = k_row_stride
174 self.v_row_stride = v_row_stride
175 self.q_head_stride = q_head_stride
176 self.k_head_stride = k_head_stride
177 self.v_head_stride = v_head_stride
178 self.o_row_stride = o_row_stride
179 self.o_head_stride = o_head_stride
180 self.q_batch_stride = q_batch_stride
181 self.k_batch_stride = k_batch_stride
182 self.v_batch_stride = v_batch_stride
183 self.o_batch_stride = o_batch_stride
184 self.is_cu_seqlens_q = is_cu_seqlens_q
185 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr
186 self.is_cu_seqlens_k = is_cu_seqlens_k
187 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr
188 self.is_seqused_k = is_seqused_k
189 self.seqused_k_ptr = seqused_k_ptr
190 # sizes
191 self.b = b
192 self.bk = bk
193 self.h = h
194 self.hk = hk
195 self.h_hk_ratio = h_hk_ratio
196 self.seqlen_q = seqlen_q
197 self.seqlen_k = seqlen_k
198 self.seqlen_q_rounded = seqlen_q_rounded
199 self.seqlen_k_rounded = seqlen_k_rounded
200 self.d = d
201 self.d_rounded = d_rounded
202 # scaling factors
203 self.is_softcap = is_softcap
204 self.softcap = softcap
205 self.scale_softmax = scale_softmax
206 self.scale_softmax_log2 = scale_softmax_log2
207 # dropout
208 self.is_dropout = is_dropout
209 self.p_dropout = p_dropout
210 self.rp_dropout = rp_dropout
211 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t
212 self.philox_args = philox_args
213 self.return_softmax = return_softmax
214 # masking
215 self.is_causal = is_causal
216 self.is_local = is_local
217 self.window_size_left = window_size_left
218 self.window_size_right = window_size_right
219 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped
220 self.is_paged = is_paged
221 # alibi
222 self.is_alibi = is_alibi
223 self.alibi_slopes_ptr = alibi_slopes_ptr
224 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride
225 # block table
226 self.total_q = total_q
227 self.page_table_ptr = page_table_ptr
228 self.page_table_batch_stride = page_table_batch_stride
229 self.block_size = block_size
230 self.k_page_stride = k_page_stride
232 def args(self):
233 return tuple(getattr(self, k) for k in self.__slots__)
236def mha_varlan_fwd(
237 q,
238 k,
239 v,
240 out,
241 cu_seqlens_q,
242 cu_seqlens_k,
243 seqused_k,
244 leftpad_k,
245 page_table,
246 alibi_slopes,
247 max_seqlen_q,
248 max_seqlen_k,
249 p_dropout,
250 softmax_scale,
251 zero_tensors,
252 is_causal,
253 window_size_left,
254 window_size_right,
255 softcap,
256 return_softmax,
257 gen,
258):
259 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
260 q_device = q.device
261 q_dtype = q.dtype
262 assert q_dtype in (
263 torch.float16,
264 torch.bfloat16,
265 ), "FlashAttention only support fp16 and bf16 data type"
266 assert q_dtype == k.dtype
267 assert q_dtype == v.dtype
268 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
269 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
270 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
272 assert cu_seqlens_q.dtype == torch.int32
273 assert cu_seqlens_q.is_contiguous()
275 assert cu_seqlens_k.dtype == torch.int32
276 assert cu_seqlens_k.is_contiguous()
278 is_paged = page_table is not None
279 if not is_paged:
280 page_table = torch.empty((0, 0), device=q_device, dtype=torch.int32)
282 # q shape: [total_q_tokens, num_heads, head_size]
283 # k shape:
284 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
285 # batch_size, number of sentences
286 total_q, num_heads, head_size = q.size()
287 num_heads_k = k.size(2) if is_paged else k.size(1)
288 batch_size = cu_seqlens_q.numel() - 1
289 block_size = k.size(1) if is_paged else 1
290 num_pages = k.size(0) if is_paged else 0
291 k_batch_size = num_pages
292 # max_num_pages_per_seq = page_table.size(1)
293 page_table_batch_stride = page_table.stride(0)
294 k_batch_stride = k.stride(0)
295 v_batch_stride = v.stride(0)
297 assert k.size() == v.size()
298 assert cu_seqlens_q.size() == (batch_size + 1,)
299 assert cu_seqlens_k.size() == (batch_size + 1,)
301 # Check output shape
302 if out is not None:
303 assert out.stride(-1) == 1
304 assert out.dtype == q.dtype
305 assert out.size() == (total_q, num_heads, head_size)
307 if seqused_k is not None:
308 assert seqused_k.is_contiguous()
309 assert seqused_k.size() == (batch_size,)
311 if max_seqlen_q == 1 and alibi_slopes is None:
312 is_causal = False
314 if is_causal:
315 window_size_right = 0
317 # check disable swa
318 if window_size_left >= max_seqlen_k:
319 window_size_left = -1
320 if window_size_right >= max_seqlen_k:
321 window_size_right = -1
323 is_local = window_size_left >= 0
325 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
326 seqlenq_ngroups_swapped = (
327 max_seqlen_q == 1
328 and alibi_slopes is None
329 and num_heads > num_heads_k
330 and window_size_left < 0
331 and window_size_right < 0
332 and p_dropout == 0
333 )
334 q_groups = num_heads // num_heads_k
335 if seqlenq_ngroups_swapped:
336 logger.debug("Swapping query groups and sequence dimensions")
337 q = (
338 q.reshape((batch_size, num_heads_k, q_groups, head_size))
339 .transpose(1, 2)
340 .reshape(batch_size * q_groups, num_heads_k, head_size)
341 )
342 max_seqlen_q = q_groups
343 num_heads = num_heads_k
344 cu_seqlens_q = None
345 q_batch_stride = q.stride(0) * max_seqlen_q
346 k_batch_stride = k.stride(0)
347 v_batch_stride = v.stride(0)
348 # o_batch_stride = out.stride(0) * max_seqlen_q
349 else:
350 q_batch_stride = 0
351 k_batch_stride = 0
352 v_batch_stride = 0
353 o_batch_stride = 0
355 total_q = q.size(0)
357 assert leftpad_k is None, "leftpad_k is not supported."
358 assert (
359 head_size <= 256
360 ), "FlashAttention forward only supports head dimension at most 256"
361 assert (
362 head_size % 8 == 0
363 ), "head_size must be a multiple of 8, this is ensured by padding!"
364 assert (
365 num_heads % num_heads_k == 0
366 ), "Number of heads in key/value must divide number of heads in query"
368 assert q.shape == (total_q, num_heads, head_size)
369 if is_paged:
370 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
371 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
372 assert k.stride() == v.stride()
374 if softcap > 0.0:
375 assert p_dropout == 0, "dropout is not supported if softcap is used."
377 round_multiple = lambda x, m: (x + m - 1) // m * m
378 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
379 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
380 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
382 M_LOG2E = 1.4426950408889634074
383 if softcap > 0.0:
384 is_softcap = True
385 adjusted_scale_softmax = softcap
386 adjusted_softcap = softmax_scale / softcap
387 adjusted_scale_softmax_log2e = softcap * M_LOG2E
388 else:
389 is_softcap = False
390 adjusted_softcap = 0.0
391 adjusted_scale_softmax = softmax_scale
392 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
394 # Set alibi params
395 if alibi_slopes is not None:
396 assert alibi_slopes.device == q_device
397 assert alibi_slopes.dtype in (torch.float,)
398 assert alibi_slopes.stride(-1) == 1
399 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
400 batch_size,
401 num_heads,
402 )
403 alibi_slopes_batch_stride = (
404 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
405 )
406 is_alibi = True
407 else:
408 alibi_slopes_batch_stride = 0
409 is_alibi = False
411 # Prepare params to kernel
412 with torch_device_fn.device(q_device):
413 if out is not None:
414 out_ = out
415 if seqlenq_ngroups_swapped:
416 out = torch.empty_like(q, dtype=v.dtype)
417 else:
418 out_ = None
419 out = torch.empty_like(q, dtype=v.dtype)
421 if seqlenq_ngroups_swapped:
422 o_batch_stride = out.stride(0) * max_seqlen_q
424 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
426 if p_dropout > 0:
427 is_dropout = True
428 increment = batch_size * num_heads * 32
429 philox_seed, philox_offset = philox_backend_seed_offset(increment)
430 philox_args = torch.tensor(
431 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
432 )
433 else:
434 is_dropout = False
435 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
437 p_dropout = 1 - p_dropout
438 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
439 rp_dropout = 1.0 / p_dropout
441 if return_softmax:
442 assert is_dropout, "Only supported with non-zero dropout."
443 p = torch.empty(
444 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
445 device=q_device,
446 )
447 else:
448 p = torch.empty((), device=q_device)
450 if zero_tensors:
451 out.zero_()
452 lse.fill_(float("-inf"))
454 params = fwd_params(
455 q, # q_ptr,
456 k, # k_ptr,
457 v, # v_ptr,
458 out, # o_ptr,
459 p, # p_ptr,
460 lse, # softmax_lse_ptr,
461 q.stride(-3), # q_row_stride,
462 k.stride(-3), # k_row_stride,
463 v.stride(-3), # v_row_stride,
464 q.stride(-2), # q_head_stride,
465 k.stride(-2), # k_head_stride,
466 v.stride(-2), # v_head_stride,
467 out.stride(-3), # o_row_stride,
468 out.stride(-2), # o_head_stride,
469 q_batch_stride, # q_batch_stride,
470 k_batch_stride, # k_batch_stride,
471 v_batch_stride, # v_batch_stride,
472 o_batch_stride, # o_batch_stride,
473 cu_seqlens_q is not None, # is_cu_seqlens_q,
474 cu_seqlens_q, # cu_seqlens_q_ptr,
475 seqused_k is None, # is_cu_seqlens_k,
476 cu_seqlens_k, # cu_seqlens_k_ptr,
477 seqused_k is not None, # is_seqused_k,
478 seqused_k, # seqused_k_ptr,
479 # sizes
480 batch_size, # b,
481 k_batch_size, # bk,
482 num_heads, # h,
483 num_heads_k, # hk,
484 num_heads // num_heads_k, # h_hk_ratio,
485 max_seqlen_q, # seqlen_q,
486 max_seqlen_k, # seqlen_k,
487 seqlen_q_rounded, # seqlen_q_rounded,
488 seqlen_k_rounded, # seqlen_k_rounded,
489 head_size, # d,
490 head_size_rounded, # d_rounded,
491 # scaling factors
492 is_softcap,
493 adjusted_softcap, # softcap,
494 adjusted_scale_softmax, # scale_softmax,
495 adjusted_scale_softmax_log2e, # scale_softmax_log2,
496 # dropout
497 is_dropout,
498 p_dropout,
499 rp_dropout,
500 p_dropout_in_uint8_t,
501 philox_args,
502 return_softmax,
503 # causal and swa
504 is_causal, # is_causal,
505 is_local, # is_local,
506 window_size_left, # window_size_left,
507 window_size_right, # window_size_right,
508 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
509 is_paged,
510 # alibi
511 is_alibi, #
512 alibi_slopes, # alibi_slopes_ptr,
513 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
514 # block table params
515 total_q, # total_q,
516 page_table, # page_table_ptr,
517 page_table_batch_stride, # page_table_batch_stride,
518 block_size, # block_size,
519 k.stride(0) if is_paged else 0, # k_page_stride,
520 )
522 if flag_gems.vendor_name == "iluvatar":
523 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
524 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
525 logger.debug("kernel: flash_varlen_fwd")
526 grid = lambda args: (
527 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
528 batch_size,
529 num_heads,
530 )
531 kernel = flash_varlen_fwd_kernel[grid]
532 args = tuple(getattr(params, k) for k in params.__slots__)
534 # We assess which phase the requests are likely to be in and set the config accordingly.
535 total_rows = total_q * num_heads
536 num_sms = torch_device_fn.get_device_properties(
537 flag_gems.device
538 ).multi_processor_count
539 avg_rows_per_sm = total_rows / num_sms
540 avg_rows_per_batch = total_q / batch_size
541 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
542 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
543 # This is a rough heuristic and may not be accurate for all scenarios.
544 if avg_rows_per_cta > 64:
545 varlen_fwd_config_str = "mha_block_128"
546 elif avg_rows_per_cta > 32:
547 varlen_fwd_config_str = "mha_block_64"
548 elif avg_rows_per_cta > 16:
549 varlen_fwd_config_str = "mha_block_32"
550 else:
551 varlen_fwd_config_str = "mha_block_16"
552 if flag_gems.vendor_name == "mthreads":
553 varlen_fwd_config_str = "mha_block_32"
555 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
556 cfg_params = {
557 "BLOCK_M": cfg["BLOCK_M"](args),
558 "BLOCK_N": cfg["BLOCK_N"](args),
559 "BLOCK_K": triton.next_power_of_2(head_size),
560 "num_warps": cfg["num_warps"](args),
561 "num_stages": 1 if not is_paged else cfg["num_stages"](args),
562 }
564 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
565 kernel(*args, **cfg_params)
567 if seqlenq_ngroups_swapped:
568 out = out.reshape(
569 batch_size, max_seqlen_q, num_heads_k, head_size
570 ).transpose(1, 2)
571 if out_ is not None:
572 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
573 out = out_
574 else:
575 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
576 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
577 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
579 unused = torch.empty((), dtype=torch.int64, device=q_device)
580 return out, q, k, v, lse, philox_args, unused, p
583def mha_varlan_fwd_opt(
584 q,
585 k,
586 v,
587 out,
588 lse,
589 cu_seqlens_q,
590 cu_seqlens_k,
591 seqused_k,
592 leftpad_k,
593 page_table,
594 alibi_slopes,
595 max_seqlen_q,
596 max_seqlen_k,
597 p_dropout,
598 softmax_scale,
599 zero_tensors,
600 is_causal,
601 window_size_left,
602 window_size_right,
603 softcap,
604 return_softmax,
605 gen,
606):
607 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
608 q_device = q.device
609 q_dtype = q.dtype
610 assert q_dtype in (
611 torch.float16,
612 torch.bfloat16,
613 ), "FlashAttention only support fp16 and bf16 data type"
614 assert q_dtype == k.dtype
615 assert q_dtype == v.dtype
616 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
617 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
618 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
620 assert cu_seqlens_q.dtype == torch.int32
621 assert cu_seqlens_q.is_contiguous()
623 assert cu_seqlens_k.dtype == torch.int32
624 assert cu_seqlens_k.is_contiguous()
626 is_paged = page_table is not None
627 if not is_paged:
628 page_table = torch.emtpty((0, 0), device=q_device, dtype=torch.int32)
630 # q shape: [total_q_tokens, num_heads, head_size]
631 # k shape:
632 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
633 # batch_size, number of sentences
634 total_q, num_heads, head_size = q.size()
635 num_heads_k = k.size(2) if is_paged else k.size(1)
636 batch_size = cu_seqlens_q.numel() - 1
637 block_size = k.size(1) if is_paged else 1
638 num_pages = k.size(0) if is_paged else 0
639 k_batch_size = num_pages
640 # max_num_pages_per_seq = page_table.size(1)
641 page_table_batch_stride = page_table.stride(0)
642 k_batch_stride = k.stride(0)
643 v_batch_stride = v.stride(0)
645 assert k.size() == v.size()
646 assert cu_seqlens_q.size() == (batch_size + 1,)
647 assert cu_seqlens_k.size() == (batch_size + 1,)
649 # Check output shape
650 if out is not None:
651 assert out.stride(-1) == 1
652 assert out.dtype == q.dtype
653 assert out.size() == (total_q, num_heads, head_size)
655 if seqused_k is not None:
656 assert seqused_k.is_contiguous()
657 assert seqused_k.size() == (batch_size,)
659 if max_seqlen_q == 1 and alibi_slopes is None:
660 is_causal = False
662 if is_causal:
663 window_size_right = 0
665 # check disable swa
666 if window_size_left >= max_seqlen_k:
667 window_size_left = -1
668 if window_size_right >= max_seqlen_k:
669 window_size_right = -1
671 is_local = window_size_left >= 0
673 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
674 seqlenq_ngroups_swapped = (
675 max_seqlen_q == 1
676 and alibi_slopes is None
677 and num_heads > num_heads_k
678 and window_size_left < 0
679 and window_size_right < 0
680 and p_dropout == 0
681 )
682 q_groups = num_heads // num_heads_k
683 if seqlenq_ngroups_swapped:
684 logger.debug("Swapping query groups and sequence dimensions")
685 q = (
686 q.reshape((batch_size, num_heads_k, q_groups, head_size))
687 .transpose(1, 2)
688 .reshape(batch_size * q_groups, num_heads_k, head_size)
689 )
690 max_seqlen_q = q_groups
691 num_heads = num_heads_k
692 cu_seqlens_q = None
693 q_batch_stride = q.stride(0) * max_seqlen_q
694 k_batch_stride = k.stride(0)
695 v_batch_stride = v.stride(0)
696 # o_batch_stride = out.stride(0) * max_seqlen_q
697 else:
698 q_batch_stride = 0
699 k_batch_stride = 0
700 v_batch_stride = 0
701 o_batch_stride = 0
703 total_q = q.size(0)
705 assert leftpad_k is None, "leftpad_k is not supported."
706 assert (
707 head_size <= 256
708 ), "FlashAttention forward only supports head dimension at most 256"
709 assert (
710 head_size % 8 == 0
711 ), "head_size must be a multiple of 8, this is ensured by padding!"
712 assert (
713 num_heads % num_heads_k == 0
714 ), "Number of heads in key/value must divide number of heads in query"
716 assert q.shape == (total_q, num_heads, head_size)
717 if is_paged:
718 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
719 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
720 assert k.stride() == v.stride()
722 if softcap > 0.0:
723 assert p_dropout == 0, "dropout is not supported if softcap is used."
725 round_multiple = lambda x, m: (x + m - 1) // m * m
726 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
727 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
728 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
730 M_LOG2E = 1.4426950408889634074
731 if softcap > 0.0:
732 is_softcap = True
733 adjusted_scale_softmax = softcap
734 adjusted_softcap = softmax_scale / softcap
735 adjusted_scale_softmax_log2e = softcap * M_LOG2E
736 else:
737 is_softcap = False
738 adjusted_softcap = 0.0
739 adjusted_scale_softmax = softmax_scale
740 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
742 # Set alibi params
743 if alibi_slopes is not None:
744 assert alibi_slopes.device == q_device
745 assert alibi_slopes.dtype in (torch.float,)
746 assert alibi_slopes.stride(-1) == 1
747 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
748 batch_size,
749 num_heads,
750 )
751 alibi_slopes_batch_stride = (
752 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
753 )
754 is_alibi = True
755 else:
756 alibi_slopes_batch_stride = 0
757 is_alibi = False
759 # Prepare params to kernel
760 with torch_device_fn.device(q_device):
761 if out is not None:
762 out_ = out
763 if seqlenq_ngroups_swapped:
764 out = torch.empty_like(q, dtype=v.dtype)
765 else:
766 out_ = None
767 out = torch.empty_like(q, dtype=v.dtype)
769 if seqlenq_ngroups_swapped:
770 o_batch_stride = out.stride(0) * max_seqlen_q
772 if lse is None:
773 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
775 if p_dropout > 0:
776 is_dropout = True
777 increment = batch_size * num_heads * 32
778 philox_seed, philox_offset = philox_backend_seed_offset(increment)
779 philox_args = torch.tensor(
780 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
781 )
782 else:
783 is_dropout = False
784 # philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
785 philox_args = None
787 p_dropout = 1 - p_dropout
788 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
789 rp_dropout = 1.0 / p_dropout
791 if return_softmax:
792 assert is_dropout, "Only supported with non-zero dropout."
793 p = torch.empty(
794 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
795 device=q_device,
796 )
797 else:
798 # p = torch.empty((), device=q_device)
799 p = None
800 if zero_tensors:
801 out.zero_()
802 lse.fill_(float("-inf"))
804 params = fwd_params(
805 q, # q_ptr,
806 k, # k_ptr,
807 v, # v_ptr,
808 out, # o_ptr,
809 p, # p_ptr,
810 lse, # softmax_lse_ptr,
811 q.stride(-3), # q_row_stride,
812 k.stride(-3), # k_row_stride,
813 v.stride(-3), # v_row_stride,
814 q.stride(-2), # q_head_stride,
815 k.stride(-2), # k_head_stride,
816 v.stride(-2), # v_head_stride,
817 out.stride(-3), # o_row_stride,
818 out.stride(-2), # o_head_stride,
819 q_batch_stride, # q_batch_stride,
820 k_batch_stride, # k_batch_stride,
821 v_batch_stride, # v_batch_stride,
822 o_batch_stride, # o_batch_stride,
823 cu_seqlens_q is not None, # is_cu_seqlens_q,
824 cu_seqlens_q, # cu_seqlens_q_ptr,
825 cu_seqlens_k is not None, # is_cu_seqlens_k,
826 cu_seqlens_k, # cu_seqlens_k_ptr,
827 seqused_k is not None, # is_seqused_k,
828 seqused_k, # seqused_k_ptr,
829 # sizes
830 batch_size, # b,
831 k_batch_size, # bk,
832 num_heads, # h,
833 num_heads_k, # hk,
834 num_heads // num_heads_k, # h_hk_ratio,
835 max_seqlen_q, # seqlen_q,
836 max_seqlen_k, # seqlen_k,
837 seqlen_q_rounded, # seqlen_q_rounded,
838 seqlen_k_rounded, # seqlen_k_rounded,
839 head_size, # d,
840 head_size_rounded, # d_rounded,
841 # scaling factors
842 is_softcap,
843 adjusted_softcap, # softcap,
844 adjusted_scale_softmax, # scale_softmax,
845 adjusted_scale_softmax_log2e, # scale_softmax_log2,
846 # dropout
847 is_dropout,
848 p_dropout,
849 rp_dropout,
850 p_dropout_in_uint8_t,
851 philox_args,
852 return_softmax,
853 # causal and swa
854 is_causal, # is_causal,
855 is_local, # is_local,
856 window_size_left, # window_size_left,
857 window_size_right, # window_size_right,
858 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
859 is_paged,
860 # alibi
861 is_alibi, #
862 alibi_slopes, # alibi_slopes_ptr,
863 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
864 # block table params
865 total_q, # total_q,
866 page_table, # page_table_ptr,
867 page_table_batch_stride, # page_table_batch_stride,
868 block_size, # block_size,
869 k.stride(0) if is_paged else 0, # k_page_stride,
870 )
872 if flag_gems.vendor_name == "iluvatar":
873 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
874 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
875 logger.debug("kernel: flash_varlen_fwd")
876 grid = lambda args: (
877 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
878 batch_size,
879 num_heads,
880 )
881 kernel = flash_varlen_fwd_kernel[grid]
882 args = tuple(getattr(params, k) for k in params.__slots__)
884 # We assess which phase the requests are likely to be in and set the config accordingly.
885 total_rows = total_q * num_heads
886 num_sms = torch_device_fn.get_device_properties(
887 flag_gems.device
888 ).multi_processor_count
889 avg_rows_per_sm = total_rows / num_sms
890 avg_rows_per_batch = total_q / batch_size
891 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
892 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
893 # This is a rough heuristic and may not be accurate for all scenarios.
894 if avg_rows_per_cta > 64:
895 varlen_fwd_config_str = "mha_block_128"
896 elif avg_rows_per_cta > 32:
897 varlen_fwd_config_str = "mha_block_64"
898 elif avg_rows_per_cta > 16:
899 varlen_fwd_config_str = "mha_block_32"
900 else:
901 varlen_fwd_config_str = "mha_block_16"
902 if flag_gems.vendor_name == "mthreads":
903 varlen_fwd_config_str = "mha_block_32"
905 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
906 cfg_params = {
907 "BLOCK_M": cfg["BLOCK_M"](args),
908 "BLOCK_N": cfg["BLOCK_N"](args),
909 "BLOCK_K": triton.next_power_of_2(head_size),
910 "num_warps": cfg["num_warps"](args),
911 "num_stages": 1 if not is_paged else cfg["num_stages"](args),
912 }
914 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
915 kernel(*args, **cfg_params)
917 if seqlenq_ngroups_swapped:
918 out = out.reshape(
919 batch_size, max_seqlen_q, num_heads_k, head_size
920 ).transpose(1, 2)
921 if out_ is not None:
922 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
923 out = out_
924 else:
925 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
926 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
927 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
929 # unused = torch.empty((), dtype=torch.int64, device=q_device)
930 unused = None
931 return out, q, k, v, lse, philox_args, unused, p
934def mha_fwd(
935 q,
936 k,
937 v,
938 out,
939 alibi_slopes,
940 p_dropout,
941 softmax_scale,
942 is_causal,
943 window_size_left,
944 window_size_right,
945 softcap,
946 return_softmax,
947 disable_splitkv=False,
948):
949 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
950 q_dtype = q.dtype
951 q_device = q.device
952 assert q_dtype in (
953 torch.float16,
954 torch.bfloat16,
955 ), "FlashAttention only support fp16 and bf16 data type"
956 assert q_dtype == k.dtype
957 assert q_dtype == v.dtype
958 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
959 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
960 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
961 batch_size, seqlen_q, num_heads, head_size = q.size()
962 _, seqlen_k, num_heads_k, _ = k.size()
964 # Check output shape
965 if out is not None:
966 assert out.stride(-1) == 1
967 assert out.dtype == q.dtype
968 assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
969 CHECK_DEVICE(out)
971 assert (
972 head_size % 8 == 0
973 ), "head_size must be a multiple of 8, this is ensured by padding!"
974 assert (
975 num_heads % num_heads_k == 0
976 ), "Number of heads in key/value must divide number of heads in query"
977 if window_size_left >= seqlen_k:
978 window_size_left = -1
979 if window_size_right >= seqlen_k:
980 window_size_right = -1
981 if seqlen_q == 1 and alibi_slopes is None:
982 is_causal = False
983 if is_causal:
984 window_size_right = 0
986 is_causal = window_size_left < 0 and window_size_right == 0
987 is_local = window_size_left >= 0 and window_size_right >= 0
989 seqlenq_ngroups_swapped = (
990 seqlen_q == 1
991 and alibi_slopes is None
992 and num_heads > num_heads_k
993 and window_size_left < 0
994 and window_size_right < 0
995 and p_dropout == 0
996 )
997 q_groups = num_heads // num_heads_k
999 if seqlenq_ngroups_swapped:
1000 logger.debug("q_kg swapped.")
1001 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2)
1002 seqlen_q = q_groups
1003 num_heads = num_heads_k
1005 round_multiple = lambda x, m: (x + m - 1) // m * m
1006 head_size_rounded = round_multiple(head_size, 32)
1007 seqlen_q_rounded = round_multiple(seqlen_q, 128)
1008 seqlen_k_rounded = round_multiple(seqlen_k, 32)
1010 assert (
1011 head_size <= 256
1012 ), "FlashAttention forward only supports head dimension at most 256"
1013 assert head_size == head_size_rounded, "head_size must be rounded to 32"
1015 def splits_heuristic(num_tasks, num_sms, n_blocks):
1016 # splits when wave efficiency is low
1017 n_waves = triton.cdiv(num_tasks, num_sms)
1018 eff = (num_tasks / num_sms) / n_waves
1019 if eff > 0.8 or n_waves > 1:
1020 return 1
1022 min_blocks_per_split = 2
1023 best_splits = min(
1024 triton.cdiv(n_blocks, min_blocks_per_split),
1025 int(math.floor(1.0 / eff)),
1026 num_sms,
1027 )
1029 return best_splits
1031 with torch_device_fn.device(q_device):
1032 # Set softmax params
1033 lse = torch.empty(
1034 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
1035 )
1037 if out is not None:
1038 if seqlenq_ngroups_swapped:
1039 out = out.reshape(
1040 batch_size, num_heads_k, q_groups, head_size
1041 ).transpose(1, 2)
1042 else:
1043 out = torch.empty_like(q, dtype=v.dtype)
1045 # Set dropout params
1046 if p_dropout > 0:
1047 is_dropout = True
1048 increment = batch_size * num_heads * 32
1049 philox_seed, philox_offset = philox_backend_seed_offset(increment)
1050 philox_args = torch.tensor(
1051 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
1052 )
1053 else:
1054 is_dropout = False
1055 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
1057 p_dropout = 1 - p_dropout
1058 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
1059 rp_dropout = 1.0 / p_dropout
1061 if return_softmax:
1062 assert is_dropout, "Only supported with non-zero dropout."
1063 p = torch.empty(
1064 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
1065 device=q_device,
1066 )
1067 else:
1068 p = torch.empty((), device=q_device)
1070 M_LOG2E = 1.4426950408889634074
1071 if softcap > 0.0:
1072 is_softcap = True
1073 adjusted_scale_softmax = softcap
1074 adjusted_softcap = softmax_scale / softcap
1075 adjusted_scale_softmax_log2e = softcap * M_LOG2E
1076 else:
1077 is_softcap = False
1078 adjusted_softcap = 0.0
1079 adjusted_scale_softmax = softmax_scale
1080 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
1082 # Set alibi params
1083 if alibi_slopes is not None:
1084 assert alibi_slopes.device == q_device
1085 assert alibi_slopes.dtype in (torch.float,)
1086 assert alibi_slopes.stride(-1) == 1
1087 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
1088 batch_size,
1089 num_heads,
1090 )
1091 alibi_slopes_batch_stride = (
1092 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
1093 )
1094 is_alibi = True
1095 else:
1096 alibi_slopes_batch_stride = 0
1097 is_alibi = False
1099 # ONLY EVEN_K IS SUPPORTED
1100 assert head_size == head_size_rounded
1102 # Do kernel dispatching
1103 def dispatch(B, H, Q, K, D, params):
1104 num_sms = torch_device_fn.get_device_properties(
1105 "cuda"
1106 ).multi_processor_count
1108 # Try bh parallel
1109 # if B * H > 0.8 * num_sms:
1110 # kernel = flash_fwd_bh_parallel_kernel[(H, B)]
1111 # # Yield kernel and prefilled args
1112 # return kernel, default_args, None, None
1114 # Try splitkv
1115 if not is_dropout and not is_local and not disable_splitkv:
1116 BM = block_m_splitkv_heuristic(D)
1117 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
1118 BN = block_n_splitkv_heuristic(D)
1119 n_blocks = triton.cdiv(seqlen_k, BN)
1120 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks)
1122 if n_splits > 1:
1123 logger.debug("kernel: flash_fwd_splitkv")
1124 lse_splits = torch.empty(
1125 (n_splits, B, H, Q), dtype=torch.float, device=q_device
1126 )
1127 out_splits = torch.empty(
1128 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device
1129 )
1130 grid = lambda args: (
1131 triton.cdiv(Q, args["BLOCK_M"]),
1132 n_splits,
1133 B * H,
1134 )
1135 splitkv_kernel = flash_fwd_splitkv_kernel[grid]
1136 params.o_ptr = out_splits
1137 params.softmax_lse_ptr = lse_splits
1138 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)}
1139 kernel = splitkv_kernel(*params.args(), **extra_args)
1141 if D >= 128:
1142 BLOCK_M = 4
1143 elif D >= 64:
1144 BLOCK_M = 8
1145 else:
1146 BLOCK_M = 16
1147 BLOCK_K = triton.next_power_of_2(D)
1148 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),)
1149 combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
1150 combine_args = {
1151 "out_ptr": out,
1152 "lse_ptr": lse,
1153 "head_size": head_size,
1154 "out_split_stride": out_splits.stride(0),
1155 "lse_split_stride": lse_splits.stride(0),
1156 "out_b_stride": out.stride(0),
1157 "out_s_stride": out.stride(-3),
1158 "out_h_stride": out.stride(-1),
1159 "out_splits_ptr": out_splits,
1160 "lse_splits_ptr": lse_splits,
1161 "n_splits": n_splits,
1162 "BLOCK_M": BLOCK_M,
1163 "BLOCK_K": BLOCK_K,
1164 "q_total": B * H * Q,
1165 "MAX_N_SPLITS": triton.next_power_of_2(n_splits),
1166 }
1167 combine_kernel(**combine_args)
1168 return kernel
1170 # Last option: flash_fwd
1171 logger.debug("kernel: flash_fwd")
1172 grid = lambda args: (
1173 triton.cdiv(Q, args["BLOCK_M"]),
1174 H * B,
1175 )
1176 kernel = flash_fwd_kernel[grid]
1177 kernel = kernel(*params.args())
1178 return kernel
1180 if _debug:
1181 p = torch.empty(
1182 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
1183 dtype=torch.float32,
1184 device=q_device,
1185 )
1186 return_softmax = True
1188 params = fwd_params(
1189 q, # q_ptr,
1190 k, # k_ptr,
1191 v, # v_ptr,
1192 out, # o_ptr,
1193 p, # p_ptr,
1194 lse, # softmax_lse_ptr,
1195 q.stride(-3), # q_row_stride,
1196 k.stride(-3), # k_row_stride,
1197 v.stride(-3), # v_row_stride,
1198 q.stride(-2), # q_head_stride,
1199 k.stride(-2), # k_head_stride,
1200 v.stride(-2), # v_head_stride,
1201 out.stride(-3), # o_row_stride,
1202 out.stride(-2), # o_head_stride,
1203 q.stride(0), # q_batch_stride,
1204 k.stride(0), # k_batch_stride,
1205 v.stride(0), # v_batch_stride,
1206 out.stride(0), # o_batch_stride,
1207 False, # is_cu_seqlens_q,
1208 None, # cu_seqlens_q_ptr,
1209 False, # is_cu_seqlens_k,
1210 None, # cu_seqlens_k_ptr,
1211 False, # is_seqused_k,
1212 None, # seqused_k_ptr,
1213 # sizes
1214 batch_size, # b,
1215 0, # bk,
1216 num_heads, # h,
1217 num_heads_k, # hk,
1218 num_heads // num_heads_k, # h_hk_ratio,
1219 seqlen_q, # seqlen_q,
1220 seqlen_k, # seqlen_k,
1221 seqlen_q_rounded, # seqlen_q_rounded,
1222 seqlen_k_rounded, # seqlen_k_rounded,
1223 head_size, # d,
1224 head_size_rounded, # d_rounded,
1225 # scaling factors
1226 is_softcap,
1227 adjusted_softcap, # softcap,
1228 adjusted_scale_softmax, # scale_softmax,
1229 adjusted_scale_softmax_log2e, # scale_softmax_log2,
1230 # dropout
1231 is_dropout,
1232 p_dropout,
1233 rp_dropout,
1234 p_dropout_in_uint8_t,
1235 philox_args,
1236 return_softmax,
1237 # causal and swa
1238 is_causal, # is_causal,
1239 is_local, # is_local,
1240 window_size_left, # window_size_left,
1241 window_size_right, # window_size_right,
1242 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
1243 False, # is_paged,
1244 # alibi
1245 is_alibi, #
1246 alibi_slopes, # alibi_slopes_ptr,
1247 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
1248 # block table params
1249 0, # total_q,
1250 None, # page_table_ptr,
1251 0, # page_table_batch_stride,
1252 0, # block_size,
1253 0, # k_page_stride,
1254 )
1256 # Move TxD to last dims for correct stride in Triton tt.load
1257 if flag_gems.vendor_name == "iluvatar":
1258 params.q_ptr = q.transpose(1, 2)
1259 params.k_ptr = k.transpose(1, 2)
1260 params.v_ptr = v.transpose(1, 2)
1261 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params)
1263 if _debug:
1264 print(f"{kernel.name} shared memory:", kernel.metadata.shared)
1265 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps)
1266 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages)
1267 # print(kernel.asm['ttgir'])
1269 if seqlenq_ngroups_swapped:
1270 out = out.transpose(1, 2).reshape(
1271 (batch_size, 1, num_heads_k * seqlen_q, head_size)
1272 )
1273 q = q.transpose(1, 2).reshape(
1274 (batch_size, 1, num_heads_k * seqlen_q, head_size)
1275 )
1276 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
1278 unused = torch.empty((), dtype=torch.int64, device=q_device)
1280 return out, q, k, v, lse, philox_args, unused, p