Coverage for src/flag_gems/patches/patch_vllm_all.py: 15%

199 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import os 

2from typing import Optional, Tuple 

3 

4import torch 

5import torch.nn.functional as F 

6 

7import flag_gems 

8from flag_gems.fused import top_k_per_row_prefill 

9from flag_gems.patches.patch_util import patch_module_method, patch_vllm_lib 

10 

11 

12def custom_gems_rms_forward_cuda(self, x, residual=None): 

13 from flag_gems.modules.normalization import gems_rms_forward 

14 

15 return gems_rms_forward(x, residual, self.weight, self.variance_epsilon) 

16 

17 

18def custom_gems_rope_forward_cuda( 

19 self, 

20 positions: torch.Tensor, 

21 query: torch.Tensor, 

22 key: torch.Tensor, 

23 offsets: Optional[torch.Tensor] = None, 

24) -> Tuple[torch.Tensor, torch.Tensor]: 

25 from flag_gems.modules.rotary_embedding import gems_rope_forward 

26 

27 self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) 

28 if offsets is not None: 

29 positions = positions + offsets 

30 positions = positions.flatten() 

31 num_tokens = positions.shape[0] 

32 

33 query_shape = query.shape 

34 key_shape = key.shape 

35 query = query.view(num_tokens, -1, self.head_size) 

36 key = key.view(num_tokens, -1, self.head_size) 

37 

38 query_rot = query[..., : self.rotary_dim] 

39 key_rot = key[..., : self.rotary_dim] 

40 if self.rotary_dim < self.head_size: 

41 query_pass = query[..., self.rotary_dim :] 

42 key_pass = key[..., self.rotary_dim :] 

43 

44 cos, sin = self.cos_sin_cache.chunk(2, dim=-1) 

45 

46 q_embed, k_embed = gems_rope_forward( 

47 query_rot, 

48 key_rot, 

49 cos, 

50 sin, 

51 position_ids=positions, 

52 rotary_interleaved=not self.is_neox_style, 

53 inplace=True, # set inplace to True for vLLM compatibility 

54 ) 

55 

56 if self.rotary_dim < self.head_size: 

57 query = torch.cat((q_embed, query_pass), dim=-1).reshape(query_shape) 

58 key = torch.cat((k_embed, key_pass), dim=-1).reshape(key_shape) 

59 else: 

60 query = q_embed.reshape(query_shape) 

61 key = k_embed.reshape(key_shape) 

62 

63 return query, key 

64 

65 

66def custom_gems_silu_and_mul(self, x: torch.Tensor) -> torch.Tensor: 

67 from flag_gems.modules.activation import gems_silu_and_mul 

68 

69 d = x.shape[-1] // 2 

70 x1, x2 = x[..., :d], x[..., d:] 

71 return gems_silu_and_mul(x1, x2) 

72 

73 

74def custom_gems_write_to_paged_cache( 

75 key, 

76 value, 

77 key_cache, 

78 value_cache, 

79 slot_mapping, 

80 kv_cache_dtype, 

81 k_scale, 

82 v_scale, 

83): 

84 from flag_gems.fused.reshape_and_cache import reshape_and_cache 

85 

86 reshape_and_cache( 

87 key, 

88 value, 

89 key_cache, 

90 value_cache, 

91 slot_mapping.flatten(), 

92 kv_cache_dtype, 

93 k_scale, 

94 v_scale, 

95 ) 

96 

97 

98def custom_gems_flash_mla_forward( 

99 self, 

100 q_nope, 

101 q_pe, 

102 kv_c_and_k_pe_cache, 

103 attn_metadata, 

104) -> torch.Tensor: 

105 from flag_gems.fused import flash_mla 

106 

107 assert kv_c_and_k_pe_cache.numel() > 0 

108 assert attn_metadata.decode is not None 

109 

110 if self.kv_cache_dtype.startswith("fp8"): 

111 raise NotImplementedError("FP8 Triton MLA not yet supported") 

112 

113 batch, num_head_q, head_dim_v = q_nope.shape 

114 seqlen_q = 1 

115 

116 q = torch.cat([q_nope, q_pe], dim=-1) 

117 head_dim = q.shape[-1] 

118 q = q.view(batch, seqlen_q, num_head_q, head_dim) 

119 

120 # Add a head dim of 1 

121 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) 

122 PAGE_SIZE = kv_c_and_k_pe_cache.size(1) 

123 

124 block_table = attn_metadata.decode.block_table 

125 output = flash_mla( 

126 q, 

127 block_table, 

128 kv_c_and_k_pe_cache, 

129 None, 

130 PAGE_SIZE, 

131 batch, 

132 seqlen_q, 

133 attn_metadata.decode.seq_lens, 

134 num_head_q, 

135 None, 

136 head_dim, 

137 head_dim_v, 

138 True, 

139 ) 

140 

141 o = self._v_up_proj_and_o_proj(output) 

142 return o 

143 

144 

145def custom_gems_flash_attention_impl_forward( 

146 self, 

147 layer: torch.nn.Module, 

148 query: torch.Tensor, 

149 key: torch.Tensor, 

150 value: torch.Tensor, 

151 kv_cache: torch.Tensor, 

152 attn_metadata, #: FlashAttentionMetadata, 

153 output: Optional[torch.Tensor] = None, 

154 output_scale: Optional[torch.Tensor] = None, 

155 output_block_scale: Optional[torch.Tensor] = None, 

156 **kwargs, 

157) -> torch.Tensor: 

158 from flag_gems import flash_attn_varlen_func, reshape_and_cache_flash 

159 

160 assert output is not None, "Output tensor must be provided." 

161 

162 if output_scale is not None: 

163 raise NotImplementedError( 

164 "fused output quantization is not yet supported" " for FlashAttentionImpl" 

165 ) 

166 

167 if attn_metadata is None: 

168 # Profiling run. 

169 return output 

170 

171 num_actual_tokens = attn_metadata.num_actual_tokens 

172 key_cache, value_cache = kv_cache.unbind(0) 

173 

174 reshape_and_cache_flash( 

175 key, 

176 value, 

177 key_cache, 

178 value_cache, 

179 attn_metadata.slot_mapping, 

180 self.kv_cache_dtype, 

181 layer._k_scale, 

182 layer._v_scale, 

183 ) 

184 

185 # TODO: Support FP8 

186 if self.kv_cache_dtype.startswith("fp8"): 

187 raise NotImplementedError( 

188 "FP8 quantization is not yet supported for FlashAttentionImpl" 

189 ) 

190 # key_cache = key_cache.view(torch.float8_e4m3fn) 

191 # value_cache = value_cache.view(torch.float8_e4m3fn) 

192 # num_tokens, num_heads, head_size = query.shape 

193 # query, _ = ops.scaled_fp8_quant( 

194 # query.reshape((num_tokens, num_heads * head_size)).contiguous(), 

195 # layer._q_scale, 

196 # ) 

197 # query = query.reshape((num_tokens, num_heads, head_size)) 

198 

199 # Compute attention and update output up to `num_actual_tokens`. 

200 # use_local_attn = self.use_irope and attn_metadata.local_attn_metadata is not None 

201 use_local_attn = ( 

202 getattr(self, "use_irope", False) 

203 and getattr(attn_metadata, "local_attn_metadata", None) is not None 

204 ) 

205 if not attn_metadata.use_cascade or use_local_attn: 

206 if use_local_attn: 

207 assert attn_metadata.local_attn_metadata is not None 

208 local_metadata = attn_metadata.local_attn_metadata 

209 cu_seqlens_q = local_metadata.local_query_start_loc 

210 seqused_k = local_metadata.local_seqused_k 

211 max_seqlen_q = local_metadata.local_max_query_len 

212 max_seqlen_k = local_metadata.local_max_seq_len 

213 block_table = local_metadata.local_block_table 

214 scheduler_metadata = local_metadata.local_scheduler_metadata 

215 else: 

216 cu_seqlens_q = attn_metadata.query_start_loc 

217 seqused_k = attn_metadata.seq_lens 

218 max_seqlen_q = attn_metadata.max_query_len 

219 max_seqlen_k = attn_metadata.max_seq_len 

220 block_table = attn_metadata.block_table 

221 scheduler_metadata = attn_metadata.scheduler_metadata 

222 

223 descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) 

224 

225 flash_attn_varlen_func( 

226 q=query[:num_actual_tokens], 

227 k=key_cache, 

228 v=value_cache, 

229 out=output[:num_actual_tokens], 

230 cu_seqlens_q=cu_seqlens_q, 

231 max_seqlen_q=max_seqlen_q, 

232 seqused_k=seqused_k, 

233 max_seqlen_k=max_seqlen_k, 

234 softmax_scale=self.scale, 

235 causal=True, 

236 alibi_slopes=self.alibi_slopes, 

237 window_size=self.sliding_window, 

238 block_table=block_table, 

239 softcap=self.logits_soft_cap, 

240 scheduler_metadata=scheduler_metadata, 

241 fa_version=2, 

242 q_descale=layer._q_scale.expand(descale_shape), 

243 k_descale=layer._k_scale.expand(descale_shape), 

244 v_descale=layer._v_scale.expand(descale_shape), 

245 s_aux=None, 

246 num_splits=0, 

247 cp_world_size=1, 

248 cp_rank=0, 

249 cp_tot_seqused_k=None, 

250 ) 

251 return output 

252 

253 # TODO: Support cascade_attention. 

254 raise NotImplementedError("Cascade attention is not implemented in flag_gems.") 

255 

256 

257def custom_silu_and_mul(out: torch.Tensor, input: torch.Tensor): 

258 d = input.size(-1) // 2 

259 x, y = input.split(d, dim=-1) 

260 flag_gems.silu_and_mul_out(x, y, out) 

261 

262 

263def custom_silu_and_mul_with_clamp( 

264 out: torch.Tensor, input: torch.Tensor, limit: float 

265): 

266 d = input.size(-1) // 2 

267 x, y = input.split(d, dim=-1) 

268 flag_gems.silu_and_mul_with_clamp_out(x, y, out, limit) 

269 

270 

271def custom_hc_head_fused_kernel( 

272 hs_flat: torch.Tensor, 

273 fn: torch.Tensor, 

274 hc_scale: torch.Tensor, 

275 hc_base: torch.Tensor, 

276 out: torch.Tensor, 

277 hidden_size: int, 

278 rms_eps: float, 

279 hc_eps: float, 

280 hc_mult: int, 

281): 

282 return flag_gems.hc_head_fused_kernel( 

283 hs_flat, fn, hc_scale, hc_base, out, hidden_size, rms_eps, hc_eps, hc_mult 

284 ) 

285 

286 

287def custom_moe_align_block_size( 

288 topk_ids: torch.Tensor, 

289 num_experts: int, 

290 block_size: int, 

291 sorted_token_ids: torch.Tensor, 

292 experts_ids: torch.Tensor, 

293 num_tokens_post_pad: torch.Tensor, 

294): 

295 flag_gems.moe_align_block_size_triton( 

296 topk_ids, 

297 num_experts, 

298 block_size, 

299 sorted_token_ids, 

300 experts_ids, 

301 num_tokens_post_pad, 

302 ) 

303 

304 

305def custom_moe_grouped_topk( 

306 gating_output: torch.Tensor, 

307 n_group: int, 

308 topk_group: int, 

309 topk: int, 

310 renormalize: bool, 

311 routed_scaling_factor: float, 

312 bias: torch.Tensor, 

313 scoring_func: int = 0, 

314): 

315 from flag_gems.fused import grouped_topk 

316 

317 return grouped_topk( 

318 scores=gating_output, 

319 n_group=n_group, 

320 topk_group=topk_group, 

321 topk=topk, 

322 renormalize=renormalize, 

323 routed_scaling_factor=routed_scaling_factor, 

324 bias=bias, 

325 scoring_func=scoring_func, 

326 ) 

327 

328 

329def custom_topk_softmax( 

330 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize=False 

331): 

332 flag_gems.topk_softmax( 

333 topk_weights, topk_indices, token_expert_indices, gating_output, renormalize 

334 ) 

335 

336 

337def custom_moe_sum(input: torch.Tensor, output: torch.Tensor): 

338 from flag_gems.fused import moe_sum 

339 

340 moe_sum(input, output) 

341 

342 

343def custom_apply_repetition_penalties( 

344 logits: torch.Tensor, 

345 prompt_mask: torch.Tensor, 

346 output_mask: torch.Tensor, 

347 repetition_penalties: torch.Tensor, 

348): 

349 return flag_gems.apply_repetition_penalties( 

350 logits, prompt_mask, output_mask, repetition_penalties 

351 ) 

352 

353 

354def custom_get_scheduler_metadata( 

355 batch_size: int, 

356 max_seqlen_q: int, 

357 max_seqlen_k: int, 

358 num_heads: int, 

359 num_heads_k: int, 

360 headdim: int, 

361 headdim_v: int, 

362 qkv_dtype: torch.dtype, 

363 seqused_k: torch.Tensor, 

364 cu_seqlens_q: Optional[torch.Tensor] = None, 

365 cu_seqlens_k: Optional[torch.Tensor] = None, 

366 cu_seqlens_k_new: Optional[torch.Tensor] = None, 

367 seqused_q: Optional[torch.Tensor] = None, 

368 leftpad_k: Optional[torch.Tensor] = None, 

369 page_size: Optional[int] = None, 

370 max_seqlen_k_new: int = 0, 

371 is_causal: bool = False, 

372 window_size_left: int = -1, 

373 window_size_right: int = -1, 

374 has_softcap: bool = False, 

375 num_splits: int = 0, 

376 pack_gqa: Optional[bool] = None, 

377 sm_margin: int = 0, 

378): 

379 return flag_gems.get_scheduler_metadata( 

380 batch_size, 

381 max_seqlen_q, 

382 max_seqlen_k, 

383 num_heads, 

384 num_heads_k, 

385 headdim, 

386 headdim_v, 

387 qkv_dtype, 

388 seqused_k, 

389 cu_seqlens_q=cu_seqlens_q, 

390 cu_seqlens_k=cu_seqlens_k, 

391 cu_seqlens_k_new=cu_seqlens_k_new, 

392 seqused_q=seqused_q, 

393 leftpad_k=leftpad_k, 

394 page_size=page_size, 

395 max_seqlen_k_new=max_seqlen_k_new, 

396 is_causal=is_causal, 

397 window_size_left=window_size_left, 

398 window_size_right=window_size_right, 

399 has_softcap=has_softcap, 

400 num_splits=num_splits, 

401 pack_gqa=pack_gqa, 

402 sm_margin=sm_margin, 

403 ) 

404 

405 

406def custom_per_token_group_fp8_quant( 

407 input: torch.Tensor, 

408 output_q: torch.Tensor, 

409 output_s: torch.Tensor, 

410 group_size: int, 

411 eps: float, 

412 fp8_min: float, 

413 fp8_max: float, 

414 scale_ue8m0: bool = False, 

415): 

416 from flag_gems.ops import per_token_group_quant_fp8 

417 

418 column_major_scales = output_s.stride(0) < output_s.stride(1) 

419 

420 x_q, x_s = per_token_group_quant_fp8( 

421 x=input, 

422 group_size=group_size, 

423 eps=eps, 

424 column_major_scales=column_major_scales, 

425 scale_ue8m0=scale_ue8m0, 

426 ) 

427 

428 output_q.copy_(x_q) 

429 output_s.copy_(x_s) 

430 

431 

432def custom_cutlass_scaled_mm( 

433 output: torch.Tensor, 

434 input: torch.Tensor, 

435 weight: torch.Tensor, 

436 scale_a: torch.Tensor, 

437 scale_b: torch.Tensor, 

438 bias: torch.Tensor | None = None, 

439): 

440 return flag_gems.cutlass_scaled_mm(output, input, weight, scale_a, scale_b, bias) 

441 

442 

443def custom_top_k_per_row_prefill( 

444 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k 

445): 

446 top_k_per_row_prefill( 

447 logits, row_starts, row_ends, indices, num_rows, stride0, stride1, top_k 

448 ) 

449 

450 

451def custom_concat_and_cache_mla( 

452 kv_c: torch.Tensor, 

453 k_pe: torch.Tensor, 

454 kv_cache: torch.Tensor, 

455 slot_mapping: torch.Tensor, 

456 kv_cache_dtype: str, 

457 scale: torch.Tensor, 

458) -> None: 

459 return flag_gems.concat_and_cache_mla( 

460 kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale 

461 ) 

462 

463 

464def custom_gems_flashattn_mla_forward_decode( 

465 self, 

466 q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], 

467 kv_c_and_k_pe_cache: torch.Tensor, 

468 attn_metadata, # FlashAttnMLAMetadata 

469 layer, # AttentionLayer 

470) -> tuple[torch.Tensor, torch.Tensor | None]: 

471 from flag_gems import flash_attn_varlen_func 

472 

473 assert kv_c_and_k_pe_cache.numel() > 0 

474 assert attn_metadata.decode is not None 

475 

476 if type(q) is tuple: 

477 q_nope, q_pe = q 

478 else: 

479 q_nope, q_pe = torch.split( 

480 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 

481 ) 

482 

483 if self.kv_cache_dtype.startswith("fp8"): 

484 raise NotImplementedError("FP8 FlashAttention MLA not yet supported") 

485 

486 kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] 

487 k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] 

488 

489 # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the 

490 # kernel uses this to calculate grid dimensions. Ensure it's at least 1 

491 # to prevent invalid grid configuration during graph capture. 

492 max_seqlen_q = max(attn_metadata.decode.max_query_len, 1) 

493 

494 attn_out = flash_attn_varlen_func( 

495 q=q_pe, 

496 k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 

497 v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 

498 q_v=q_nope, 

499 max_seqlen_q=max_seqlen_q, 

500 cu_seqlens_q=attn_metadata.decode.query_start_loc, 

501 max_seqlen_k=attn_metadata.decode.max_seq_len, 

502 seqused_k=attn_metadata.decode.seq_lens, 

503 block_table=attn_metadata.decode.block_table, 

504 softmax_scale=self.scale, 

505 causal=True, 

506 return_softmax_lse=self.need_to_return_lse_for_decode, 

507 fa_version=2, 

508 scheduler_metadata=attn_metadata.decode.scheduler_metadata, 

509 num_splits=0, 

510 cp_world_size=self.dcp_world_size, 

511 cp_rank=self.dcp_rank, 

512 cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, 

513 ) 

514 

515 if self.need_to_return_lse_for_decode: 

516 o, lse = attn_out 

517 # FA returns LSE in shape [ H, B ] but DCP wants [ B, H ] 

518 return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ] 

519 else: 

520 o = attn_out 

521 return o, None 

522 

523 

524# use gems flash attention in vit attention 

525def patch_vllm_vit_to_attn(vitw): 

526 if not hasattr(vitw, "vit_xformers_attn_wrapper"): 

527 return 

528 

529 _orig_vit = vitw.vit_xformers_attn_wrapper 

530 

531 def _seqlens_to_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: 

532 cu_seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32) 

533 return F.pad(cu_seqlens, (1, 0)) 

534 

535 def _torch_sdpa_wrapper_gems( 

536 q: torch.Tensor, 

537 k: torch.Tensor, 

538 v: torch.Tensor, 

539 cu_seqlens: torch.Tensor, 

540 ): 

541 import flag_gems.ops.attention as gems_attn 

542 

543 outputs = [] 

544 for i in range(1, int(cu_seqlens.numel())): 

545 start = int(cu_seqlens[i - 1].item()) 

546 end = int(cu_seqlens[i].item()) 

547 q_i = q[:, start:end] 

548 k_i = k[:, start:end] 

549 v_i = v[:, start:end] 

550 

551 out_i, *_ = gems_attn.flash_attention_forward( 

552 q_i, 

553 k_i, 

554 v_i, 

555 None, 

556 None, 

557 int(q_i.shape[1]), 

558 int(k_i.shape[1]), 

559 0.0, 

560 False, 

561 False, 

562 scale=None, 

563 softcap=0.0, 

564 window_size_left=None, 

565 window_size_right=None, 

566 seqused_k=None, 

567 alibi_slopes=None, 

568 disable_splitkv=True, 

569 ) 

570 outputs.append(out_i) 

571 

572 context_layer = torch.cat(outputs, dim=1) 

573 x = context_layer.transpose(0, 1).contiguous() 

574 return x.view(x.shape[0], x.shape[1], -1) 

575 

576 def _wrapped_vit_xformers_attn_wrapper( 

577 q: torch.Tensor, 

578 k: torch.Tensor, 

579 v: torch.Tensor, 

580 seqlens: torch.Tensor, 

581 ) -> torch.Tensor: 

582 if os.getenv("VIT_ATTN_BACKEND", "xformers") == "no-sdpa": 

583 return _orig_vit(q, k, v, seqlens) 

584 

585 cu_seqlens = _seqlens_to_cu_seqlens(seqlens) 

586 return _torch_sdpa_wrapper_gems(q, k, v, cu_seqlens) 

587 

588 vitw.vit_xformers_attn_wrapper = _wrapped_vit_xformers_attn_wrapper 

589 

590 

591def custom_rms_norm_out(result, input, weight, epsilon): 

592 from flag_gems.ops.rms_norm import rms_norm_out 

593 

594 rms_norm_out(result, input, list(weight.size()), weight, epsilon) 

595 

596 

597def apply_gems_patches_to_vllm(verbose=True): 

598 import vllm # noqa: F401 

599 import vllm._custom_ops as ops # noqa: F401 

600 

601 try: 

602 from vllm.attention.ops import vit_attn_wrappers as vitw 

603 except (ModuleNotFoundError, ImportError): 

604 vitw = None 

605 from vllm.attention.ops.paged_attn import PagedAttention 

606 from vllm.model_executor.layers.activation import SiluAndMul 

607 from vllm.model_executor.layers.layernorm import RMSNorm 

608 from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding 

609 from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl 

610 from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLAImpl 

611 from vllm.v1.attention.backends.mla.triton_mla import TritonMLAImpl 

612 

613 dispatch_key = flag_gems.runtime.device.dispatch_key 

614 

615 module_patches = [ 

616 (RMSNorm, "forward_cuda", custom_gems_rms_forward_cuda), 

617 (RotaryEmbedding, "forward_cuda", custom_gems_rope_forward_cuda), 

618 (PagedAttention, "write_to_paged_cache", custom_gems_write_to_paged_cache), 

619 (SiluAndMul, "forward_cuda", custom_gems_silu_and_mul), 

620 (TritonMLAImpl, "_forward_decode", custom_gems_flash_mla_forward), 

621 (FlashAttentionImpl, "forward", custom_gems_flash_attention_impl_forward), 

622 (FlashAttnMLAImpl, "_forward_decode", custom_gems_flashattn_mla_forward_decode), 

623 ] 

624 for cls, method_name, new_method in module_patches: 

625 patch_module_method(cls, method_name, new_method, verbose) 

626 

627 lib_patches = [ 

628 ("_C", "rms_norm", custom_rms_norm_out), 

629 ("_C", "silu_and_mul", custom_silu_and_mul), 

630 ("_C", "silu_and_mul_with_clamp", custom_silu_and_mul_with_clamp), 

631 ("_C", "hc_head_fused_kernel", custom_hc_head_fused_kernel), 

632 ("_C", "cutlass_scaled_mm", custom_cutlass_scaled_mm), 

633 ("_moe_C", "moe_align_block_size", custom_moe_align_block_size), 

634 ("_moe_C", "topk_softmax", custom_topk_softmax), 

635 ("_moe_C", "moe_sum", custom_moe_sum), 

636 ("_vllm_fa3_C", "get_scheduler_metadata", custom_get_scheduler_metadata), 

637 ("_moe_C", "grouped_topk", custom_moe_grouped_topk), 

638 ("_C", "per_token_group_fp8_quant", custom_per_token_group_fp8_quant), 

639 ("_C", "apply_repetition_penalties_", custom_apply_repetition_penalties), 

640 ("_C", "top_k_per_row_prefill", custom_top_k_per_row_prefill), 

641 ("_C_cache_ops", "concat_and_cache_mla", custom_concat_and_cache_mla), 

642 ] 

643 for lib_name, fn_name, fn in lib_patches: 

644 patch_vllm_lib(lib_name, fn_name, fn, dispatch_key, verbose) 

645 

646 if vitw is not None: 

647 patch_vllm_vit_to_attn(vitw)