Coverage for src/flag_gems/patches/patch_vllm_all.py: 15%

200 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import os 

2from typing import Optional, Tuple 

3 

4import torch 

5import torch.nn.functional as F 

6 

7import flag_gems 

8from flag_gems.fused import top_k_per_row_prefill 

9from flag_gems.patches.patch_util import ( 

10 init_vllm_libraries, 

11 patch_module_method, 

12 patch_vllm_lib, 

13) 

14 

15 

16def custom_gems_rms_forward_cuda(self, x, residual=None): 

17 from flag_gems.modules.normalization import gems_rms_forward 

18 

19 return gems_rms_forward(x, residual, self.weight, self.variance_epsilon) 

20 

21 

22def custom_gems_rope_forward_cuda( 

23 self, 

24 positions: torch.Tensor, 

25 query: torch.Tensor, 

26 key: torch.Tensor, 

27 offsets: Optional[torch.Tensor] = None, 

28) -> Tuple[torch.Tensor, torch.Tensor]: 

29 from flag_gems.modules.rotary_embedding import gems_rope_forward 

30 

31 self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) 

32 if offsets is not None: 

33 positions = positions + offsets 

34 positions = positions.flatten() 

35 num_tokens = positions.shape[0] 

36 

37 query_shape = query.shape 

38 key_shape = key.shape 

39 query = query.view(num_tokens, -1, self.head_size) 

40 key = key.view(num_tokens, -1, self.head_size) 

41 

42 query_rot = query[..., : self.rotary_dim] 

43 key_rot = key[..., : self.rotary_dim] 

44 if self.rotary_dim < self.head_size: 

45 query_pass = query[..., self.rotary_dim :] 

46 key_pass = key[..., self.rotary_dim :] 

47 

48 cos, sin = self.cos_sin_cache.chunk(2, dim=-1) 

49 

50 q_embed, k_embed = gems_rope_forward( 

51 query_rot, 

52 key_rot, 

53 cos, 

54 sin, 

55 position_ids=positions, 

56 rotary_interleaved=not self.is_neox_style, 

57 inplace=True, # set inplace to True for vLLM compatibility 

58 ) 

59 

60 if self.rotary_dim < self.head_size: 

61 query = torch.cat((q_embed, query_pass), dim=-1).reshape(query_shape) 

62 key = torch.cat((k_embed, key_pass), dim=-1).reshape(key_shape) 

63 else: 

64 query = q_embed.reshape(query_shape) 

65 key = k_embed.reshape(key_shape) 

66 

67 return query, key 

68 

69 

70def custom_gems_silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: 

71 from flag_gems.modules.activation import gems_silu_and_mul 

72 

73 d = x.shape[-1] // 2 

74 x1, x2 = x[..., :d], x[..., d:] 

75 return gems_silu_and_mul(x1, x2) 

76 

77 

78def custom_gems_write_to_paged_cache( 

79 key, 

80 value, 

81 key_cache, 

82 value_cache, 

83 slot_mapping, 

84 kv_cache_dtype, 

85 k_scale, 

86 v_scale, 

87): 

88 from flag_gems.fused.reshape_and_cache import reshape_and_cache 

89 

90 reshape_and_cache( 

91 key, 

92 value, 

93 key_cache, 

94 value_cache, 

95 slot_mapping.flatten(), 

96 kv_cache_dtype, 

97 k_scale, 

98 v_scale, 

99 ) 

100 

101 

102def custom_gems_flash_mla_forward( 

103 self, 

104 q_nope, 

105 q_pe, 

106 kv_c_and_k_pe_cache, 

107 attn_metadata, 

108) -> torch.Tensor: 

109 from flag_gems.fused import flash_mla 

110 

111 assert kv_c_and_k_pe_cache.numel() > 0 

112 assert attn_metadata.decode is not None 

113 

114 if self.kv_cache_dtype.startswith("fp8"): 

115 raise NotImplementedError("FP8 Triton MLA not yet supported") 

116 

117 batch, num_head_q, head_dim_v = q_nope.shape 

118 seqlen_q = 1 

119 

120 q = torch.cat([q_nope, q_pe], dim=-1) 

121 head_dim = q.shape[-1] 

122 q = q.view(batch, seqlen_q, num_head_q, head_dim) 

123 

124 # Add a head dim of 1 

125 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) 

126 PAGE_SIZE = kv_c_and_k_pe_cache.size(1) 

127 

128 block_table = attn_metadata.decode.block_table 

129 output = flash_mla( 

130 q, 

131 block_table, 

132 kv_c_and_k_pe_cache, 

133 None, 

134 PAGE_SIZE, 

135 batch, 

136 seqlen_q, 

137 attn_metadata.decode.seq_lens, 

138 num_head_q, 

139 None, 

140 head_dim, 

141 head_dim_v, 

142 True, 

143 ) 

144 

145 o = self._v_up_proj_and_o_proj(output) 

146 return o 

147 

148 

149def custom_gems_flash_attention_impl_forward( 

150 self, 

151 layer: torch.nn.Module, 

152 query: torch.Tensor, 

153 key: torch.Tensor, 

154 value: torch.Tensor, 

155 kv_cache: torch.Tensor, 

156 attn_metadata, #: FlashAttentionMetadata, 

157 output: Optional[torch.Tensor] = None, 

158 output_scale: Optional[torch.Tensor] = None, 

159 output_block_scale: Optional[torch.Tensor] = None, 

160 **kwargs, 

161) -> torch.Tensor: 

162 from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash 

163 

164 assert output is not None, "Output tensor must be provided." 

165 

166 if output_scale is not None: 

167 raise NotImplementedError( 

168 "fused output quantization is not yet supported" " for FlashAttentionImpl" 

169 ) 

170 

171 if attn_metadata is None: 

172 # Profiling run. 

173 return output 

174 

175 num_actual_tokens = attn_metadata.num_actual_tokens 

176 key_cache, value_cache = kv_cache.unbind(0) 

177 

178 reshape_and_cache_flash( 

179 key, 

180 value, 

181 key_cache, 

182 value_cache, 

183 attn_metadata.slot_mapping, 

184 self.kv_cache_dtype, 

185 layer._k_scale, 

186 layer._v_scale, 

187 ) 

188 

189 # TODO: Support FP8 

190 if self.kv_cache_dtype.startswith("fp8"): 

191 raise NotImplementedError( 

192 "FP8 quantization is not yet supported for FlashAttentionImpl" 

193 ) 

194 # key_cache = key_cache.view(torch.float8_e4m3fn) 

195 # value_cache = value_cache.view(torch.float8_e4m3fn) 

196 # num_tokens, num_heads, head_size = query.shape 

197 # query, _ = ops.scaled_fp8_quant( 

198 # query.reshape((num_tokens, num_heads * head_size)).contiguous(), 

199 # layer._q_scale, 

200 # ) 

201 # query = query.reshape((num_tokens, num_heads, head_size)) 

202 

203 # Compute attention and update output up to `num_actual_tokens`. 

204 # use_local_attn = self.use_irope and attn_metadata.local_attn_metadata is not None 

205 use_local_attn = ( 

206 getattr(self, "use_irope", False) 

207 and getattr(attn_metadata, "local_attn_metadata", None) is not None 

208 ) 

209 if not attn_metadata.use_cascade or use_local_attn: 

210 if use_local_attn: 

211 assert attn_metadata.local_attn_metadata is not None 

212 local_metadata = attn_metadata.local_attn_metadata 

213 cu_seqlens_q = local_metadata.local_query_start_loc 

214 seqused_k = local_metadata.local_seqused_k 

215 max_seqlen_q = local_metadata.local_max_query_len 

216 max_seqlen_k = local_metadata.local_max_seq_len 

217 block_table = local_metadata.local_block_table 

218 scheduler_metadata = local_metadata.local_scheduler_metadata 

219 else: 

220 cu_seqlens_q = attn_metadata.query_start_loc 

221 seqused_k = attn_metadata.seq_lens 

222 max_seqlen_q = attn_metadata.max_query_len 

223 max_seqlen_k = attn_metadata.max_seq_len 

224 block_table = attn_metadata.block_table 

225 scheduler_metadata = attn_metadata.scheduler_metadata 

226 

227 descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) 

228 

229 flash_attn_varlen_func( 

230 q=query[:num_actual_tokens], 

231 k=key_cache, 

232 v=value_cache, 

233 out=output[:num_actual_tokens], 

234 cu_seqlens_q=cu_seqlens_q, 

235 max_seqlen_q=max_seqlen_q, 

236 seqused_k=seqused_k, 

237 max_seqlen_k=max_seqlen_k, 

238 softmax_scale=self.scale, 

239 causal=True, 

240 alibi_slopes=self.alibi_slopes, 

241 window_size=self.sliding_window, 

242 block_table=block_table, 

243 softcap=self.logits_soft_cap, 

244 scheduler_metadata=scheduler_metadata, 

245 fa_version=2, 

246 q_descale=layer._q_scale.expand(descale_shape), 

247 k_descale=layer._k_scale.expand(descale_shape), 

248 v_descale=layer._v_scale.expand(descale_shape), 

249 s_aux=None, 

250 num_splits=0, 

251 cp_world_size=1, 

252 cp_rank=0, 

253 cp_tot_seqused_k=None, 

254 ) 

255 return output 

256 

257 # TODO: Support cascade_attention. 

258 raise NotImplementedError("Cascade attention is not implemented in flag_gems.") 

259 

260 

261def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor): 

262 d = input.size(-1) // 2 

263 x, y = input.split(d, dim=-1) 

264 flag_gems.silu_and_mul_out(x, y, out) 

265 

266 

267def custom_silu_and_mul_with_clamp( 

268 out: torch.Tensor, input: torch.Tensor, limit: float 

269): 

270 d = input.size(-1) // 2 

271 x, y = input.split(d, dim=-1) 

272 flag_gems.silu_and_mul_with_clamp_out(x, y, out, limit) 

273 

274 

275def custom_hc_head_fused_kernel( 

276 hs_flat: torch.Tensor, 

277 fn: torch.Tensor, 

278 hc_scale: torch.Tensor, 

279 hc_base: torch.Tensor, 

280 out: torch.Tensor, 

281 hidden_size: int, 

282 rms_eps: float, 

283 hc_eps: float, 

284 hc_mult: int, 

285): 

286 return flag_gems.hc_head_fused_kernel( 

287 hs_flat, fn, hc_scale, hc_base, out, hidden_size, rms_eps, hc_eps, hc_mult 

288 ) 

289 

290 

291def custom_moe_align_block_size( 

292 topk_ids: torch.Tensor, 

293 num_experts: int, 

294 block_size: int, 

295 sorted_token_ids: torch.Tensor, 

296 experts_ids: torch.Tensor, 

297 num_tokens_post_pad: torch.Tensor, 

298): 

299 flag_gems.moe_align_block_size_triton( 

300 topk_ids, 

301 num_experts, 

302 block_size, 

303 sorted_token_ids, 

304 experts_ids, 

305 num_tokens_post_pad, 

306 ) 

307 

308 

309def custom_moe_grouped_topk( 

310 gating_output: torch.Tensor, 

311 n_group: int, 

312 topk_group: int, 

313 topk: int, 

314 renormalize: bool, 

315 routed_scaling_factor: float, 

316 bias: torch.Tensor, 

317 scoring_func: int = 0, 

318): 

319 from flag_gems.fused import grouped_topk 

320 

321 return grouped_topk( 

322 scores=gating_output, 

323 n_group=n_group, 

324 topk_group=topk_group, 

325 topk=topk, 

326 renormalize=renormalize, 

327 routed_scaling_factor=routed_scaling_factor, 

328 bias=bias, 

329 scoring_func=scoring_func, 

330 ) 

331 

332 

333def custom_topk_softmax( 

334 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False 

335): 

336 flag_gems.topk_softmax( 

337 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize 

338 ) 

339 

340 

341def custom_moe_sum(input: torch.Tensor, output: torch.Tensor): 

342 from flag_gems.fused import moe_sum 

343 

344 moe_sum(input, output) 

345 

346 

347def custom_apply_repetition_penalties( 

348 logits: torch.Tensor, 

349 prompt_mask: torch.Tensor, 

350 output_mask: torch.Tensor, 

351 repetition_penalties: torch.Tensor, 

352): 

353 return flag_gems.apply_repetition_penalties( 

354 logits, prompt_mask, output_mask, repetition_penalties 

355 ) 

356 

357 

358def custom_get_scheduler_metadata( 

359 batch_size: int, 

360 max_seqlen_q: int, 

361 max_seqlen_k: int, 

362 num_heads: int, 

363 num_heads_k: int, 

364 headdim: int, 

365 headdim_v: int, 

366 qkv_dtype: torch.dtype, 

367 seqused_k: torch.Tensor, 

368 cu_seqlens_q: Optional[torch.Tensor] = None, 

369 cu_seqlens_k: Optional[torch.Tensor] = None, 

370 cu_seqlens_k_new: Optional[torch.Tensor] = None, 

371 seqused_q: Optional[torch.Tensor] = None, 

372 leftpad_k: Optional[torch.Tensor] = None, 

373 page_size: Optional[int] = None, 

374 max_seqlen_k_new: int = 0, 

375 is_causal: bool = False, 

376 window_size_left: int = -1, 

377 window_size_right: int = -1, 

378 has_softcap: bool = False, 

379 num_splits: int = 0, 

380 pack_gqa: Optional[bool] = None, 

381 sm_margin: int = 0, 

382): 

383 return flag_gems.get_scheduler_metadata( 

384 batch_size, 

385 max_seqlen_q, 

386 max_seqlen_k, 

387 num_heads, 

388 num_heads_k, 

389 headdim, 

390 headdim_v, 

391 qkv_dtype, 

392 seqused_k, 

393 cu_seqlens_q=cu_seqlens_q, 

394 cu_seqlens_k=cu_seqlens_k, 

395 cu_seqlens_k_new=cu_seqlens_k_new, 

396 seqused_q=seqused_q, 

397 leftpad_k=leftpad_k, 

398 page_size=page_size, 

399 max_seqlen_k_new=max_seqlen_k_new, 

400 is_causal=is_causal, 

401 window_size_left=window_size_left, 

402 window_size_right=window_size_right, 

403 has_softcap=has_softcap, 

404 num_splits=num_splits, 

405 pack_gqa=pack_gqa, 

406 sm_margin=sm_margin, 

407 ) 

408 

409 

410def custom_per_token_group_fp8_quant( 

411 input: torch.Tensor, 

412 output_q: torch.Tensor, 

413 output_s: torch.Tensor, 

414 group_size: int, 

415 eps: float, 

416 fp8_min: float, 

417 fp8_max: float, 

418 scale_ue8m0: bool = False, 

419): 

420 from flag_gems.ops import per_token_group_quant_fp8 

421 

422 column_major_scales = output_s.stride(0) < output_s.stride(1) 

423 

424 x_q, x_s = per_token_group_quant_fp8( 

425 x=input, 

426 group_size=group_size, 

427 eps=eps, 

428 column_major_scales=column_major_scales, 

429 scale_ue8m0=scale_ue8m0, 

430 ) 

431 

432 output_q.copy_(x_q) 

433 output_s.copy_(x_s) 

434 

435 

436def custom_cutlass_scaled_mm( 

437 output: torch.Tensor, 

438 input: torch.Tensor, 

439 weight: torch.Tensor, 

440 scale_a: torch.Tensor, 

441 scale_b: torch.Tensor, 

442 bias: torch.Tensor | None = None, 

443): 

444 return flag_gems.cutlass_scaled_mm(output, input, weight, scale_a, scale_b, bias) 

445 

446 

447def custom_top_k_per_row_prefill( 

448 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k 

449): 

450 top_k_per_row_prefill( 

451 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k 

452 ) 

453 

454 

455def custom_concat_and_cache_mla( 

456 kv_c: torch.Tensor, 

457 k_pe: torch.Tensor, 

458 kv_cache: torch.Tensor, 

459 slot_mapping: torch.Tensor, 

460 kv_cache_dtype: str, 

461 scale: torch.Tensor, 

462) -> None: 

463 return flag_gems.concat_and_cache_mla( 

464 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale 

465 ) 

466 

467 

468def custom_gems_flashattn_mla_forward_decode( 

469 self, 

470 q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], 

471 kv_c_and_k_pe_cache: torch.Tensor, 

472 attn_metadata, # FlashAttnMLAMetadata 

473 layer, # AttentionLayer 

474) -> tuple[torch.Tensor, torch.Tensor | None]: 

475 from flag_gems import flash_attn_varlen_func 

476 

477 assert kv_c_and_k_pe_cache.numel() > 0 

478 assert attn_metadata.decode is not None 

479 

480 if type(q) is tuple: 

481 q_nope, q_pe = q 

482 else: 

483 q_nope, q_pe = torch.split( 

484 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 

485 ) 

486 

487 if self.kv_cache_dtype.startswith("fp8"): 

488 raise NotImplementedError("FP8 FlashAttention MLA not yet supported") 

489 

490 kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] 

491 k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] 

492 

493 # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the 

494 # kernel uses this to calculate grid dimensions. Ensure it's at least 1 

495 # to prevent invalid grid configuration during graph capture. 

496 max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) 

497 

498 attn_out = flash_attn_varlen_func( 

499 q=q_pe, 

500 k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 

501 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 

502 q_v=q_nope, 

503 max_seqlen_q=max_seqlen_q, 

504 cu_seqlens_q=attn_metadata.decode.query_start_loc, 

505 max_seqlen_k=attn_metadata.decode.max_seq_len, 

506 seqused_k=attn_metadata.decode.seq_lens, 

507 block_table=attn_metadata.decode.block_table, 

508 softmax_scale=self.scale, 

509 causal=True, 

510 return_softmax_lse=self.need_to_return_lse_for_decode, 

511 fa_version=2, 

512 scheduler_metadata=attn_metadata.decode.scheduler_metadata, 

513 num_splits=0, 

514 cp_world_size=self.dcp_world_size, 

515 cp_rank=self.dcp_rank, 

516 cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, 

517 ) 

518 

519 if self.need_to_return_lse_for_decode: 

520 o, lse = attn_out 

521 # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] 

522 return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] 

523 else: 

524 o = attn_out 

525 return o, None 

526 

527 

528# use gems flash attention in vit attention 

529def patch_vllm_vit_to_attn(vitw): 

530 if not hasattr(vitw, "vit_xformers_attn_wrapper"): 

531 return 

532 

533 _orig_vit = vitw.vit_xformers_attn_wrapper 

534 

535 def _seqlens_to_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: 

536 cu_seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32) 

537 return F.pad(cu_seqlens, (1, 0)) 

538 

539 def _torch_sdpa_wrapper_gems( 

540 q: torch.Tensor, 

541 k: torch.Tensor, 

542 v: torch.Tensor, 

543 cu_seqlens: torch.Tensor, 

544 ): 

545 import flag_gems.ops.attention as gems_attn 

546 

547 outputs = [] 

548 for i in range(1, int(cu_seqlens.numel())): 

549 start = int(cu_seqlens[i - 1].item()) 

550 end = int(cu_seqlens[i].item()) 

551 q_i = q[:, start:end] 

552 k_i = k[:, start:end] 

553 v_i = v[:, start:end] 

554 

555 out_i, *_ = gems_attn.flash_attention_forward( 

556 q_i, 

557 k_i, 

558 v_i, 

559 None, 

560 None, 

561 int(q_i.shape[1]), 

562 int(k_i.shape[1]), 

563 0.0, 

564 False, 

565 False, 

566 scale=None, 

567 softcap=0.0, 

568 window_size_left=None, 

569 window_size_right=None, 

570 seqused_k=None, 

571 alibi_slopes=None, 

572 disable_splitkv=True, 

573 ) 

574 outputs.append(out_i) 

575 

576 context_layer = torch.cat(outputs, dim=1) 

577 x = context_layer.transpose(0, 1).contiguous() 

578 return x.view(x.shape[0], x.shape[1], -1) 

579 

580 def _wrapped_vit_xformers_attn_wrapper( 

581 q: torch.Tensor, 

582 k: torch.Tensor, 

583 v: torch.Tensor, 

584 seqlens: torch.Tensor, 

585 ) -> torch.Tensor: 

586 if os.getenv("VIT_ATTN_BACKEND", "xformers") == "no-sdpa": 

587 return _orig_vit(q, k, v, seqlens) 

588 

589 cu_seqlens = _seqlens_to_cu_seqlens(seqlens) 

590 return _torch_sdpa_wrapper_gems(q, k, v, cu_seqlens) 

591 

592 vitw.vit_xformers_attn_wrapper = _wrapped_vit_xformers_attn_wrapper 

593 

594 

595def custom_rms_norm_out(result, input, weight, epsilon): 

596 from flag_gems.ops.rms_norm import rms_norm_out 

597 

598 rms_norm_out(result, input, list(weight.size()), weight, epsilon) 

599 

600 

601def apply_gems_patches_to_vllm(verbose=True): 

602 import vllm # noqa: F401 

603 import vllm._custom_ops as ops # noqa: F401 

604 

605 try: 

606 from vllm.attention.ops import vit_attn_wrappers as vitw 

607 except (ModuleNotFoundError, ImportError): 

608 vitw = None 

609 from vllm.attention.ops.paged_attn import PagedAttention 

610 from vllm.model_executor.layers.activation import SiluAndMul 

611 from vllm.model_executor.layers.layernorm import RMSNorm 

612 from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding 

613 from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl 

614 from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLAImpl 

615 from vllm.v1.attention.backends.mla.triton_mla import TritonMLAImpl 

616 

617 dispatch_key = flag_gems.runtime.device.dispatch_key 

618 init_vllm_libraries() 

619 

620 module_patches = [ 

621 (RMSNorm, "forward_cuda", custom_gems_rms_forward_cuda), 

622 (RotaryEmbedding, "forward_cuda", custom_gems_rope_forward_cuda), 

623 (PagedAttention, "write_to_paged_cache", custom_gems_write_to_paged_cache), 

624 (SiluAndMul, "forward_cuda", custom_gems_silu_and_mul), 

625 (TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward), 

626 (FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward), 

627 (FlashAttnMLAImpl, "_forward_decode", custom_gems_flashattn_mla_forward_decode), 

628 ] 

629 for cls, method_name, new_method in module_patches: 

630 patch_module_method(cls, method_name, new_method, verbose) 

631 

632 lib_patches = [ 

633 ("_C", "rms_norm", custom_rms_norm_out), 

634 ("_C", "silu_and_mul", custom_silu_and_mul), 

635 ("_C", "silu_and_mul_with_clamp", custom_silu_and_mul_with_clamp), 

636 ("_C", "hc_head_fused_kernel", custom_hc_head_fused_kernel), 

637 ("_C", "cutlass_scaled_mm", custom_cutlass_scaled_mm), 

638 ("_moe_C", "moe_align_block_size", custom_moe_align_block_size), 

639 ("_moe_C", "topk_softmax", custom_topk_softmax), 

640 ("_moe_C", "moe_sum", custom_moe_sum), 

641 ("_vllm_fa3_C", "get_scheduler_metadata", custom_get_scheduler_metadata), 

642 ("_moe_C", "grouped_topk", custom_moe_grouped_topk), 

643 ("_C", "per_token_group_fp8_quant", custom_per_token_group_fp8_quant), 

644 ("_C", "apply_repetition_penalties_", custom_apply_repetition_penalties), 

645 ("_C", "top_k_per_row_prefill", custom_top_k_per_row_prefill), 

646 ("_C_cache_ops", "concat_and_cache_mla", custom_concat_and_cache_mla), 

647 ] 

648 for lib_name, fn_name, fn in lib_patches: 

649 patch_vllm_lib(lib_name, fn_name, fn, dispatch_key, verbose) 

650 

651 if vitw is not None: 

652 patch_vllm_vit_to_attn(vitw)