Coverage for src/flag_gems/ops/flash_api.py: 86%
536 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import math
4import torch
5import triton
7import flag_gems
8from flag_gems import runtime
9from flag_gems.ops.flash_kernel import (
10 block_m_splitkv_heuristic,
11 block_n_splitkv_heuristic,
12 flash_fwd_kernel,
13 flash_fwd_splitkv_combine_kernel,
14 flash_fwd_splitkv_kernel,
15 flash_varlen_fwd_kernel,
16)
17from flag_gems.runtime import torch_device_fn
18from flag_gems.utils.random_utils import philox_backend_seed_offset
20logger = logging.getLogger(__name__)
21_debug = False
24def CHECK_DEVICE(x):
25 assert x.device.type == flag_gems.device
28class fwd_params:
29 __slots__ = (
30 # pointers and strides
31 "q_ptr",
32 "k_ptr",
33 "v_ptr",
34 "o_ptr",
35 "p_ptr",
36 "softmax_lse_ptr",
37 "q_row_stride",
38 "k_row_stride",
39 "v_row_stride",
40 "q_head_stride",
41 "k_head_stride",
42 "v_head_stride",
43 "o_row_stride",
44 "o_head_stride",
45 "q_batch_stride",
46 "k_batch_stride",
47 "v_batch_stride",
48 "o_batch_stride",
49 "is_cu_seqlens_q",
50 "cu_seqlens_q_ptr",
51 "is_cu_seqlens_k",
52 "cu_seqlens_k_ptr",
53 "is_seqused_k",
54 "seqused_k_ptr",
55 # sizes
56 "b",
57 "bk",
58 "h",
59 "hk",
60 "h_hk_ratio",
61 "seqlen_q",
62 "seqlen_k",
63 "seqlen_q_rounded",
64 "seqlen_k_rounded",
65 "d",
66 "d_rounded",
67 # scaling factors
68 "is_softcap",
69 "softcap",
70 "scale_softmax",
71 "scale_softmax_log2",
72 # dropout
73 "is_dropout",
74 "p_dropout",
75 "rp_dropout",
76 "p_dropout_in_uint8_t",
77 "philox_args",
78 "return_softmax",
79 # masking
80 "is_causal",
81 "is_local",
82 "window_size_left",
83 "window_size_right",
84 "seqlenq_ngroups_swapped",
85 "is_paged",
86 # alibi
87 "is_alibi",
88 "alibi_slopes_ptr",
89 "alibi_slopes_batch_stride",
90 # block table
91 "total_q",
92 "page_table_ptr",
93 "page_table_batch_stride",
94 "block_size",
95 )
97 def __init__(
98 self,
99 q_ptr,
100 k_ptr,
101 v_ptr,
102 o_ptr,
103 p_ptr,
104 softmax_lse_ptr,
105 q_row_stride,
106 k_row_stride,
107 v_row_stride,
108 q_head_stride,
109 k_head_stride,
110 v_head_stride,
111 o_row_stride,
112 o_head_stride,
113 q_batch_stride,
114 k_batch_stride,
115 v_batch_stride,
116 o_batch_stride,
117 is_cu_seqlens_q,
118 cu_seqlens_q_ptr,
119 is_cu_seqlens_k,
120 cu_seqlens_k_ptr,
121 is_seqused_k,
122 seqused_k_ptr,
123 # sizes
124 b,
125 bk,
126 h,
127 hk,
128 h_hk_ratio,
129 seqlen_q,
130 seqlen_k,
131 seqlen_q_rounded,
132 seqlen_k_rounded,
133 d,
134 d_rounded,
135 # scaling factors
136 is_softcap,
137 softcap,
138 scale_softmax,
139 scale_softmax_log2,
140 # dropout
141 is_dropout,
142 p_dropout,
143 rp_dropout,
144 p_dropout_in_uint8_t,
145 philox_args,
146 return_softmax,
147 # masking
148 is_causal,
149 is_local,
150 window_size_left,
151 window_size_right,
152 seqlenq_ngroups_swapped,
153 is_paged,
154 # alibi
155 is_alibi,
156 alibi_slopes_ptr,
157 alibi_slopes_batch_stride,
158 # block table
159 total_q,
160 page_table_ptr,
161 page_table_batch_stride,
162 block_size,
163 ):
164 self.q_ptr = q_ptr
165 self.k_ptr = k_ptr
166 self.v_ptr = v_ptr
167 self.o_ptr = o_ptr
168 self.p_ptr = p_ptr
169 self.softmax_lse_ptr = softmax_lse_ptr
170 self.q_row_stride = q_row_stride
171 self.k_row_stride = k_row_stride
172 self.v_row_stride = v_row_stride
173 self.q_head_stride = q_head_stride
174 self.k_head_stride = k_head_stride
175 self.v_head_stride = v_head_stride
176 self.o_row_stride = o_row_stride
177 self.o_head_stride = o_head_stride
178 self.q_batch_stride = q_batch_stride
179 self.k_batch_stride = k_batch_stride
180 self.v_batch_stride = v_batch_stride
181 self.o_batch_stride = o_batch_stride
182 self.is_cu_seqlens_q = is_cu_seqlens_q
183 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr
184 self.is_cu_seqlens_k = is_cu_seqlens_k
185 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr
186 self.is_seqused_k = is_seqused_k
187 self.seqused_k_ptr = seqused_k_ptr
188 # sizes
189 self.b = b
190 self.bk = bk
191 self.h = h
192 self.hk = hk
193 self.h_hk_ratio = h_hk_ratio
194 self.seqlen_q = seqlen_q
195 self.seqlen_k = seqlen_k
196 self.seqlen_q_rounded = seqlen_q_rounded
197 self.seqlen_k_rounded = seqlen_k_rounded
198 self.d = d
199 self.d_rounded = d_rounded
200 # scaling factors
201 self.is_softcap = is_softcap
202 self.softcap = softcap
203 self.scale_softmax = scale_softmax
204 self.scale_softmax_log2 = scale_softmax_log2
205 # dropout
206 self.is_dropout = is_dropout
207 self.p_dropout = p_dropout
208 self.rp_dropout = rp_dropout
209 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t
210 self.philox_args = philox_args
211 self.return_softmax = return_softmax
212 # masking
213 self.is_causal = is_causal
214 self.is_local = is_local
215 self.window_size_left = window_size_left
216 self.window_size_right = window_size_right
217 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped
218 self.is_paged = is_paged
219 # alibi
220 self.is_alibi = is_alibi
221 self.alibi_slopes_ptr = alibi_slopes_ptr
222 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride
223 # block table
224 self.total_q = total_q
225 self.page_table_ptr = page_table_ptr
226 self.page_table_batch_stride = page_table_batch_stride
227 self.block_size = block_size
229 def args(self):
230 return tuple(getattr(self, k) for k in self.__slots__)
233def mha_varlan_fwd(
234 q,
235 k,
236 v,
237 out,
238 cu_seqlens_q,
239 cu_seqlens_k,
240 seqused_k,
241 leftpad_k,
242 page_table,
243 alibi_slopes,
244 max_seqlen_q,
245 max_seqlen_k,
246 p_dropout,
247 softmax_scale,
248 zero_tensors,
249 is_causal,
250 window_size_left,
251 window_size_right,
252 softcap,
253 return_softmax,
254 gen,
255):
256 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
257 q_device = q.device
258 q_dtype = q.dtype
259 assert q_dtype in (
260 torch.float16,
261 torch.bfloat16,
262 ), "FlashAttention only support fp16 and bf16 data type"
263 assert q_dtype == k.dtype
264 assert q_dtype == v.dtype
265 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
266 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
267 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
269 assert cu_seqlens_q.dtype == torch.int32
270 assert cu_seqlens_q.is_contiguous()
272 assert cu_seqlens_k.dtype == torch.int32
273 assert cu_seqlens_k.is_contiguous()
275 is_paged = page_table is not None
276 if not is_paged:
277 page_table = torch.empty((0, 0), device=q_device, dtype=torch.int32)
279 # q shape: [total_q_tokens, num_heads, head_size]
280 # k shape:
281 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
282 # batch_size, number of sentences
283 total_q, num_heads, head_size = q.size()
284 num_heads_k = k.size(2) if is_paged else k.size(1)
285 batch_size = cu_seqlens_q.numel() - 1
286 block_size = k.size(1) if is_paged else 1
287 num_pages = k.size(0) if is_paged else 0
288 k_batch_size = num_pages
289 # max_num_pages_per_seq = page_table.size(1)
290 page_table_batch_stride = page_table.stride(0)
291 k_batch_stride = k.stride(0)
292 v_batch_stride = v.stride(0)
294 assert k.size() == v.size()
295 assert cu_seqlens_q.size() == (batch_size + 1,)
296 assert cu_seqlens_k.size() == (batch_size + 1,)
298 # Check output shape
299 if out is not None:
300 assert out.stride(-1) == 1
301 assert out.dtype == q.dtype
302 assert out.size() == (total_q, num_heads, head_size)
304 if seqused_k is not None:
305 assert seqused_k.is_contiguous()
306 assert seqused_k.size() == (batch_size,)
308 if max_seqlen_q == 1 and alibi_slopes is None:
309 is_causal = False
311 if is_causal:
312 window_size_right = 0
314 # check disable swa
315 if window_size_left >= max_seqlen_k:
316 window_size_left = -1
317 if window_size_right >= max_seqlen_k:
318 window_size_right = -1
320 is_local = window_size_left >= 0
322 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
323 seqlenq_ngroups_swapped = (
324 max_seqlen_q == 1
325 and alibi_slopes is None
326 and num_heads > num_heads_k
327 and window_size_left < 0
328 and window_size_right < 0
329 and p_dropout == 0
330 )
331 q_groups = num_heads // num_heads_k
332 if seqlenq_ngroups_swapped:
333 logger.debug("Swapping query groups and sequence dimensions")
334 q = (
335 q.reshape((batch_size, num_heads_k, q_groups, head_size))
336 .transpose(1, 2)
337 .reshape(batch_size * q_groups, num_heads_k, head_size)
338 )
339 max_seqlen_q = q_groups
340 num_heads = num_heads_k
341 cu_seqlens_q = None
342 q_batch_stride = q.stride(0) * max_seqlen_q
343 k_batch_stride = k.stride(0)
344 v_batch_stride = v.stride(0)
345 # o_batch_stride = out.stride(0) * max_seqlen_q
346 else:
347 q_batch_stride = 0
348 k_batch_stride = 0
349 v_batch_stride = 0
350 o_batch_stride = 0
352 total_q = q.size(0)
354 assert leftpad_k is None, "leftpad_k is not supported."
355 assert (
356 head_size <= 256
357 ), "FlashAttention forward only supports head dimension at most 256"
358 assert (
359 head_size % 8 == 0
360 ), "head_size must be a multiple of 8, this is ensured by padding!"
361 assert (
362 num_heads % num_heads_k == 0
363 ), "Number of heads in key/value must divide number of heads in query"
365 assert q.shape == (total_q, num_heads, head_size)
366 if is_paged:
367 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
368 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
369 assert k.stride() == v.stride()
371 if softcap > 0.0:
372 assert p_dropout == 0, "dropout is not supported if softcap is used."
374 round_multiple = lambda x, m: (x + m - 1) // m * m
375 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
376 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
377 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
379 M_LOG2E = 1.4426950408889634074
380 if softcap > 0.0:
381 is_softcap = True
382 adjusted_scale_softmax = softcap
383 adjusted_softcap = softmax_scale / softcap
384 adjusted_scale_softmax_log2e = softcap * M_LOG2E
385 else:
386 is_softcap = False
387 adjusted_softcap = 0.0
388 adjusted_scale_softmax = softmax_scale
389 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
391 # Set alibi params
392 if alibi_slopes is not None:
393 assert alibi_slopes.device == q_device
394 assert alibi_slopes.dtype in (torch.float,)
395 assert alibi_slopes.stride(-1) == 1
396 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
397 batch_size,
398 num_heads,
399 )
400 alibi_slopes_batch_stride = (
401 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
402 )
403 is_alibi = True
404 else:
405 alibi_slopes_batch_stride = 0
406 is_alibi = False
408 # Prepare params to kernel
409 with torch_device_fn.device(q_device):
410 if out is not None:
411 out_ = out
412 if seqlenq_ngroups_swapped:
413 out = torch.empty_like(q, dtype=v.dtype)
414 else:
415 out_ = None
416 out = torch.empty_like(q, dtype=v.dtype)
418 if seqlenq_ngroups_swapped:
419 o_batch_stride = out.stride(0) * max_seqlen_q
421 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
423 if p_dropout > 0:
424 is_dropout = True
425 increment = batch_size * num_heads * 32
426 philox_seed, philox_offset = philox_backend_seed_offset(increment)
427 philox_args = torch.tensor(
428 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
429 )
430 else:
431 is_dropout = False
432 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
434 p_dropout = 1 - p_dropout
435 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
436 rp_dropout = 1.0 / p_dropout
438 if return_softmax:
439 assert is_dropout, "Only supported with non-zero dropout."
440 p = torch.empty(
441 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
442 device=q_device,
443 )
444 else:
445 p = torch.empty((), device=q_device)
447 if zero_tensors:
448 out.zero_()
449 lse.fill_(float("-inf"))
451 params = fwd_params(
452 q, # q_ptr,
453 k, # k_ptr,
454 v, # v_ptr,
455 out, # o_ptr,
456 p, # p_ptr,
457 lse, # softmax_lse_ptr,
458 q.stride(-3), # q_row_stride,
459 k.stride(-3), # k_row_stride,
460 v.stride(-3), # v_row_stride,
461 q.stride(-2), # q_head_stride,
462 k.stride(-2), # k_head_stride,
463 v.stride(-2), # v_head_stride,
464 out.stride(-3), # o_row_stride,
465 out.stride(-2), # o_head_stride,
466 q_batch_stride, # q_batch_stride,
467 k_batch_stride, # k_batch_stride,
468 v_batch_stride, # v_batch_stride,
469 o_batch_stride, # o_batch_stride,
470 cu_seqlens_q is not None, # is_cu_seqlens_q,
471 cu_seqlens_q, # cu_seqlens_q_ptr,
472 seqused_k is None, # is_cu_seqlens_k,
473 cu_seqlens_k, # cu_seqlens_k_ptr,
474 seqused_k is not None, # is_seqused_k,
475 seqused_k, # seqused_k_ptr,
476 # sizes
477 batch_size, # b,
478 k_batch_size, # bk,
479 num_heads, # h,
480 num_heads_k, # hk,
481 num_heads // num_heads_k, # h_hk_ratio,
482 max_seqlen_q, # seqlen_q,
483 max_seqlen_k, # seqlen_k,
484 seqlen_q_rounded, # seqlen_q_rounded,
485 seqlen_k_rounded, # seqlen_k_rounded,
486 head_size, # d,
487 head_size_rounded, # d_rounded,
488 # scaling factors
489 is_softcap,
490 adjusted_softcap, # softcap,
491 adjusted_scale_softmax, # scale_softmax,
492 adjusted_scale_softmax_log2e, # scale_softmax_log2,
493 # dropout
494 is_dropout,
495 p_dropout,
496 rp_dropout,
497 p_dropout_in_uint8_t,
498 philox_args,
499 return_softmax,
500 # causal and swa
501 is_causal, # is_causal,
502 is_local, # is_local,
503 window_size_left, # window_size_left,
504 window_size_right, # window_size_right,
505 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
506 is_paged,
507 # alibi
508 is_alibi, #
509 alibi_slopes, # alibi_slopes_ptr,
510 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
511 # block table params
512 total_q, # total_q,
513 page_table, # page_table_ptr,
514 page_table_batch_stride, # page_table_batch_stride,
515 block_size, # block_size,
516 )
518 if flag_gems.vendor_name == "iluvatar":
519 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
520 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
521 logger.debug("kernel: flash_varlen_fwd")
522 grid = lambda args: (
523 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
524 batch_size,
525 num_heads,
526 )
527 kernel = flash_varlen_fwd_kernel[grid]
528 args = tuple(getattr(params, k) for k in params.__slots__)
530 # We assess which phase the requests are likely to be in and set the config accordingly.
531 total_rows = total_q * num_heads
532 num_sms = torch_device_fn.get_device_properties(
533 flag_gems.device
534 ).multi_processor_count
535 avg_rows_per_sm = total_rows / num_sms
536 avg_rows_per_batch = total_q / batch_size
537 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
538 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
539 # This is a rough heuristic and may not be accurate for all scenarios.
540 if avg_rows_per_cta > 64:
541 varlen_fwd_config_str = "mha_block_128"
542 elif avg_rows_per_cta > 32:
543 varlen_fwd_config_str = "mha_block_64"
544 elif avg_rows_per_cta > 16:
545 varlen_fwd_config_str = "mha_block_32"
546 else:
547 varlen_fwd_config_str = "mha_block_16"
548 if flag_gems.vendor_name == "mthreads":
549 varlen_fwd_config_str = "mha_block_32"
551 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
552 cfg_params = {
553 "BLOCK_M": cfg["BLOCK_M"](args),
554 "BLOCK_N": cfg["BLOCK_N"](args),
555 "BLOCK_K": triton.next_power_of_2(head_size),
556 "num_warps": cfg["num_warps"](args),
557 "num_stages": 1 if not is_paged else cfg["num_stages"](args),
558 }
560 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
561 kernel(*args, **cfg_params)
563 if seqlenq_ngroups_swapped:
564 out = out.reshape(
565 batch_size, max_seqlen_q, num_heads_k, head_size
566 ).transpose(1, 2)
567 if out_ is not None:
568 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
569 out = out_
570 else:
571 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
572 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
573 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
575 unused = torch.empty((), dtype=torch.int64, device=q_device)
576 return out, q, k, v, lse, philox_args, unused, p
579def mha_varlan_fwd_opt(
580 q,
581 k,
582 v,
583 out,
584 lse,
585 cu_seqlens_q,
586 cu_seqlens_k,
587 seqused_k,
588 leftpad_k,
589 page_table,
590 alibi_slopes,
591 max_seqlen_q,
592 max_seqlen_k,
593 p_dropout,
594 softmax_scale,
595 zero_tensors,
596 is_causal,
597 window_size_left,
598 window_size_right,
599 softcap,
600 return_softmax,
601 gen,
602):
603 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
604 q_device = q.device
605 q_dtype = q.dtype
606 assert q_dtype in (
607 torch.float16,
608 torch.bfloat16,
609 ), "FlashAttention only support fp16 and bf16 data type"
610 assert q_dtype == k.dtype
611 assert q_dtype == v.dtype
612 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
613 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
614 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
616 assert cu_seqlens_q.dtype == torch.int32
617 assert cu_seqlens_q.is_contiguous()
619 assert cu_seqlens_k.dtype == torch.int32
620 assert cu_seqlens_k.is_contiguous()
622 is_paged = page_table is not None
623 if not is_paged:
624 page_table = torch.emtpty((0, 0), device=q_device, dtype=torch.int32)
626 # q shape: [total_q_tokens, num_heads, head_size]
627 # k shape:
628 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
629 # batch_size, number of sentences
630 total_q, num_heads, head_size = q.size()
631 num_heads_k = k.size(2) if is_paged else k.size(1)
632 batch_size = cu_seqlens_q.numel() - 1
633 block_size = k.size(1) if is_paged else 1
634 num_pages = k.size(0) if is_paged else 0
635 k_batch_size = num_pages
636 # max_num_pages_per_seq = page_table.size(1)
637 page_table_batch_stride = page_table.stride(0)
638 k_batch_stride = k.stride(0)
639 v_batch_stride = v.stride(0)
641 assert k.size() == v.size()
642 assert cu_seqlens_q.size() == (batch_size + 1,)
643 assert cu_seqlens_k.size() == (batch_size + 1,)
645 # Check output shape
646 if out is not None:
647 assert out.stride(-1) == 1
648 assert out.dtype == q.dtype
649 assert out.size() == (total_q, num_heads, head_size)
651 if seqused_k is not None:
652 assert seqused_k.is_contiguous()
653 assert seqused_k.size() == (batch_size,)
655 if max_seqlen_q == 1 and alibi_slopes is None:
656 is_causal = False
658 if is_causal:
659 window_size_right = 0
661 # check disable swa
662 if window_size_left >= max_seqlen_k:
663 window_size_left = -1
664 if window_size_right >= max_seqlen_k:
665 window_size_right = -1
667 is_local = window_size_left >= 0
669 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
670 seqlenq_ngroups_swapped = (
671 max_seqlen_q == 1
672 and alibi_slopes is None
673 and num_heads > num_heads_k
674 and window_size_left < 0
675 and window_size_right < 0
676 and p_dropout == 0
677 )
678 q_groups = num_heads // num_heads_k
679 if seqlenq_ngroups_swapped:
680 logger.debug("Swapping query groups and sequence dimensions")
681 q = (
682 q.reshape((batch_size, num_heads_k, q_groups, head_size))
683 .transpose(1, 2)
684 .reshape(batch_size * q_groups, num_heads_k, head_size)
685 )
686 max_seqlen_q = q_groups
687 num_heads = num_heads_k
688 cu_seqlens_q = None
689 q_batch_stride = q.stride(0) * max_seqlen_q
690 k_batch_stride = k.stride(0)
691 v_batch_stride = v.stride(0)
692 # o_batch_stride = out.stride(0) * max_seqlen_q
693 else:
694 q_batch_stride = 0
695 k_batch_stride = 0
696 v_batch_stride = 0
697 o_batch_stride = 0
699 total_q = q.size(0)
701 assert leftpad_k is None, "leftpad_k is not supported."
702 assert (
703 head_size <= 256
704 ), "FlashAttention forward only supports head dimension at most 256"
705 assert (
706 head_size % 8 == 0
707 ), "head_size must be a multiple of 8, this is ensured by padding!"
708 assert (
709 num_heads % num_heads_k == 0
710 ), "Number of heads in key/value must divide number of heads in query"
712 assert q.shape == (total_q, num_heads, head_size)
713 if is_paged:
714 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
715 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
716 assert k.stride() == v.stride()
718 if softcap > 0.0:
719 assert p_dropout == 0, "dropout is not supported if softcap is used."
721 round_multiple = lambda x, m: (x + m - 1) // m * m
722 head_size_rounded = round_multiple(head_size, 32) if head_size <= 192 else 256
723 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
724 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
726 M_LOG2E = 1.4426950408889634074
727 if softcap > 0.0:
728 is_softcap = True
729 adjusted_scale_softmax = softcap
730 adjusted_softcap = softmax_scale / softcap
731 adjusted_scale_softmax_log2e = softcap * M_LOG2E
732 else:
733 is_softcap = False
734 adjusted_softcap = 0.0
735 adjusted_scale_softmax = softmax_scale
736 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
738 # Set alibi params
739 if alibi_slopes is not None:
740 assert alibi_slopes.device == q_device
741 assert alibi_slopes.dtype in (torch.float,)
742 assert alibi_slopes.stride(-1) == 1
743 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
744 batch_size,
745 num_heads,
746 )
747 alibi_slopes_batch_stride = (
748 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
749 )
750 is_alibi = True
751 else:
752 alibi_slopes_batch_stride = 0
753 is_alibi = False
755 # Prepare params to kernel
756 with torch_device_fn.device(q_device):
757 if out is not None:
758 out_ = out
759 if seqlenq_ngroups_swapped:
760 out = torch.empty_like(q, dtype=v.dtype)
761 else:
762 out_ = None
763 out = torch.empty_like(q, dtype=v.dtype)
765 if seqlenq_ngroups_swapped:
766 o_batch_stride = out.stride(0) * max_seqlen_q
768 if lse is None:
769 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
771 if p_dropout > 0:
772 is_dropout = True
773 increment = batch_size * num_heads * 32
774 philox_seed, philox_offset = philox_backend_seed_offset(increment)
775 philox_args = torch.tensor(
776 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
777 )
778 else:
779 is_dropout = False
780 # philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
781 philox_args = None
783 p_dropout = 1 - p_dropout
784 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
785 rp_dropout = 1.0 / p_dropout
787 if return_softmax:
788 assert is_dropout, "Only supported with non-zero dropout."
789 p = torch.empty(
790 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
791 device=q_device,
792 )
793 else:
794 # p = torch.empty((), device=q_device)
795 p = None
796 if zero_tensors:
797 out.zero_()
798 lse.fill_(float("-inf"))
800 params = fwd_params(
801 q, # q_ptr,
802 k, # k_ptr,
803 v, # v_ptr,
804 out, # o_ptr,
805 p, # p_ptr,
806 lse, # softmax_lse_ptr,
807 q.stride(-3), # q_row_stride,
808 k.stride(-3), # k_row_stride,
809 v.stride(-3), # v_row_stride,
810 q.stride(-2), # q_head_stride,
811 k.stride(-2), # k_head_stride,
812 v.stride(-2), # v_head_stride,
813 out.stride(-3), # o_row_stride,
814 out.stride(-2), # o_head_stride,
815 q_batch_stride, # q_batch_stride,
816 k_batch_stride, # k_batch_stride,
817 v_batch_stride, # v_batch_stride,
818 o_batch_stride, # o_batch_stride,
819 cu_seqlens_q is not None, # is_cu_seqlens_q,
820 cu_seqlens_q, # cu_seqlens_q_ptr,
821 cu_seqlens_k is not None, # is_cu_seqlens_k,
822 cu_seqlens_k, # cu_seqlens_k_ptr,
823 seqused_k is not None, # is_seqused_k,
824 seqused_k, # seqused_k_ptr,
825 # sizes
826 batch_size, # b,
827 k_batch_size, # bk,
828 num_heads, # h,
829 num_heads_k, # hk,
830 num_heads // num_heads_k, # h_hk_ratio,
831 max_seqlen_q, # seqlen_q,
832 max_seqlen_k, # seqlen_k,
833 seqlen_q_rounded, # seqlen_q_rounded,
834 seqlen_k_rounded, # seqlen_k_rounded,
835 head_size, # d,
836 head_size_rounded, # d_rounded,
837 # scaling factors
838 is_softcap,
839 adjusted_softcap, # softcap,
840 adjusted_scale_softmax, # scale_softmax,
841 adjusted_scale_softmax_log2e, # scale_softmax_log2,
842 # dropout
843 is_dropout,
844 p_dropout,
845 rp_dropout,
846 p_dropout_in_uint8_t,
847 philox_args,
848 return_softmax,
849 # causal and swa
850 is_causal, # is_causal,
851 is_local, # is_local,
852 window_size_left, # window_size_left,
853 window_size_right, # window_size_right,
854 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
855 is_paged,
856 # alibi
857 is_alibi, #
858 alibi_slopes, # alibi_slopes_ptr,
859 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
860 # block table params
861 total_q, # total_q,
862 page_table, # page_table_ptr,
863 page_table_batch_stride, # page_table_batch_stride,
864 block_size, # block_size,
865 )
867 if flag_gems.vendor_name == "iluvatar":
868 params.k_ptr = k.view(k.shape[0], k.shape[1], -1)
869 params.v_ptr = v.view(v.shape[0], v.shape[1], -1)
870 logger.debug("kernel: flash_varlen_fwd")
871 grid = lambda args: (
872 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
873 batch_size,
874 num_heads,
875 )
876 kernel = flash_varlen_fwd_kernel[grid]
877 args = tuple(getattr(params, k) for k in params.__slots__)
879 # We assess which phase the requests are likely to be in and set the config accordingly.
880 total_rows = total_q * num_heads
881 num_sms = torch_device_fn.get_device_properties(
882 flag_gems.device
883 ).multi_processor_count
884 avg_rows_per_sm = total_rows / num_sms
885 avg_rows_per_batch = total_q / batch_size
886 avg_rows_per_cta = min(avg_rows_per_batch, avg_rows_per_sm)
887 # Heuristic: if avg_rows_per_sm >= 128, we are likely in prefill phase.
888 # This is a rough heuristic and may not be accurate for all scenarios.
889 if avg_rows_per_cta > 64:
890 varlen_fwd_config_str = "mha_block_128"
891 elif avg_rows_per_cta > 32:
892 varlen_fwd_config_str = "mha_block_64"
893 elif avg_rows_per_cta > 16:
894 varlen_fwd_config_str = "mha_block_32"
895 else:
896 varlen_fwd_config_str = "mha_block_16"
897 if flag_gems.vendor_name == "mthreads":
898 varlen_fwd_config_str = "mha_block_32"
900 cfg = runtime.get_heuristic_config(varlen_fwd_config_str)
901 cfg_params = {
902 "BLOCK_M": cfg["BLOCK_M"](args),
903 "BLOCK_N": cfg["BLOCK_N"](args),
904 "BLOCK_K": triton.next_power_of_2(head_size),
905 "num_warps": cfg["num_warps"](args),
906 "num_stages": 1 if not is_paged else cfg["num_stages"](args),
907 }
909 logger.debug("Running flash_varlen_fwd_kernel with config: %s", cfg_params)
910 kernel(*args, **cfg_params)
912 if seqlenq_ngroups_swapped:
913 out = out.reshape(
914 batch_size, max_seqlen_q, num_heads_k, head_size
915 ).transpose(1, 2)
916 if out_ is not None:
917 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
918 out = out_
919 else:
920 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
921 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
922 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
924 # unused = torch.empty((), dtype=torch.int64, device=q_device)
925 unused = None
926 return out, q, k, v, lse, philox_args, unused, p
929def mha_fwd(
930 q,
931 k,
932 v,
933 out,
934 alibi_slopes,
935 p_dropout,
936 softmax_scale,
937 is_causal,
938 window_size_left,
939 window_size_right,
940 softcap,
941 return_softmax,
942 disable_splitkv=False,
943):
944 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
945 q_dtype = q.dtype
946 q_device = q.device
947 assert q_dtype in (
948 torch.float16,
949 torch.bfloat16,
950 ), "FlashAttention only support fp16 and bf16 data type"
951 assert q_dtype == k.dtype
952 assert q_dtype == v.dtype
953 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
954 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
955 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
956 batch_size, seqlen_q, num_heads, head_size = q.size()
957 _, seqlen_k, num_heads_k, _ = k.size()
959 # Check output shape
960 if out is not None:
961 assert out.stride(-1) == 1
962 assert out.dtype == q.dtype
963 assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
964 CHECK_DEVICE(out)
966 assert (
967 head_size % 8 == 0
968 ), "head_size must be a multiple of 8, this is ensured by padding!"
969 assert (
970 num_heads % num_heads_k == 0
971 ), "Number of heads in key/value must divide number of heads in query"
972 if window_size_left >= seqlen_k:
973 window_size_left = -1
974 if window_size_right >= seqlen_k:
975 window_size_right = -1
976 if seqlen_q == 1 and alibi_slopes is None:
977 is_causal = False
978 if is_causal:
979 window_size_right = 0
981 is_causal = window_size_left < 0 and window_size_right == 0
982 is_local = window_size_left >= 0 and window_size_right >= 0
984 seqlenq_ngroups_swapped = (
985 seqlen_q == 1
986 and alibi_slopes is None
987 and num_heads > num_heads_k
988 and window_size_left < 0
989 and window_size_right < 0
990 and p_dropout == 0
991 )
992 q_groups = num_heads // num_heads_k
994 if seqlenq_ngroups_swapped:
995 logger.debug("q_kg swapped.")
996 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2)
997 seqlen_q = q_groups
998 num_heads = num_heads_k
1000 round_multiple = lambda x, m: (x + m - 1) // m * m
1001 head_size_rounded = round_multiple(head_size, 32)
1002 seqlen_q_rounded = round_multiple(seqlen_q, 128)
1003 seqlen_k_rounded = round_multiple(seqlen_k, 32)
1005 assert (
1006 head_size <= 256
1007 ), "FlashAttention forward only supports head dimension at most 256"
1008 assert head_size == head_size_rounded, "head_size must be rounded to 32"
1010 def splits_heuristic(num_tasks, num_sms, n_blocks):
1011 # splits when wave efficiency is low
1012 n_waves = triton.cdiv(num_tasks, num_sms)
1013 eff = (num_tasks / num_sms) / n_waves
1014 if eff > 0.8 or n_waves > 1:
1015 return 1
1017 min_blocks_per_split = 2
1018 best_splits = min(
1019 triton.cdiv(n_blocks, min_blocks_per_split),
1020 int(math.floor(1.0 / eff)),
1021 num_sms,
1022 )
1024 return best_splits
1026 with torch_device_fn.device(q_device):
1027 # Set softmax params
1028 lse = torch.empty(
1029 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
1030 )
1032 if out is not None:
1033 if seqlenq_ngroups_swapped:
1034 out = out.reshape(
1035 batch_size, num_heads_k, q_groups, head_size
1036 ).transpose(1, 2)
1037 else:
1038 out = torch.empty_like(q, dtype=v.dtype)
1040 # Set dropout params
1041 if p_dropout > 0:
1042 is_dropout = True
1043 increment = batch_size * num_heads * 32
1044 philox_seed, philox_offset = philox_backend_seed_offset(increment)
1045 philox_args = torch.tensor(
1046 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
1047 )
1048 else:
1049 is_dropout = False
1050 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
1052 p_dropout = 1 - p_dropout
1053 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
1054 rp_dropout = 1.0 / p_dropout
1056 if return_softmax:
1057 assert is_dropout, "Only supported with non-zero dropout."
1058 p = torch.empty(
1059 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
1060 device=q_device,
1061 )
1062 else:
1063 p = torch.empty((), device=q_device)
1065 M_LOG2E = 1.4426950408889634074
1066 if softcap > 0.0:
1067 is_softcap = True
1068 adjusted_scale_softmax = softcap
1069 adjusted_softcap = softmax_scale / softcap
1070 adjusted_scale_softmax_log2e = softcap * M_LOG2E
1071 else:
1072 is_softcap = False
1073 adjusted_softcap = 0.0
1074 adjusted_scale_softmax = softmax_scale
1075 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
1077 # Set alibi params
1078 if alibi_slopes is not None:
1079 assert alibi_slopes.device == q_device
1080 assert alibi_slopes.dtype in (torch.float,)
1081 assert alibi_slopes.stride(-1) == 1
1082 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
1083 batch_size,
1084 num_heads,
1085 )
1086 alibi_slopes_batch_stride = (
1087 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
1088 )
1089 is_alibi = True
1090 else:
1091 alibi_slopes_batch_stride = 0
1092 is_alibi = False
1094 # ONLY EVEN_K IS SUPPORTED
1095 assert head_size == head_size_rounded
1097 # Do kernel dispatching
1098 def dispatch(B, H, Q, K, D, params):
1099 num_sms = torch_device_fn.get_device_properties(
1100 "cuda"
1101 ).multi_processor_count
1103 # Try bh parallel
1104 # if B * H > 0.8 * num_sms:
1105 # kernel = flash_fwd_bh_parallel_kernel[(H, B)]
1106 # # Yield kernel and prefilled args
1107 # return kernel, default_args, None, None
1109 # Try splitkv
1110 if not is_dropout and not is_local and not disable_splitkv:
1111 BM = block_m_splitkv_heuristic(D)
1112 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
1113 BN = block_n_splitkv_heuristic(D)
1114 n_blocks = triton.cdiv(seqlen_k, BN)
1115 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks)
1117 if n_splits > 1:
1118 logger.debug("kernel: flash_fwd_splitkv")
1119 lse_splits = torch.empty(
1120 (n_splits, B, H, Q), dtype=torch.float, device=q_device
1121 )
1122 out_splits = torch.empty(
1123 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device
1124 )
1125 grid = lambda args: (
1126 triton.cdiv(Q, args["BLOCK_M"]),
1127 n_splits,
1128 B * H,
1129 )
1130 splitkv_kernel = flash_fwd_splitkv_kernel[grid]
1131 params.o_ptr = out_splits
1132 params.softmax_lse_ptr = lse_splits
1133 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)}
1134 kernel = splitkv_kernel(*params.args(), **extra_args)
1136 if D >= 128:
1137 BLOCK_M = 4
1138 elif D >= 64:
1139 BLOCK_M = 8
1140 else:
1141 BLOCK_M = 16
1142 BLOCK_K = triton.next_power_of_2(D)
1143 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),)
1144 combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
1145 combine_args = {
1146 "out_ptr": out,
1147 "lse_ptr": lse,
1148 "head_size": head_size,
1149 "out_split_stride": out_splits.stride(0),
1150 "lse_split_stride": lse_splits.stride(0),
1151 "out_b_stride": out.stride(0),
1152 "out_s_stride": out.stride(-3),
1153 "out_h_stride": out.stride(-1),
1154 "out_splits_ptr": out_splits,
1155 "lse_splits_ptr": lse_splits,
1156 "n_splits": n_splits,
1157 "BLOCK_M": BLOCK_M,
1158 "BLOCK_K": BLOCK_K,
1159 "q_total": B * H * Q,
1160 "MAX_N_SPLITS": triton.next_power_of_2(n_splits),
1161 }
1162 combine_kernel(**combine_args)
1163 return kernel
1165 # Last option: flash_fwd
1166 logger.debug("kernel: flash_fwd")
1167 grid = lambda args: (
1168 triton.cdiv(Q, args["BLOCK_M"]),
1169 H * B,
1170 )
1171 kernel = flash_fwd_kernel[grid]
1172 kernel = kernel(*params.args())
1173 return kernel
1175 if _debug:
1176 p = torch.empty(
1177 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
1178 dtype=torch.float32,
1179 device=q_device,
1180 )
1181 return_softmax = True
1183 params = fwd_params(
1184 q, # q_ptr,
1185 k, # k_ptr,
1186 v, # v_ptr,
1187 out, # o_ptr,
1188 p, # p_ptr,
1189 lse, # softmax_lse_ptr,
1190 q.stride(-3), # q_row_stride,
1191 k.stride(-3), # k_row_stride,
1192 v.stride(-3), # v_row_stride,
1193 q.stride(-2), # q_head_stride,
1194 k.stride(-2), # k_head_stride,
1195 v.stride(-2), # v_head_stride,
1196 out.stride(-3), # o_row_stride,
1197 out.stride(-2), # o_head_stride,
1198 q.stride(0), # q_batch_stride,
1199 k.stride(0), # k_batch_stride,
1200 v.stride(0), # v_batch_stride,
1201 out.stride(0), # o_batch_stride,
1202 False, # is_cu_seqlens_q,
1203 None, # cu_seqlens_q_ptr,
1204 False, # is_cu_seqlens_k,
1205 None, # cu_seqlens_k_ptr,
1206 False, # is_seqused_k,
1207 None, # seqused_k_ptr,
1208 # sizes
1209 batch_size, # b,
1210 0, # bk,
1211 num_heads, # h,
1212 num_heads_k, # hk,
1213 num_heads // num_heads_k, # h_hk_ratio,
1214 seqlen_q, # seqlen_q,
1215 seqlen_k, # seqlen_k,
1216 seqlen_q_rounded, # seqlen_q_rounded,
1217 seqlen_k_rounded, # seqlen_k_rounded,
1218 head_size, # d,
1219 head_size_rounded, # d_rounded,
1220 # scaling factors
1221 is_softcap,
1222 adjusted_softcap, # softcap,
1223 adjusted_scale_softmax, # scale_softmax,
1224 adjusted_scale_softmax_log2e, # scale_softmax_log2,
1225 # dropout
1226 is_dropout,
1227 p_dropout,
1228 rp_dropout,
1229 p_dropout_in_uint8_t,
1230 philox_args,
1231 return_softmax,
1232 # causal and swa
1233 is_causal, # is_causal,
1234 is_local, # is_local,
1235 window_size_left, # window_size_left,
1236 window_size_right, # window_size_right,
1237 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
1238 False, # is_paged,
1239 # alibi
1240 is_alibi, #
1241 alibi_slopes, # alibi_slopes_ptr,
1242 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
1243 # block table params
1244 0, # total_q,
1245 None, # page_table_ptr,
1246 0, # page_table_batch_stride,
1247 0, # block_size,
1248 )
1250 # Move TxD to last dims for correct stride in Triton tt.load
1251 if flag_gems.vendor_name == "iluvatar":
1252 params.q_ptr = q.transpose(1, 2)
1253 params.k_ptr = k.transpose(1, 2)
1254 params.v_ptr = v.transpose(1, 2)
1255 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params)
1257 if _debug:
1258 print(f"{kernel.name} shared memory:", kernel.metadata.shared)
1259 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps)
1260 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages)
1261 # print(kernel.asm['ttgir'])
1263 if seqlenq_ngroups_swapped:
1264 out = out.transpose(1, 2).reshape(
1265 (batch_size, 1, num_heads_k * seqlen_q, head_size)
1266 )
1267 q = q.transpose(1, 2).reshape(
1268 (batch_size, 1, num_heads_k * seqlen_q, head_size)
1269 )
1270 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
1272 unused = torch.empty((), dtype=torch.int64, device=q_device)
1274 return out, q, k, v, lse, philox_args, unused, p