Coverage for src/flag_gems/patches/patch_vllm_all.py: 15%
200 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import os
2from typing import Optional, Tuple
4import torch
5import torch.nn.functional as F
7import flag_gems
8from flag_gems.fused import top_k_per_row_prefill
9from flag_gems.patches.patch_util import (
10 init_vllm_libraries,
11 patch_module_method,
12 patch_vllm_lib,
13)
16def custom_gems_rms_forward_cuda(self, x, residual=None):
17 from flag_gems.modules.normalization import gems_rms_forward
19 return gems_rms_forward(x, residual, self.weight, self.variance_epsilon)
22def custom_gems_rope_forward_cuda(
23 self,
24 positions: torch.Tensor,
25 query: torch.Tensor,
26 key: torch.Tensor,
27 offsets: Optional[torch.Tensor] = None,
28) -> Tuple[torch.Tensor, torch.Tensor]:
29 from flag_gems.modules.rotary_embedding import gems_rope_forward
31 self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
32 if offsets is not None:
33 positions = positions + offsets
34 positions = positions.flatten()
35 num_tokens = positions.shape[0]
37 query_shape = query.shape
38 key_shape = key.shape
39 query = query.view(num_tokens, -1, self.head_size)
40 key = key.view(num_tokens, -1, self.head_size)
42 query_rot = query[..., : self.rotary_dim]
43 key_rot = key[..., : self.rotary_dim]
44 if self.rotary_dim < self.head_size:
45 query_pass = query[..., self.rotary_dim :]
46 key_pass = key[..., self.rotary_dim :]
48 cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
50 q_embed, k_embed = gems_rope_forward(
51 query_rot,
52 key_rot,
53 cos,
54 sin,
55 position_ids=positions,
56 rotary_interleaved=not self.is_neox_style,
57 inplace=True, # set inplace to True for vLLM compatibility
58 )
60 if self.rotary_dim < self.head_size:
61 query = torch.cat((q_embed, query_pass), dim=-1).reshape(query_shape)
62 key = torch.cat((k_embed, key_pass), dim=-1).reshape(key_shape)
63 else:
64 query = q_embed.reshape(query_shape)
65 key = k_embed.reshape(key_shape)
67 return query, key
70def custom_gems_silu_and_mul(self, x: torch.Tensor) -> torch.Tensor:
71 from flag_gems.modules.activation import gems_silu_and_mul
73 d = x.shape[-1] // 2
74 x1, x2 = x[..., :d], x[..., d:]
75 return gems_silu_and_mul(x1, x2)
78def custom_gems_write_to_paged_cache(
79 key,
80 value,
81 key_cache,
82 value_cache,
83 slot_mapping,
84 kv_cache_dtype,
85 k_scale,
86 v_scale,
87):
88 from flag_gems.fused.reshape_and_cache import reshape_and_cache
90 reshape_and_cache(
91 key,
92 value,
93 key_cache,
94 value_cache,
95 slot_mapping.flatten(),
96 kv_cache_dtype,
97 k_scale,
98 v_scale,
99 )
102def custom_gems_flash_mla_forward(
103 self,
104 q_nope,
105 q_pe,
106 kv_c_and_k_pe_cache,
107 attn_metadata,
108) -> torch.Tensor:
109 from flag_gems.fused import flash_mla
111 assert kv_c_and_k_pe_cache.numel() > 0
112 assert attn_metadata.decode is not None
114 if self.kv_cache_dtype.startswith("fp8"):
115 raise NotImplementedError("FP8 Triton MLA not yet supported")
117 batch, num_head_q, head_dim_v = q_nope.shape
118 seqlen_q = 1
120 q = torch.cat([q_nope, q_pe], dim=-1)
121 head_dim = q.shape[-1]
122 q = q.view(batch, seqlen_q, num_head_q, head_dim)
124 # Add a head dim of 1
125 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
126 PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
128 block_table = attn_metadata.decode.block_table
129 output = flash_mla(
130 q,
131 block_table,
132 kv_c_and_k_pe_cache,
133 None,
134 PAGE_SIZE,
135 batch,
136 seqlen_q,
137 attn_metadata.decode.seq_lens,
138 num_head_q,
139 None,
140 head_dim,
141 head_dim_v,
142 True,
143 )
145 o = self._v_up_proj_and_o_proj(output)
146 return o
149def custom_gems_flash_attention_impl_forward(
150 self,
151 layer: torch.nn.Module,
152 query: torch.Tensor,
153 key: torch.Tensor,
154 value: torch.Tensor,
155 kv_cache: torch.Tensor,
156 attn_metadata, #: FlashAttentionMetadata,
157 output: Optional[torch.Tensor] = None,
158 output_scale: Optional[torch.Tensor] = None,
159 output_block_scale: Optional[torch.Tensor] = None,
160 **kwargs,
161) -> torch.Tensor:
162 from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash
164 assert output is not None, "Output tensor must be provided."
166 if output_scale is not None:
167 raise NotImplementedError(
168 "fused output quantization is not yet supported" " for FlashAttentionImpl"
169 )
171 if attn_metadata is None:
172 # Profiling run.
173 return output
175 num_actual_tokens = attn_metadata.num_actual_tokens
176 key_cache, value_cache = kv_cache.unbind(0)
178 reshape_and_cache_flash(
179 key,
180 value,
181 key_cache,
182 value_cache,
183 attn_metadata.slot_mapping,
184 self.kv_cache_dtype,
185 layer._k_scale,
186 layer._v_scale,
187 )
189 # TODO: Support FP8
190 if self.kv_cache_dtype.startswith("fp8"):
191 raise NotImplementedError(
192 "FP8 quantization is not yet supported for FlashAttentionImpl"
193 )
194 # key_cache = key_cache.view(torch.float8_e4m3fn)
195 # value_cache = value_cache.view(torch.float8_e4m3fn)
196 # num_tokens, num_heads, head_size = query.shape
197 # query, _ = ops.scaled_fp8_quant(
198 # query.reshape((num_tokens, num_heads * head_size)).contiguous(),
199 # layer._q_scale,
200 # )
201 # query = query.reshape((num_tokens, num_heads, head_size))
203 # Compute attention and update output up to `num_actual_tokens`.
204 # use_local_attn = self.use_irope and attn_metadata.local_attn_metadata is not None
205 use_local_attn = (
206 getattr(self, "use_irope", False)
207 and getattr(attn_metadata, "local_attn_metadata", None) is not None
208 )
209 if not attn_metadata.use_cascade or use_local_attn:
210 if use_local_attn:
211 assert attn_metadata.local_attn_metadata is not None
212 local_metadata = attn_metadata.local_attn_metadata
213 cu_seqlens_q = local_metadata.local_query_start_loc
214 seqused_k = local_metadata.local_seqused_k
215 max_seqlen_q = local_metadata.local_max_query_len
216 max_seqlen_k = local_metadata.local_max_seq_len
217 block_table = local_metadata.local_block_table
218 scheduler_metadata = local_metadata.local_scheduler_metadata
219 else:
220 cu_seqlens_q = attn_metadata.query_start_loc
221 seqused_k = attn_metadata.seq_lens
222 max_seqlen_q = attn_metadata.max_query_len
223 max_seqlen_k = attn_metadata.max_seq_len
224 block_table = attn_metadata.block_table
225 scheduler_metadata = attn_metadata.scheduler_metadata
227 descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
229 flash_attn_varlen_func(
230 q=query[:num_actual_tokens],
231 k=key_cache,
232 v=value_cache,
233 out=output[:num_actual_tokens],
234 cu_seqlens_q=cu_seqlens_q,
235 max_seqlen_q=max_seqlen_q,
236 seqused_k=seqused_k,
237 max_seqlen_k=max_seqlen_k,
238 softmax_scale=self.scale,
239 causal=True,
240 alibi_slopes=self.alibi_slopes,
241 window_size=self.sliding_window,
242 block_table=block_table,
243 softcap=self.logits_soft_cap,
244 scheduler_metadata=scheduler_metadata,
245 fa_version=2,
246 q_descale=layer._q_scale.expand(descale_shape),
247 k_descale=layer._k_scale.expand(descale_shape),
248 v_descale=layer._v_scale.expand(descale_shape),
249 s_aux=None,
250 num_splits=0,
251 cp_world_size=1,
252 cp_rank=0,
253 cp_tot_seqused_k=None,
254 )
255 return output
257 # TODO: Support cascade_attention.
258 raise NotImplementedError("Cascade attention is not implemented in flag_gems.")
261def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor):
262 d = input.size(-1) // 2
263 x, y = input.split(d, dim=-1)
264 flag_gems.silu_and_mul_out(x, y, out)
267def custom_silu_and_mul_with_clamp(
268 out: torch.Tensor, input: torch.Tensor, limit: float
269):
270 d = input.size(-1) // 2
271 x, y = input.split(d, dim=-1)
272 flag_gems.silu_and_mul_with_clamp_out(x, y, out, limit)
275def custom_hc_head_fused_kernel(
276 hs_flat: torch.Tensor,
277 fn: torch.Tensor,
278 hc_scale: torch.Tensor,
279 hc_base: torch.Tensor,
280 out: torch.Tensor,
281 hidden_size: int,
282 rms_eps: float,
283 hc_eps: float,
284 hc_mult: int,
285):
286 return flag_gems.hc_head_fused_kernel(
287 hs_flat, fn, hc_scale, hc_base, out, hidden_size, rms_eps, hc_eps, hc_mult
288 )
291def custom_moe_align_block_size(
292 topk_ids: torch.Tensor,
293 num_experts: int,
294 block_size: int,
295 sorted_token_ids: torch.Tensor,
296 experts_ids: torch.Tensor,
297 num_tokens_post_pad: torch.Tensor,
298):
299 flag_gems.moe_align_block_size_triton(
300 topk_ids,
301 num_experts,
302 block_size,
303 sorted_token_ids,
304 experts_ids,
305 num_tokens_post_pad,
306 )
309def custom_moe_grouped_topk(
310 gating_output: torch.Tensor,
311 n_group: int,
312 topk_group: int,
313 topk: int,
314 renormalize: bool,
315 routed_scaling_factor: float,
316 bias: torch.Tensor,
317 scoring_func: int = 0,
318):
319 from flag_gems.fused import grouped_topk
321 return grouped_topk(
322 scores=gating_output,
323 n_group=n_group,
324 topk_group=topk_group,
325 topk=topk,
326 renormalize=renormalize,
327 routed_scaling_factor=routed_scaling_factor,
328 bias=bias,
329 scoring_func=scoring_func,
330 )
333def custom_topk_softmax(
334 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False
335):
336 flag_gems.topk_softmax(
337 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
338 )
341def custom_moe_sum(input: torch.Tensor, output: torch.Tensor):
342 from flag_gems.fused import moe_sum
344 moe_sum(input, output)
347def custom_apply_repetition_penalties(
348 logits: torch.Tensor,
349 prompt_mask: torch.Tensor,
350 output_mask: torch.Tensor,
351 repetition_penalties: torch.Tensor,
352):
353 return flag_gems.apply_repetition_penalties(
354 logits, prompt_mask, output_mask, repetition_penalties
355 )
358def custom_get_scheduler_metadata(
359 batch_size: int,
360 max_seqlen_q: int,
361 max_seqlen_k: int,
362 num_heads: int,
363 num_heads_k: int,
364 headdim: int,
365 headdim_v: int,
366 qkv_dtype: torch.dtype,
367 seqused_k: torch.Tensor,
368 cu_seqlens_q: Optional[torch.Tensor] = None,
369 cu_seqlens_k: Optional[torch.Tensor] = None,
370 cu_seqlens_k_new: Optional[torch.Tensor] = None,
371 seqused_q: Optional[torch.Tensor] = None,
372 leftpad_k: Optional[torch.Tensor] = None,
373 page_size: Optional[int] = None,
374 max_seqlen_k_new: int = 0,
375 is_causal: bool = False,
376 window_size_left: int = -1,
377 window_size_right: int = -1,
378 has_softcap: bool = False,
379 num_splits: int = 0,
380 pack_gqa: Optional[bool] = None,
381 sm_margin: int = 0,
382):
383 return flag_gems.get_scheduler_metadata(
384 batch_size,
385 max_seqlen_q,
386 max_seqlen_k,
387 num_heads,
388 num_heads_k,
389 headdim,
390 headdim_v,
391 qkv_dtype,
392 seqused_k,
393 cu_seqlens_q=cu_seqlens_q,
394 cu_seqlens_k=cu_seqlens_k,
395 cu_seqlens_k_new=cu_seqlens_k_new,
396 seqused_q=seqused_q,
397 leftpad_k=leftpad_k,
398 page_size=page_size,
399 max_seqlen_k_new=max_seqlen_k_new,
400 is_causal=is_causal,
401 window_size_left=window_size_left,
402 window_size_right=window_size_right,
403 has_softcap=has_softcap,
404 num_splits=num_splits,
405 pack_gqa=pack_gqa,
406 sm_margin=sm_margin,
407 )
410def custom_per_token_group_fp8_quant(
411 input: torch.Tensor,
412 output_q: torch.Tensor,
413 output_s: torch.Tensor,
414 group_size: int,
415 eps: float,
416 fp8_min: float,
417 fp8_max: float,
418 scale_ue8m0: bool = False,
419):
420 from flag_gems.ops import per_token_group_quant_fp8
422 column_major_scales = output_s.stride(0) < output_s.stride(1)
424 x_q, x_s = per_token_group_quant_fp8(
425 x=input,
426 group_size=group_size,
427 eps=eps,
428 column_major_scales=column_major_scales,
429 scale_ue8m0=scale_ue8m0,
430 )
432 output_q.copy_(x_q)
433 output_s.copy_(x_s)
436def custom_cutlass_scaled_mm(
437 output: torch.Tensor,
438 input: torch.Tensor,
439 weight: torch.Tensor,
440 scale_a: torch.Tensor,
441 scale_b: torch.Tensor,
442 bias: torch.Tensor | None = None,
443):
444 return flag_gems.cutlass_scaled_mm(output, input, weight, scale_a, scale_b, bias)
447def custom_top_k_per_row_prefill(
448 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k
449):
450 top_k_per_row_prefill(
451 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k
452 )
455def custom_concat_and_cache_mla(
456 kv_c: torch.Tensor,
457 k_pe: torch.Tensor,
458 kv_cache: torch.Tensor,
459 slot_mapping: torch.Tensor,
460 kv_cache_dtype: str,
461 scale: torch.Tensor,
462) -> None:
463 return flag_gems.concat_and_cache_mla(
464 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
465 )
468def custom_gems_flashattn_mla_forward_decode(
469 self,
470 q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
471 kv_c_and_k_pe_cache: torch.Tensor,
472 attn_metadata, # FlashAttnMLAMetadata
473 layer, # AttentionLayer
474) -> tuple[torch.Tensor, torch.Tensor | None]:
475 from flag_gems import flash_attn_varlen_func
477 assert kv_c_and_k_pe_cache.numel() > 0
478 assert attn_metadata.decode is not None
480 if type(q) is tuple:
481 q_nope, q_pe = q
482 else:
483 q_nope, q_pe = torch.split(
484 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
485 )
487 if self.kv_cache_dtype.startswith("fp8"):
488 raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
490 kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
491 k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
493 # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
494 # kernel uses this to calculate grid dimensions. Ensure it's at least 1
495 # to prevent invalid grid configuration during graph capture.
496 max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
498 attn_out = flash_attn_varlen_func(
499 q=q_pe,
500 k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
501 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
502 q_v=q_nope,
503 max_seqlen_q=max_seqlen_q,
504 cu_seqlens_q=attn_metadata.decode.query_start_loc,
505 max_seqlen_k=attn_metadata.decode.max_seq_len,
506 seqused_k=attn_metadata.decode.seq_lens,
507 block_table=attn_metadata.decode.block_table,
508 softmax_scale=self.scale,
509 causal=True,
510 return_softmax_lse=self.need_to_return_lse_for_decode,
511 fa_version=2,
512 scheduler_metadata=attn_metadata.decode.scheduler_metadata,
513 num_splits=0,
514 cp_world_size=self.dcp_world_size,
515 cp_rank=self.dcp_rank,
516 cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
517 )
519 if self.need_to_return_lse_for_decode:
520 o, lse = attn_out
521 # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
522 return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
523 else:
524 o = attn_out
525 return o, None
528# use gems flash attention in vit attention
529def patch_vllm_vit_to_attn(vitw):
530 if not hasattr(vitw, "vit_xformers_attn_wrapper"):
531 return
533 _orig_vit = vitw.vit_xformers_attn_wrapper
535 def _seqlens_to_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
536 cu_seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
537 return F.pad(cu_seqlens, (1, 0))
539 def _torch_sdpa_wrapper_gems(
540 q: torch.Tensor,
541 k: torch.Tensor,
542 v: torch.Tensor,
543 cu_seqlens: torch.Tensor,
544 ):
545 import flag_gems.ops.attention as gems_attn
547 outputs = []
548 for i in range(1, int(cu_seqlens.numel())):
549 start = int(cu_seqlens[i - 1].item())
550 end = int(cu_seqlens[i].item())
551 q_i = q[:, start:end]
552 k_i = k[:, start:end]
553 v_i = v[:, start:end]
555 out_i, *_ = gems_attn.flash_attention_forward(
556 q_i,
557 k_i,
558 v_i,
559 None,
560 None,
561 int(q_i.shape[1]),
562 int(k_i.shape[1]),
563 0.0,
564 False,
565 False,
566 scale=None,
567 softcap=0.0,
568 window_size_left=None,
569 window_size_right=None,
570 seqused_k=None,
571 alibi_slopes=None,
572 disable_splitkv=True,
573 )
574 outputs.append(out_i)
576 context_layer = torch.cat(outputs, dim=1)
577 x = context_layer.transpose(0, 1).contiguous()
578 return x.view(x.shape[0], x.shape[1], -1)
580 def _wrapped_vit_xformers_attn_wrapper(
581 q: torch.Tensor,
582 k: torch.Tensor,
583 v: torch.Tensor,
584 seqlens: torch.Tensor,
585 ) -> torch.Tensor:
586 if os.getenv("VIT_ATTN_BACKEND", "xformers") == "no-sdpa":
587 return _orig_vit(q, k, v, seqlens)
589 cu_seqlens = _seqlens_to_cu_seqlens(seqlens)
590 return _torch_sdpa_wrapper_gems(q, k, v, cu_seqlens)
592 vitw.vit_xformers_attn_wrapper = _wrapped_vit_xformers_attn_wrapper
595def custom_rms_norm_out(result, input, weight, epsilon):
596 from flag_gems.ops.rms_norm import rms_norm_out
598 rms_norm_out(result, input, list(weight.size()), weight, epsilon)
601def apply_gems_patches_to_vllm(verbose=True):
602 import vllm # noqa: F401
603 import vllm._custom_ops as ops # noqa: F401
605 try:
606 from vllm.attention.ops import vit_attn_wrappers as vitw
607 except (ModuleNotFoundError, ImportError):
608 vitw = None
609 from vllm.attention.ops.paged_attn import PagedAttention
610 from vllm.model_executor.layers.activation import SiluAndMul
611 from vllm.model_executor.layers.layernorm import RMSNorm
612 from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
613 from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl
614 from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLAImpl
615 from vllm.v1.attention.backends.mla.triton_mla import TritonMLAImpl
617 dispatch_key = flag_gems.runtime.device.dispatch_key
618 init_vllm_libraries()
620 module_patches = [
621 (RMSNorm, "forward_cuda", custom_gems_rms_forward_cuda),
622 (RotaryEmbedding, "forward_cuda", custom_gems_rope_forward_cuda),
623 (PagedAttention, "write_to_paged_cache", custom_gems_write_to_paged_cache),
624 (SiluAndMul, "forward_cuda", custom_gems_silu_and_mul),
625 (TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward),
626 (FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward),
627 (FlashAttnMLAImpl, "_forward_decode", custom_gems_flashattn_mla_forward_decode),
628 ]
629 for cls, method_name, new_method in module_patches:
630 patch_module_method(cls, method_name, new_method, verbose)
632 lib_patches = [
633 ("_C", "rms_norm", custom_rms_norm_out),
634 ("_C", "silu_and_mul", custom_silu_and_mul),
635 ("_C", "silu_and_mul_with_clamp", custom_silu_and_mul_with_clamp),
636 ("_C", "hc_head_fused_kernel", custom_hc_head_fused_kernel),
637 ("_C", "cutlass_scaled_mm", custom_cutlass_scaled_mm),
638 ("_moe_C", "moe_align_block_size", custom_moe_align_block_size),
639 ("_moe_C", "topk_softmax", custom_topk_softmax),
640 ("_moe_C", "moe_sum", custom_moe_sum),
641 ("_vllm_fa3_C", "get_scheduler_metadata", custom_get_scheduler_metadata),
642 ("_moe_C", "grouped_topk", custom_moe_grouped_topk),
643 ("_C", "per_token_group_fp8_quant", custom_per_token_group_fp8_quant),
644 ("_C", "apply_repetition_penalties_", custom_apply_repetition_penalties),
645 ("_C", "top_k_per_row_prefill", custom_top_k_per_row_prefill),
646 ("_C_cache_ops", "concat_and_cache_mla", custom_concat_and_cache_mla),
647 ]
648 for lib_name, fn_name, fn in lib_patches:
649 patch_vllm_lib(lib_name, fn_name, fn, dispatch_key, verbose)
651 if vitw is not None:
652 patch_vllm_vit_to_attn(vitw)