Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/flash_api.py: 0%
349 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
7import flag_gems
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.random_utils import philox_backend_seed_offset
12from .flash_kernel import (
13 block_m_splitkv_heuristic,
14 block_n_splitkv_heuristic,
15 flash_fwd_kernel,
16 flash_fwd_splitkv_combine_kernel,
17 flash_fwd_splitkv_kernel,
18 flash_varlen_fwd_kernel,
19)
21TOTAL_CORE_NUM = torch_device_fn.get_device_properties().multi_processor_count
23logger = logging.getLogger(__name__)
24_debug = False
27def CHECK_DEVICE(x):
28 assert x.device.type == flag_gems.device
31class fwd_params:
32 __slots__ = (
33 # pointers and strides
34 "q_ptr",
35 "k_ptr",
36 "v_ptr",
37 "o_ptr",
38 "p_ptr",
39 "softmax_lse_ptr",
40 "q_row_stride",
41 "k_row_stride",
42 "v_row_stride",
43 "q_head_stride",
44 "k_head_stride",
45 "v_head_stride",
46 "o_row_stride",
47 "o_head_stride",
48 "q_batch_stride",
49 "k_batch_stride",
50 "v_batch_stride",
51 "o_batch_stride",
52 "is_cu_seqlens_q",
53 "cu_seqlens_q_ptr",
54 "is_cu_seqlens_k",
55 "cu_seqlens_k_ptr",
56 "is_seqused_k",
57 "seqused_k_ptr",
58 # sizes
59 "b",
60 "bk",
61 "h",
62 "hk",
63 "h_hk_ratio",
64 "seqlen_q",
65 "seqlen_k",
66 "seqlen_q_rounded",
67 "seqlen_k_rounded",
68 "d",
69 "d_rounded",
70 # scaling factors
71 "is_softcap",
72 "softcap",
73 "scale_softmax",
74 "scale_softmax_log2",
75 # dropout
76 "is_dropout",
77 "p_dropout",
78 "rp_dropout",
79 "p_dropout_in_uint8_t",
80 "philox_args",
81 "return_softmax",
82 # masking
83 "is_causal",
84 "is_local",
85 "window_size_left",
86 "window_size_right",
87 "seqlenq_ngroups_swapped",
88 # alibi
89 "is_alibi",
90 "alibi_slopes_ptr",
91 "alibi_slopes_batch_stride",
92 # block table
93 "total_q",
94 "page_table_ptr",
95 "page_table_batch_stride",
96 "block_size",
97 )
99 def __init__(
100 self,
101 q_ptr,
102 k_ptr,
103 v_ptr,
104 o_ptr,
105 p_ptr,
106 softmax_lse_ptr,
107 q_row_stride,
108 k_row_stride,
109 v_row_stride,
110 q_head_stride,
111 k_head_stride,
112 v_head_stride,
113 o_row_stride,
114 o_head_stride,
115 q_batch_stride,
116 k_batch_stride,
117 v_batch_stride,
118 o_batch_stride,
119 is_cu_seqlens_q,
120 cu_seqlens_q_ptr,
121 is_cu_seqlens_k,
122 cu_seqlens_k_ptr,
123 is_seqused_k,
124 seqused_k_ptr,
125 # sizes
126 b,
127 bk,
128 h,
129 hk,
130 h_hk_ratio,
131 seqlen_q,
132 seqlen_k,
133 seqlen_q_rounded,
134 seqlen_k_rounded,
135 d,
136 d_rounded,
137 # scaling factors
138 is_softcap,
139 softcap,
140 scale_softmax,
141 scale_softmax_log2,
142 # dropout
143 is_dropout,
144 p_dropout,
145 rp_dropout,
146 p_dropout_in_uint8_t,
147 philox_args,
148 return_softmax,
149 # masking
150 is_causal,
151 is_local,
152 window_size_left,
153 window_size_right,
154 seqlenq_ngroups_swapped,
155 # alibi
156 is_alibi,
157 alibi_slopes_ptr,
158 alibi_slopes_batch_stride,
159 # block table
160 total_q,
161 page_table_ptr,
162 page_table_batch_stride,
163 block_size,
164 ):
165 self.q_ptr = q_ptr
166 self.k_ptr = k_ptr
167 self.v_ptr = v_ptr
168 self.o_ptr = o_ptr
169 self.p_ptr = p_ptr
170 self.softmax_lse_ptr = softmax_lse_ptr
171 self.q_row_stride = q_row_stride
172 self.k_row_stride = k_row_stride
173 self.v_row_stride = v_row_stride
174 self.q_head_stride = q_head_stride
175 self.k_head_stride = k_head_stride
176 self.v_head_stride = v_head_stride
177 self.o_row_stride = o_row_stride
178 self.o_head_stride = o_head_stride
179 self.q_batch_stride = q_batch_stride
180 self.k_batch_stride = k_batch_stride
181 self.v_batch_stride = v_batch_stride
182 self.o_batch_stride = o_batch_stride
183 self.is_cu_seqlens_q = is_cu_seqlens_q
184 self.cu_seqlens_q_ptr = cu_seqlens_q_ptr
185 self.is_cu_seqlens_k = is_cu_seqlens_k
186 self.cu_seqlens_k_ptr = cu_seqlens_k_ptr
187 self.is_seqused_k = is_seqused_k
188 self.seqused_k_ptr = seqused_k_ptr
189 # sizes
190 self.b = b
191 self.bk = bk
192 self.h = h
193 self.hk = hk
194 self.h_hk_ratio = h_hk_ratio
195 self.seqlen_q = seqlen_q
196 self.seqlen_k = seqlen_k
197 self.seqlen_q_rounded = seqlen_q_rounded
198 self.seqlen_k_rounded = seqlen_k_rounded
199 self.d = d
200 self.d_rounded = d_rounded
201 # scaling factors
202 self.is_softcap = is_softcap
203 self.softcap = softcap
204 self.scale_softmax = scale_softmax
205 self.scale_softmax_log2 = scale_softmax_log2
206 # dropout
207 self.is_dropout = is_dropout
208 self.p_dropout = p_dropout
209 self.rp_dropout = rp_dropout
210 self.p_dropout_in_uint8_t = p_dropout_in_uint8_t
211 self.philox_args = philox_args
212 self.return_softmax = return_softmax
213 # masking
214 self.is_causal = is_causal
215 self.is_local = is_local
216 self.window_size_left = window_size_left
217 self.window_size_right = window_size_right
218 self.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped
219 # alibi
220 self.is_alibi = is_alibi
221 self.alibi_slopes_ptr = alibi_slopes_ptr
222 self.alibi_slopes_batch_stride = alibi_slopes_batch_stride
223 # block table
224 self.total_q = total_q
225 self.page_table_ptr = page_table_ptr
226 self.page_table_batch_stride = page_table_batch_stride
227 self.block_size = block_size
229 def args(self):
230 return tuple(getattr(self, k) for k in self.__slots__)
233def mha_varlan_fwd(
234 q,
235 k,
236 v,
237 out,
238 cu_seqlens_q,
239 cu_seqlens_k,
240 seqused_k,
241 leftpad_k,
242 page_table,
243 alibi_slopes,
244 max_seqlen_q,
245 max_seqlen_k,
246 p_dropout,
247 softmax_scale,
248 zero_tensors,
249 is_causal,
250 window_size_left,
251 window_size_right,
252 softcap,
253 return_softmax,
254 gen,
255):
256 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
257 q_device = q.device
258 q_dtype = q.dtype
259 assert q_dtype in (
260 torch.float16,
261 torch.bfloat16,
262 ), "FlashAttention only support fp16 and bf16 data type"
263 assert q_dtype == k.dtype
264 assert q_dtype == v.dtype
265 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
266 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
267 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
269 assert cu_seqlens_q.dtype == torch.int32
270 assert cu_seqlens_q.is_contiguous()
272 assert cu_seqlens_k.dtype == torch.int32
273 assert cu_seqlens_k.is_contiguous()
275 assert page_table is not None
277 # q shape: [total_q_tokens, num_heads, head_size]
278 # k shape:
279 # paged_kv: [num_pages, block_size, num_heads_k, head_size]
280 # batch_size, number of sentences
281 total_q, num_heads, head_size = q.size()
282 num_heads_k = k.size(2)
283 batch_size = cu_seqlens_q.numel() - 1
284 block_size = k.size(1)
285 num_pages = k.size(0)
286 k_batch_size = num_pages
287 # max_num_pages_per_seq = page_table.size(1)
288 page_table_batch_stride = page_table.stride(0)
289 k_batch_stride = k.stride(0)
290 v_batch_stride = v.stride(0)
292 assert k.size() == v.size()
293 assert cu_seqlens_q.size() == (batch_size + 1,)
294 assert cu_seqlens_k.size() == (batch_size + 1,)
296 # Check output shape
297 if out is not None:
298 assert out.stride(-1) == 1
299 assert out.dtype == q.dtype
300 assert out.size() == (total_q, num_heads, head_size)
302 if seqused_k is not None:
303 assert seqused_k.is_contiguous()
304 assert seqused_k.size() == (batch_size,)
306 if max_seqlen_q == 1 and alibi_slopes is None:
307 is_causal = False
309 if is_causal:
310 window_size_right = 0
312 # check disable swa
313 if window_size_left >= max_seqlen_k:
314 window_size_left = -1
315 if window_size_right >= max_seqlen_k:
316 window_size_right = -1
318 is_local = window_size_left >= 0
320 # Optimize all single-query sequences by swapping the query-group and sequence dimensions
321 seqlenq_ngroups_swapped = (
322 max_seqlen_q == 1
323 and alibi_slopes is None
324 and num_heads > num_heads_k
325 and window_size_left < 0
326 and window_size_right < 0
327 and p_dropout == 0
328 )
329 q_groups = num_heads // num_heads_k
330 if seqlenq_ngroups_swapped:
331 q = (
332 q.reshape((batch_size, num_heads_k, q_groups, head_size))
333 .transpose(1, 2)
334 .reshape(batch_size * q_groups, num_heads_k, head_size)
335 )
336 max_seqlen_q = q_groups
337 num_heads = num_heads_k
338 cu_seqlens_q = None
339 q_batch_stride = q.stride(0) * max_seqlen_q
340 k_batch_stride = k.stride(0)
341 v_batch_stride = v.stride(0)
342 # o_batch_stride = out.stride(0) * max_seqlen_q
343 else:
344 q_batch_stride = 0
345 k_batch_stride = 0
346 v_batch_stride = 0
347 o_batch_stride = 0
349 total_q = q.size(0)
351 assert leftpad_k is None, "leftpad_k is not supported."
352 assert (
353 head_size <= 256
354 ), "FlashAttention forward only supports head dimension at most 256"
355 assert (
356 head_size % 8 == 0
357 ), "head_size must be a multiple of 8, this is ensured by padding!"
358 assert (
359 num_heads % num_heads_k == 0
360 ), "Number of heads in key/value must divide number of heads in query"
362 assert q.shape == (total_q, num_heads, head_size)
363 assert k.shape == (num_pages, block_size, num_heads_k, head_size)
364 assert v.shape == (num_pages, block_size, num_heads_k, head_size)
365 assert k.stride() == v.stride()
367 if softcap > 0.0:
368 assert p_dropout == 0, "dropout is not supported if softcap is used."
370 round_multiple = lambda x, m: (x + m - 1) // m * m
371 head_size_rounded = round_multiple(head_size, 32) if head_size < 192 else 256
372 seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
373 seqlen_k_rounded = round_multiple(max_seqlen_k, 32)
375 M_LOG2E = 1.4426950408889634074
376 if softcap > 0.0:
377 is_softcap = True
378 adjusted_scale_softmax = softcap
379 adjusted_softcap = softmax_scale / softcap
380 adjusted_scale_softmax_log2e = softcap * M_LOG2E
381 else:
382 is_softcap = False
383 adjusted_softcap = 0.0
384 adjusted_scale_softmax = softmax_scale
385 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
387 # Set alibi params
388 if alibi_slopes is not None:
389 assert alibi_slopes.device == q_device
390 assert alibi_slopes.dtype in (torch.float,)
391 assert alibi_slopes.stride(-1) == 1
392 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
393 batch_size,
394 num_heads,
395 )
396 alibi_slopes_batch_stride = (
397 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
398 )
399 is_alibi = True
400 else:
401 alibi_slopes_batch_stride = 0
402 is_alibi = False
404 # Prepare params to kernel
405 with torch_device_fn.device(q_device):
406 if out is not None:
407 out_ = out
408 if seqlenq_ngroups_swapped:
409 out = torch.empty_like(q, dtype=v.dtype)
410 else:
411 out_ = None
412 out = torch.empty_like(q, dtype=v.dtype)
414 if seqlenq_ngroups_swapped:
415 o_batch_stride = out.stride(0) * max_seqlen_q
417 lse = torch.empty((num_heads, total_q), dtype=torch.float, device=q_device)
419 if p_dropout > 0:
420 is_dropout = True
421 increment = batch_size * num_heads * 32
422 philox_seed, philox_offset = philox_backend_seed_offset(increment)
423 philox_args = torch.tensor(
424 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
425 )
426 else:
427 is_dropout = False
428 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
430 p_dropout = 1 - p_dropout
431 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
432 rp_dropout = 1.0 / p_dropout
434 if return_softmax:
435 assert is_dropout, "Only supported with non-zero dropout."
436 p = torch.empty(
437 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
438 device=q_device,
439 )
440 else:
441 p = torch.empty((), device=q_device)
443 if zero_tensors:
444 out.zero_()
445 lse.fill_(float("-inf"))
447 params = fwd_params(
448 q, # q_ptr,
449 k, # k_ptr,
450 v, # v_ptr,
451 out, # o_ptr,
452 p, # p_ptr,
453 lse, # softmax_lse_ptr,
454 q.stride(-3), # q_row_stride,
455 k.stride(-3), # k_row_stride,
456 v.stride(-3), # v_row_stride,
457 q.stride(-2), # q_head_stride,
458 k.stride(-2), # k_head_stride,
459 v.stride(-2), # v_head_stride,
460 out.stride(-3), # o_row_stride,
461 out.stride(-2), # o_head_stride,
462 q_batch_stride, # q_batch_stride,
463 k_batch_stride, # k_batch_stride,
464 v_batch_stride, # v_batch_stride,
465 o_batch_stride, # o_batch_stride,
466 cu_seqlens_q is not None, # is_cu_seqlens_q,
467 cu_seqlens_q, # cu_seqlens_q_ptr,
468 seqused_k is None, # is_cu_seqlens_k,
469 cu_seqlens_k, # cu_seqlens_k_ptr,
470 seqused_k is not None, # is_seqused_k,
471 seqused_k, # seqused_k_ptr,
472 # sizes
473 batch_size, # b,
474 k_batch_size, # bk,
475 num_heads, # h,
476 num_heads_k, # hk,
477 num_heads // num_heads_k, # h_hk_ratio,
478 max_seqlen_q, # seqlen_q,
479 max_seqlen_k, # seqlen_k,
480 seqlen_q_rounded, # seqlen_q_rounded,
481 seqlen_k_rounded, # seqlen_k_rounded,
482 head_size, # d,
483 head_size_rounded, # d_rounded,
484 # scaling factors
485 is_softcap,
486 adjusted_softcap, # softcap,
487 adjusted_scale_softmax, # scale_softmax,
488 adjusted_scale_softmax_log2e, # scale_softmax_log2,
489 # dropout
490 is_dropout,
491 p_dropout,
492 rp_dropout,
493 p_dropout_in_uint8_t,
494 philox_args,
495 return_softmax,
496 # causal and swa
497 is_causal, # is_causal,
498 is_local, # is_local,
499 window_size_left, # window_size_left,
500 window_size_right, # window_size_right,
501 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
502 # alibi
503 is_alibi, #
504 alibi_slopes, # alibi_slopes_ptr,
505 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
506 # block table params
507 total_q, # total_q,
508 page_table, # page_table_ptr,
509 page_table_batch_stride, # page_table_batch_stride,
510 block_size, # block_size,
511 )
513 logger.debug("kernel: flash_varlen_fwd")
514 grid = lambda args: (
515 triton.cdiv(max_seqlen_q, args["BLOCK_M"]),
516 batch_size,
517 num_heads,
518 )
519 kernel = flash_varlen_fwd_kernel[grid]
520 args = tuple(getattr(params, k) for k in params.__slots__)
522 # We have to forego parameter autotuning and particularly fix BLOCK_N
523 # to avoid breaking a kv block onto multiple cache pages.
524 cfg = runtime.get_heuristic_config("mha_varlen_fwd")
525 cfg_params = {
526 "BLOCK_M": cfg["BLOCK_M"](params),
527 "BLOCK_N": cfg["BLOCK_N"](params),
528 "num_warps": cfg["num_warps"](params),
529 "num_stages": cfg["num_stages"](params),
530 }
531 # BLOCK_M, BLOCK_N, num_warps, num_stages = 128, 32, 4, 3
532 assert (
533 block_size % cfg_params["BLOCK_N"] == 0
534 ), f"block_size must be divisible by {cfg_params['BLOCK_N']}."
535 kernel(*args, **cfg_params)
537 if seqlenq_ngroups_swapped:
538 out = out.reshape(
539 batch_size, max_seqlen_q, num_heads_k, head_size
540 ).transpose(1, 2)
541 if out_ is not None:
542 out_.view(batch_size, num_heads_k, max_seqlen_q, head_size).copy_(out)
543 out = out_
544 else:
545 out = out.reshape(batch_size, num_heads_k * max_seqlen_q, head_size)
546 lse = lse.reshape(num_heads_k, batch_size, max_seqlen_q)
547 lse = lse.reshape(num_heads_k * max_seqlen_q, batch_size)
549 unused = torch.empty((), dtype=torch.int64, device=q_device)
550 return out, q, k, v, lse, philox_args, unused, p
553def mha_fwd(
554 q,
555 k,
556 v,
557 out,
558 alibi_slopes,
559 p_dropout,
560 softmax_scale,
561 is_causal,
562 window_size_left,
563 window_size_right,
564 softcap,
565 return_softmax,
566 disable_splitkv=False,
567):
568 CHECK_DEVICE(q), CHECK_DEVICE(k), CHECK_DEVICE(v)
569 q_dtype = q.dtype
570 q_device = q.device
571 assert q_dtype in (
572 torch.float16,
573 torch.bfloat16,
574 ), "FlashAttention only support fp16 and bf16 data type"
575 assert q_dtype == k.dtype
576 assert q_dtype == v.dtype
577 assert q.stride(-1) == 1, "Input tensor must have contiguous last dimension"
578 assert k.stride(-1) == 1, "Input tensor must have contiguous last dimension"
579 assert v.stride(-1) == 1, "Input tensor must have contiguous last dimension"
580 batch_size, seqlen_q, num_heads, head_size = q.size()
581 _, seqlen_k, num_heads_k, _ = k.size()
583 # Check output shape
584 if out is not None:
585 assert out.stride(-1) == 1
586 assert out.dtype == q.dtype
587 assert out.size() == (batch_size, seqlen_q, num_heads, head_size)
588 CHECK_DEVICE(out)
590 assert (
591 head_size % 8 == 0
592 ), "head_size must be a multiple of 8, this is ensured by padding!"
593 assert (
594 num_heads % num_heads_k == 0
595 ), "Number of heads in key/value must divide number of heads in query"
596 if window_size_left >= seqlen_k:
597 window_size_left = -1
598 if window_size_right >= seqlen_k:
599 window_size_right = -1
600 if seqlen_q == 1 and alibi_slopes is None:
601 is_causal = False
602 if is_causal:
603 window_size_right = 0
605 is_causal = window_size_left < 0 and window_size_right == 0
606 is_local = window_size_left >= 0 and window_size_right >= 0
608 seqlenq_ngroups_swapped = (
609 seqlen_q == 1
610 and alibi_slopes is None
611 and num_heads > num_heads_k
612 and window_size_left < 0
613 and window_size_right < 0
614 and p_dropout == 0
615 )
616 q_groups = num_heads // num_heads_k
618 if seqlenq_ngroups_swapped:
619 logger.debug("q_kg swapped.")
620 q = q.reshape(batch_size, num_heads_k, q_groups, head_size).transpose(1, 2)
621 seqlen_q = q_groups
622 num_heads = num_heads_k
624 round_multiple = lambda x, m: (x + m - 1) // m * m
625 head_size_rounded = round_multiple(head_size, 32)
626 seqlen_q_rounded = round_multiple(seqlen_q, 128)
627 seqlen_k_rounded = round_multiple(seqlen_k, 32)
629 def splits_heuristic(num_tasks, num_sms, n_blocks):
630 # splits when wave efficiency is low
631 n_waves = triton.cdiv(num_tasks, num_sms)
632 eff = (num_tasks / num_sms) / n_waves
633 if eff > 0.8 or n_waves > 1:
634 return 1
636 min_blocks_per_split = 2
637 best_splits = min(
638 triton.cdiv(n_blocks, min_blocks_per_split),
639 int(math.floor(1.0 / eff)),
640 num_sms,
641 )
643 return best_splits
645 with torch_device_fn.device(q_device):
646 # Set softmax params
647 lse = torch.empty(
648 (batch_size, num_heads, seqlen_q), dtype=torch.float, device=q_device
649 )
651 if out is not None:
652 if seqlenq_ngroups_swapped:
653 out = out.reshape(
654 batch_size, num_heads_k, q_groups, head_size
655 ).transpose(1, 2)
656 else:
657 out = torch.empty_like(q, dtype=v.dtype)
659 # Set dropout params
660 if p_dropout > 0:
661 is_dropout = True
662 increment = batch_size * num_heads * 32
663 philox_seed, philox_offset = philox_backend_seed_offset(increment)
664 philox_args = torch.tensor(
665 [philox_seed, philox_offset], dtype=torch.int64, device=q_device
666 )
667 else:
668 is_dropout = False
669 philox_args = torch.empty((2,), dtype=torch.int64, device=q_device)
671 p_dropout = 1 - p_dropout
672 p_dropout_in_uint8_t = math.floor(p_dropout * 255.0)
673 rp_dropout = 1.0 / p_dropout
675 if return_softmax:
676 assert is_dropout, "Only supported with non-zero dropout."
677 p = torch.empty(
678 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
679 device=q_device,
680 )
681 else:
682 p = torch.empty((), device=q_device)
684 M_LOG2E = 1.4426950408889634074
685 if softcap > 0.0:
686 is_softcap = True
687 adjusted_scale_softmax = softcap
688 adjusted_softcap = softmax_scale / softcap
689 adjusted_scale_softmax_log2e = softcap * M_LOG2E
690 else:
691 is_softcap = False
692 adjusted_softcap = 0.0
693 adjusted_scale_softmax = softmax_scale
694 adjusted_scale_softmax_log2e = softmax_scale * M_LOG2E
696 # Set alibi params
697 if alibi_slopes is not None:
698 assert alibi_slopes.device == q_device
699 assert alibi_slopes.dtype in (torch.float,)
700 assert alibi_slopes.stride(-1) == 1
701 assert alibi_slopes.shape == (num_heads,) or alibi_slopes.shape == (
702 batch_size,
703 num_heads,
704 )
705 alibi_slopes_batch_stride = (
706 alibi_slopes.stride(0) if alibi_slopes.ndim == 2 else 0
707 )
708 is_alibi = True
709 else:
710 alibi_slopes_batch_stride = 0
711 is_alibi = False
713 # ONLY EVEN_K IS SUPPORTED
714 assert head_size == head_size_rounded
716 # Do kernel dispatching
717 def dispatch(B, H, Q, K, D, params):
718 num_sms = TOTAL_CORE_NUM
720 # Try bh parallel
721 # if B * H > 0.8 * num_sms:
722 # kernel = flash_fwd_bh_parallel_kernel[(H, B)]
723 # # Yield kernel and prefilled args
724 # return kernel, default_args, None, None
726 # Try splitkv
727 if not is_dropout and not is_local and not disable_splitkv:
728 BM = block_m_splitkv_heuristic(D)
729 n_tasks = B * H * triton.cdiv(seqlen_q, BM)
730 BN = block_n_splitkv_heuristic(D)
731 n_blocks = triton.cdiv(seqlen_k, BN)
732 n_splits = splits_heuristic(n_tasks, num_sms, n_blocks)
734 # if _debug:
735 # n_splits = 32
736 # n_blocks = triton.cdiv(K, BN)
737 # blocks_per_split = triton.cdiv(n_blocks, n_splits)
738 # print("block_n:", BN)
739 # print("n_splits:", n_splits)
740 # print("blocks_per_split", blocks_per_split)
742 if n_splits > 1:
743 logger.debug("kernel: flash_fwd_splitkv")
744 lse_splits = torch.empty(
745 (n_splits, B, H, Q), dtype=torch.float, device=q_device
746 )
747 out_splits = torch.empty(
748 (n_splits, B, H, Q, D), dtype=torch.float, device=q_device
749 )
750 grid = lambda args: (
751 triton.cdiv(Q, args["BLOCK_M"]),
752 n_splits,
753 B * H,
754 )
755 splitkv_kernel = flash_fwd_splitkv_kernel[grid]
756 params.o_ptr = out_splits
757 params.softmax_lse_ptr = lse_splits
758 extra_args = {"blocks_per_split": triton.cdiv(n_blocks, n_splits)}
759 kernel = splitkv_kernel(*params.args(), **extra_args)
761 if D % 128 == 0:
762 BLOCK_M = 4
763 elif D % 64 == 0:
764 BLOCK_M = 8
765 else:
766 BLOCK_M = 16
767 grid = lambda args: (triton.cdiv(B * H * Q, BLOCK_M),)
768 combine_kernel = flash_fwd_splitkv_combine_kernel[grid]
769 combine_args = {
770 "out_ptr": out,
771 "lse_ptr": lse,
772 "head_size": head_size,
773 "out_b_stride": out.stride(0),
774 "out_s_stride": out.stride(-3),
775 "out_h_stride": out.stride(-1),
776 "out_splits_ptr": out_splits,
777 "lse_splits_ptr": lse_splits,
778 "n_splits": n_splits,
779 "BLOCK_M": BLOCK_M,
780 "q_total": B * H * Q,
781 "MAX_N_SPLITS": triton.next_power_of_2(n_splits),
782 }
783 combine_kernel(**combine_args)
784 return kernel
786 # Last option: flash_fwd
787 logger.debug("kernel: flash_fwd")
788 grid = lambda args: (
789 triton.cdiv(Q, args["BLOCK_M"]),
790 H * B,
791 )
792 kernel = flash_fwd_kernel[grid]
793 kernel = kernel(*params.args())
794 return kernel
796 if _debug:
797 p = torch.empty(
798 (batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded),
799 dtype=torch.float32,
800 device=q_device,
801 )
802 return_softmax = True
804 params = fwd_params(
805 q, # q_ptr,
806 k, # k_ptr,
807 v, # v_ptr,
808 out, # o_ptr,
809 p, # p_ptr,
810 lse, # softmax_lse_ptr,
811 q.stride(-3), # q_row_stride,
812 k.stride(-3), # k_row_stride,
813 v.stride(-3), # v_row_stride,
814 q.stride(-2), # q_head_stride,
815 k.stride(-2), # k_head_stride,
816 v.stride(-2), # v_head_stride,
817 out.stride(-3), # o_row_stride,
818 out.stride(-2), # o_head_stride,
819 q.stride(0), # q_batch_stride,
820 k.stride(0), # k_batch_stride,
821 v.stride(0), # v_batch_stride,
822 out.stride(0), # o_batch_stride,
823 False, # is_cu_seqlens_q,
824 None, # cu_seqlens_q_ptr,
825 False, # is_cu_seqlens_k,
826 None, # cu_seqlens_k_ptr,
827 False, # is_seqused_k,
828 None, # seqused_k_ptr,
829 # sizes
830 batch_size, # b,
831 0, # bk,
832 num_heads, # h,
833 num_heads_k, # hk,
834 num_heads // num_heads_k, # h_hk_ratio,
835 seqlen_q, # seqlen_q,
836 seqlen_k, # seqlen_k,
837 seqlen_q_rounded, # seqlen_q_rounded,
838 seqlen_k_rounded, # seqlen_k_rounded,
839 head_size, # d,
840 head_size_rounded, # d_rounded,
841 # scaling factors
842 is_softcap,
843 adjusted_softcap, # softcap,
844 adjusted_scale_softmax, # scale_softmax,
845 adjusted_scale_softmax_log2e, # scale_softmax_log2,
846 # dropout
847 is_dropout,
848 p_dropout,
849 rp_dropout,
850 p_dropout_in_uint8_t,
851 philox_args,
852 return_softmax,
853 # causal and swa
854 is_causal, # is_causal,
855 is_local, # is_local,
856 window_size_left, # window_size_left,
857 window_size_right, # window_size_right,
858 seqlenq_ngroups_swapped, # seqlenq_ngroups_swapped,
859 # alibi
860 is_alibi, #
861 alibi_slopes, # alibi_slopes_ptr,
862 alibi_slopes_batch_stride, # alibi_slopes_batch_stride,
863 # block table params
864 0, # total_q,
865 None, # page_table_ptr,
866 0, # page_table_batch_stride,
867 0, # block_size,
868 )
870 kernel = dispatch(batch_size, num_heads, seqlen_q, seqlen_k, head_size, params)
872 if _debug:
873 print(f"{kernel.name} shared memory:", kernel.metadata.shared)
874 print(f"{kernel.name} num_warps:", kernel.metadata.num_warps)
875 print(f"{kernel.name} num_stages:", kernel.metadata.num_stages)
876 # print(kernel.asm['ttgir'])
878 if seqlenq_ngroups_swapped:
879 out = out.transpose(1, 2).reshape(
880 (batch_size, 1, num_heads_k * seqlen_q, head_size)
881 )
882 q = q.transpose(1, 2).reshape(
883 (batch_size, 1, num_heads_k * seqlen_q, head_size)
884 )
885 lse = lse.reshape((batch_size, num_heads_k * seqlen_q, 1))
887 unused = torch.empty((), dtype=torch.int64, device=q_device)
889 return out, q, k, v, lse, philox_args, unused, p