Coverage for src/flag_gems/patches/patch_vllm_all.py: 15%
199 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import os
2from typing import Optional, Tuple
4import torch
5import torch.nn.functional as F
7import flag_gems
8from flag_gems.fused import top_k_per_row_prefill
9from flag_gems.patches.patch_util import patch_module_method, patch_vllm_lib
12def custom_gems_rms_forward_cuda(self, x, residual=None):
13 from flag_gems.modules.normalization import gems_rms_forward
15 return gems_rms_forward(x, residual, self.weight, self.variance_epsilon)
18def custom_gems_rope_forward_cuda(
19 self,
20 positions: torch.Tensor,
21 query: torch.Tensor,
22 key: torch.Tensor,
23 offsets: Optional[torch.Tensor] = None,
24) -> Tuple[torch.Tensor, torch.Tensor]:
25 from flag_gems.modules.rotary_embedding import gems_rope_forward
27 self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
28 if offsets is not None:
29 positions = positions + offsets
30 positions = positions.flatten()
31 num_tokens = positions.shape[0]
33 query_shape = query.shape
34 key_shape = key.shape
35 query = query.view(num_tokens, -1, self.head_size)
36 key = key.view(num_tokens, -1, self.head_size)
38 query_rot = query[..., : self.rotary_dim]
39 key_rot = key[..., : self.rotary_dim]
40 if self.rotary_dim < self.head_size:
41 query_pass = query[..., self.rotary_dim :]
42 key_pass = key[..., self.rotary_dim :]
44 cos, sin = self.cos_sin_cache.chunk(2, dim=-1)
46 q_embed, k_embed = gems_rope_forward(
47 query_rot,
48 key_rot,
49 cos,
50 sin,
51 position_ids=positions,
52 rotary_interleaved=not self.is_neox_style,
53 inplace=True, # set inplace to True for vLLM compatibility
54 )
56 if self.rotary_dim < self.head_size:
57 query = torch.cat((q_embed, query_pass), dim=-1).reshape(query_shape)
58 key = torch.cat((k_embed, key_pass), dim=-1).reshape(key_shape)
59 else:
60 query = q_embed.reshape(query_shape)
61 key = k_embed.reshape(key_shape)
63 return query, key
66def custom_gems_silu_and_mul(self, x: torch.Tensor) -> torch.Tensor:
67 from flag_gems.modules.activation import gems_silu_and_mul
69 d = x.shape[-1] // 2
70 x1, x2 = x[..., :d], x[..., d:]
71 return gems_silu_and_mul(x1, x2)
74def custom_gems_write_to_paged_cache(
75 key,
76 value,
77 key_cache,
78 value_cache,
79 slot_mapping,
80 kv_cache_dtype,
81 k_scale,
82 v_scale,
83):
84 from flag_gems.fused.reshape_and_cache import reshape_and_cache
86 reshape_and_cache(
87 key,
88 value,
89 key_cache,
90 value_cache,
91 slot_mapping.flatten(),
92 kv_cache_dtype,
93 k_scale,
94 v_scale,
95 )
98def custom_gems_flash_mla_forward(
99 self,
100 q_nope,
101 q_pe,
102 kv_c_and_k_pe_cache,
103 attn_metadata,
104) -> torch.Tensor:
105 from flag_gems.fused import flash_mla
107 assert kv_c_and_k_pe_cache.numel() > 0
108 assert attn_metadata.decode is not None
110 if self.kv_cache_dtype.startswith("fp8"):
111 raise NotImplementedError("FP8 Triton MLA not yet supported")
113 batch, num_head_q, head_dim_v = q_nope.shape
114 seqlen_q = 1
116 q = torch.cat([q_nope, q_pe], dim=-1)
117 head_dim = q.shape[-1]
118 q = q.view(batch, seqlen_q, num_head_q, head_dim)
120 # Add a head dim of 1
121 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
122 PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
124 block_table = attn_metadata.decode.block_table
125 output = flash_mla(
126 q,
127 block_table,
128 kv_c_and_k_pe_cache,
129 None,
130 PAGE_SIZE,
131 batch,
132 seqlen_q,
133 attn_metadata.decode.seq_lens,
134 num_head_q,
135 None,
136 head_dim,
137 head_dim_v,
138 True,
139 )
141 o = self._v_up_proj_and_o_proj(output)
142 return o
145def custom_gems_flash_attention_impl_forward(
146 self,
147 layer: torch.nn.Module,
148 query: torch.Tensor,
149 key: torch.Tensor,
150 value: torch.Tensor,
151 kv_cache: torch.Tensor,
152 attn_metadata, #: FlashAttentionMetadata,
153 output: Optional[torch.Tensor] = None,
154 output_scale: Optional[torch.Tensor] = None,
155 output_block_scale: Optional[torch.Tensor] = None,
156 **kwargs,
157) -> torch.Tensor:
158 from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash
160 assert output is not None, "Output tensor must be provided."
162 if output_scale is not None:
163 raise NotImplementedError(
164 "fused output quantization is not yet supported" " for FlashAttentionImpl"
165 )
167 if attn_metadata is None:
168 # Profiling run.
169 return output
171 num_actual_tokens = attn_metadata.num_actual_tokens
172 key_cache, value_cache = kv_cache.unbind(0)
174 reshape_and_cache_flash(
175 key,
176 value,
177 key_cache,
178 value_cache,
179 attn_metadata.slot_mapping,
180 self.kv_cache_dtype,
181 layer._k_scale,
182 layer._v_scale,
183 )
185 # TODO: Support FP8
186 if self.kv_cache_dtype.startswith("fp8"):
187 raise NotImplementedError(
188 "FP8 quantization is not yet supported for FlashAttentionImpl"
189 )
190 # key_cache = key_cache.view(torch.float8_e4m3fn)
191 # value_cache = value_cache.view(torch.float8_e4m3fn)
192 # num_tokens, num_heads, head_size = query.shape
193 # query, _ = ops.scaled_fp8_quant(
194 # query.reshape((num_tokens, num_heads * head_size)).contiguous(),
195 # layer._q_scale,
196 # )
197 # query = query.reshape((num_tokens, num_heads, head_size))
199 # Compute attention and update output up to `num_actual_tokens`.
200 # use_local_attn = self.use_irope and attn_metadata.local_attn_metadata is not None
201 use_local_attn = (
202 getattr(self, "use_irope", False)
203 and getattr(attn_metadata, "local_attn_metadata", None) is not None
204 )
205 if not attn_metadata.use_cascade or use_local_attn:
206 if use_local_attn:
207 assert attn_metadata.local_attn_metadata is not None
208 local_metadata = attn_metadata.local_attn_metadata
209 cu_seqlens_q = local_metadata.local_query_start_loc
210 seqused_k = local_metadata.local_seqused_k
211 max_seqlen_q = local_metadata.local_max_query_len
212 max_seqlen_k = local_metadata.local_max_seq_len
213 block_table = local_metadata.local_block_table
214 scheduler_metadata = local_metadata.local_scheduler_metadata
215 else:
216 cu_seqlens_q = attn_metadata.query_start_loc
217 seqused_k = attn_metadata.seq_lens
218 max_seqlen_q = attn_metadata.max_query_len
219 max_seqlen_k = attn_metadata.max_seq_len
220 block_table = attn_metadata.block_table
221 scheduler_metadata = attn_metadata.scheduler_metadata
223 descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
225 flash_attn_varlen_func(
226 q=query[:num_actual_tokens],
227 k=key_cache,
228 v=value_cache,
229 out=output[:num_actual_tokens],
230 cu_seqlens_q=cu_seqlens_q,
231 max_seqlen_q=max_seqlen_q,
232 seqused_k=seqused_k,
233 max_seqlen_k=max_seqlen_k,
234 softmax_scale=self.scale,
235 causal=True,
236 alibi_slopes=self.alibi_slopes,
237 window_size=self.sliding_window,
238 block_table=block_table,
239 softcap=self.logits_soft_cap,
240 scheduler_metadata=scheduler_metadata,
241 fa_version=2,
242 q_descale=layer._q_scale.expand(descale_shape),
243 k_descale=layer._k_scale.expand(descale_shape),
244 v_descale=layer._v_scale.expand(descale_shape),
245 s_aux=None,
246 num_splits=0,
247 cp_world_size=1,
248 cp_rank=0,
249 cp_tot_seqused_k=None,
250 )
251 return output
253 # TODO: Support cascade_attention.
254 raise NotImplementedError("Cascade attention is not implemented in flag_gems.")
257def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor):
258 d = input.size(-1) // 2
259 x, y = input.split(d, dim=-1)
260 flag_gems.silu_and_mul_out(x, y, out)
263def custom_silu_and_mul_with_clamp(
264 out: torch.Tensor, input: torch.Tensor, limit: float
265):
266 d = input.size(-1) // 2
267 x, y = input.split(d, dim=-1)
268 flag_gems.silu_and_mul_with_clamp_out(x, y, out, limit)
271def custom_hc_head_fused_kernel(
272 hs_flat: torch.Tensor,
273 fn: torch.Tensor,
274 hc_scale: torch.Tensor,
275 hc_base: torch.Tensor,
276 out: torch.Tensor,
277 hidden_size: int,
278 rms_eps: float,
279 hc_eps: float,
280 hc_mult: int,
281):
282 return flag_gems.hc_head_fused_kernel(
283 hs_flat, fn, hc_scale, hc_base, out, hidden_size, rms_eps, hc_eps, hc_mult
284 )
287def custom_moe_align_block_size(
288 topk_ids: torch.Tensor,
289 num_experts: int,
290 block_size: int,
291 sorted_token_ids: torch.Tensor,
292 experts_ids: torch.Tensor,
293 num_tokens_post_pad: torch.Tensor,
294):
295 flag_gems.moe_align_block_size_triton(
296 topk_ids,
297 num_experts,
298 block_size,
299 sorted_token_ids,
300 experts_ids,
301 num_tokens_post_pad,
302 )
305def custom_moe_grouped_topk(
306 gating_output: torch.Tensor,
307 n_group: int,
308 topk_group: int,
309 topk: int,
310 renormalize: bool,
311 routed_scaling_factor: float,
312 bias: torch.Tensor,
313 scoring_func: int = 0,
314):
315 from flag_gems.fused import grouped_topk
317 return grouped_topk(
318 scores=gating_output,
319 n_group=n_group,
320 topk_group=topk_group,
321 topk=topk,
322 renormalize=renormalize,
323 routed_scaling_factor=routed_scaling_factor,
324 bias=bias,
325 scoring_func=scoring_func,
326 )
329def custom_topk_softmax(
330 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False
331):
332 flag_gems.topk_softmax(
333 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
334 )
337def custom_moe_sum(input: torch.Tensor, output: torch.Tensor):
338 from flag_gems.fused import moe_sum
340 moe_sum(input, output)
343def custom_apply_repetition_penalties(
344 logits: torch.Tensor,
345 prompt_mask: torch.Tensor,
346 output_mask: torch.Tensor,
347 repetition_penalties: torch.Tensor,
348):
349 return flag_gems.apply_repetition_penalties(
350 logits, prompt_mask, output_mask, repetition_penalties
351 )
354def custom_get_scheduler_metadata(
355 batch_size: int,
356 max_seqlen_q: int,
357 max_seqlen_k: int,
358 num_heads: int,
359 num_heads_k: int,
360 headdim: int,
361 headdim_v: int,
362 qkv_dtype: torch.dtype,
363 seqused_k: torch.Tensor,
364 cu_seqlens_q: Optional[torch.Tensor] = None,
365 cu_seqlens_k: Optional[torch.Tensor] = None,
366 cu_seqlens_k_new: Optional[torch.Tensor] = None,
367 seqused_q: Optional[torch.Tensor] = None,
368 leftpad_k: Optional[torch.Tensor] = None,
369 page_size: Optional[int] = None,
370 max_seqlen_k_new: int = 0,
371 is_causal: bool = False,
372 window_size_left: int = -1,
373 window_size_right: int = -1,
374 has_softcap: bool = False,
375 num_splits: int = 0,
376 pack_gqa: Optional[bool] = None,
377 sm_margin: int = 0,
378):
379 return flag_gems.get_scheduler_metadata(
380 batch_size,
381 max_seqlen_q,
382 max_seqlen_k,
383 num_heads,
384 num_heads_k,
385 headdim,
386 headdim_v,
387 qkv_dtype,
388 seqused_k,
389 cu_seqlens_q=cu_seqlens_q,
390 cu_seqlens_k=cu_seqlens_k,
391 cu_seqlens_k_new=cu_seqlens_k_new,
392 seqused_q=seqused_q,
393 leftpad_k=leftpad_k,
394 page_size=page_size,
395 max_seqlen_k_new=max_seqlen_k_new,
396 is_causal=is_causal,
397 window_size_left=window_size_left,
398 window_size_right=window_size_right,
399 has_softcap=has_softcap,
400 num_splits=num_splits,
401 pack_gqa=pack_gqa,
402 sm_margin=sm_margin,
403 )
406def custom_per_token_group_fp8_quant(
407 input: torch.Tensor,
408 output_q: torch.Tensor,
409 output_s: torch.Tensor,
410 group_size: int,
411 eps: float,
412 fp8_min: float,
413 fp8_max: float,
414 scale_ue8m0: bool = False,
415):
416 from flag_gems.ops import per_token_group_quant_fp8
418 column_major_scales = output_s.stride(0) < output_s.stride(1)
420 x_q, x_s = per_token_group_quant_fp8(
421 x=input,
422 group_size=group_size,
423 eps=eps,
424 column_major_scales=column_major_scales,
425 scale_ue8m0=scale_ue8m0,
426 )
428 output_q.copy_(x_q)
429 output_s.copy_(x_s)
432def custom_cutlass_scaled_mm(
433 output: torch.Tensor,
434 input: torch.Tensor,
435 weight: torch.Tensor,
436 scale_a: torch.Tensor,
437 scale_b: torch.Tensor,
438 bias: torch.Tensor | None = None,
439):
440 return flag_gems.cutlass_scaled_mm(output, input, weight, scale_a, scale_b, bias)
443def custom_top_k_per_row_prefill(
444 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k
445):
446 top_k_per_row_prefill(
447 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k
448 )
451def custom_concat_and_cache_mla(
452 kv_c: torch.Tensor,
453 k_pe: torch.Tensor,
454 kv_cache: torch.Tensor,
455 slot_mapping: torch.Tensor,
456 kv_cache_dtype: str,
457 scale: torch.Tensor,
458) -> None:
459 return flag_gems.concat_and_cache_mla(
460 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
461 )
464def custom_gems_flashattn_mla_forward_decode(
465 self,
466 q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
467 kv_c_and_k_pe_cache: torch.Tensor,
468 attn_metadata, # FlashAttnMLAMetadata
469 layer, # AttentionLayer
470) -> tuple[torch.Tensor, torch.Tensor | None]:
471 from flag_gems import flash_attn_varlen_func
473 assert kv_c_and_k_pe_cache.numel() > 0
474 assert attn_metadata.decode is not None
476 if type(q) is tuple:
477 q_nope, q_pe = q
478 else:
479 q_nope, q_pe = torch.split(
480 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
481 )
483 if self.kv_cache_dtype.startswith("fp8"):
484 raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
486 kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
487 k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
489 # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
490 # kernel uses this to calculate grid dimensions. Ensure it's at least 1
491 # to prevent invalid grid configuration during graph capture.
492 max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
494 attn_out = flash_attn_varlen_func(
495 q=q_pe,
496 k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
497 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
498 q_v=q_nope,
499 max_seqlen_q=max_seqlen_q,
500 cu_seqlens_q=attn_metadata.decode.query_start_loc,
501 max_seqlen_k=attn_metadata.decode.max_seq_len,
502 seqused_k=attn_metadata.decode.seq_lens,
503 block_table=attn_metadata.decode.block_table,
504 softmax_scale=self.scale,
505 causal=True,
506 return_softmax_lse=self.need_to_return_lse_for_decode,
507 fa_version=2,
508 scheduler_metadata=attn_metadata.decode.scheduler_metadata,
509 num_splits=0,
510 cp_world_size=self.dcp_world_size,
511 cp_rank=self.dcp_rank,
512 cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
513 )
515 if self.need_to_return_lse_for_decode:
516 o, lse = attn_out
517 # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
518 return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
519 else:
520 o = attn_out
521 return o, None
524# use gems flash attention in vit attention
525def patch_vllm_vit_to_attn(vitw):
526 if not hasattr(vitw, "vit_xformers_attn_wrapper"):
527 return
529 _orig_vit = vitw.vit_xformers_attn_wrapper
531 def _seqlens_to_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
532 cu_seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
533 return F.pad(cu_seqlens, (1, 0))
535 def _torch_sdpa_wrapper_gems(
536 q: torch.Tensor,
537 k: torch.Tensor,
538 v: torch.Tensor,
539 cu_seqlens: torch.Tensor,
540 ):
541 import flag_gems.ops.attention as gems_attn
543 outputs = []
544 for i in range(1, int(cu_seqlens.numel())):
545 start = int(cu_seqlens[i - 1].item())
546 end = int(cu_seqlens[i].item())
547 q_i = q[:, start:end]
548 k_i = k[:, start:end]
549 v_i = v[:, start:end]
551 out_i, *_ = gems_attn.flash_attention_forward(
552 q_i,
553 k_i,
554 v_i,
555 None,
556 None,
557 int(q_i.shape[1]),
558 int(k_i.shape[1]),
559 0.0,
560 False,
561 False,
562 scale=None,
563 softcap=0.0,
564 window_size_left=None,
565 window_size_right=None,
566 seqused_k=None,
567 alibi_slopes=None,
568 disable_splitkv=True,
569 )
570 outputs.append(out_i)
572 context_layer = torch.cat(outputs, dim=1)
573 x = context_layer.transpose(0, 1).contiguous()
574 return x.view(x.shape[0], x.shape[1], -1)
576 def _wrapped_vit_xformers_attn_wrapper(
577 q: torch.Tensor,
578 k: torch.Tensor,
579 v: torch.Tensor,
580 seqlens: torch.Tensor,
581 ) -> torch.Tensor:
582 if os.getenv("VIT_ATTN_BACKEND", "xformers") == "no-sdpa":
583 return _orig_vit(q, k, v, seqlens)
585 cu_seqlens = _seqlens_to_cu_seqlens(seqlens)
586 return _torch_sdpa_wrapper_gems(q, k, v, cu_seqlens)
588 vitw.vit_xformers_attn_wrapper = _wrapped_vit_xformers_attn_wrapper
591def custom_rms_norm_out(result, input, weight, epsilon):
592 from flag_gems.ops.rms_norm import rms_norm_out
594 rms_norm_out(result, input, list(weight.size()), weight, epsilon)
597def apply_gems_patches_to_vllm(verbose=True):
598 import vllm # noqa: F401
599 import vllm._custom_ops as ops # noqa: F401
601 try:
602 from vllm.attention.ops import vit_attn_wrappers as vitw
603 except (ModuleNotFoundError, ImportError):
604 vitw = None
605 from vllm.attention.ops.paged_attn import PagedAttention
606 from vllm.model_executor.layers.activation import SiluAndMul
607 from vllm.model_executor.layers.layernorm import RMSNorm
608 from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
609 from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl
610 from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLAImpl
611 from vllm.v1.attention.backends.mla.triton_mla import TritonMLAImpl
613 dispatch_key = flag_gems.runtime.device.dispatch_key
615 module_patches = [
616 (RMSNorm, "forward_cuda", custom_gems_rms_forward_cuda),
617 (RotaryEmbedding, "forward_cuda", custom_gems_rope_forward_cuda),
618 (PagedAttention, "write_to_paged_cache", custom_gems_write_to_paged_cache),
619 (SiluAndMul, "forward_cuda", custom_gems_silu_and_mul),
620 (TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward),
621 (FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward),
622 (FlashAttnMLAImpl, "_forward_decode", custom_gems_flashattn_mla_forward_decode),
623 ]
624 for cls, method_name, new_method in module_patches:
625 patch_module_method(cls, method_name, new_method, verbose)
627 lib_patches = [
628 ("_C", "rms_norm", custom_rms_norm_out),
629 ("_C", "silu_and_mul", custom_silu_and_mul),
630 ("_C", "silu_and_mul_with_clamp", custom_silu_and_mul_with_clamp),
631 ("_C", "hc_head_fused_kernel", custom_hc_head_fused_kernel),
632 ("_C", "cutlass_scaled_mm", custom_cutlass_scaled_mm),
633 ("_moe_C", "moe_align_block_size", custom_moe_align_block_size),
634 ("_moe_C", "topk_softmax", custom_topk_softmax),
635 ("_moe_C", "moe_sum", custom_moe_sum),
636 ("_vllm_fa3_C", "get_scheduler_metadata", custom_get_scheduler_metadata),
637 ("_moe_C", "grouped_topk", custom_moe_grouped_topk),
638 ("_C", "per_token_group_fp8_quant", custom_per_token_group_fp8_quant),
639 ("_C", "apply_repetition_penalties_", custom_apply_repetition_penalties),
640 ("_C", "top_k_per_row_prefill", custom_top_k_per_row_prefill),
641 ("_C_cache_ops", "concat_and_cache_mla", custom_concat_and_cache_mla),
642 ]
643 for lib_name, fn_name, fn in lib_patches:
644 patch_vllm_lib(lib_name, fn_name, fn, dispatch_key, verbose)
646 if vitw is not None:
647 patch_vllm_vit_to_attn(vitw)