Coverage for src/flag_gems/fused/fused_marlin_moe.py: 20%
109 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# SPDX-License-Identifier: Apache-2.0
2"""
3Fused Marlin MoE for FlagGems.
5Aligns the interface of vLLM v0.20.0:
6 vllm/model_executor/layers/fused_moe/fused_marlin_moe.py :: fused_marlin_moe
8PHASE 2 (this file): bypass `fused_experts_impl`'s dequant-then-FP16-GEMM
9shortcut and dispatch directly to the wna16 Triton kernel
10(`fused_moe_kernel_gptq_awq`) for true fused-dequant W4A16/W8A16 GEMM.
12The local helper `_fused_marlin_moe_impl` mirrors `fused_experts_impl`'s
13orchestration (chunk loop, moe_align, two GEMMs, activation, reduction)
14but deletes the INT4/INT8 dequant branch and forwards `block_shape` so
15the wna16 path is actually taken.
17MVP scope:
18 - quant_type: GPTQ uint4b8 (INT4) and uint8b128 (INT8)
19 - activation: SwiGLU / SiLU
20 - act_order: NOT supported (g_idx / sort_indices must be None)
21 - FP8 input: NOT supported
22 - LoRA, clamp_limit, expert_map: NOT supported
23"""
24import functools
25from typing import Any, Callable, Optional
27import torch
28import triton.language as tl
30from flag_gems.fused.fused_moe import (
31 MoEActivation,
32 _get_config_dtype_str,
33 _get_config_quant_dtype,
34 apply_moe_activation,
35 dispatch_fused_moe_kernel,
36 moe_kernel_quantize_input,
37 try_get_optimal_moe_config,
38)
39from flag_gems.fused.moe_align_block_size import moe_align_block_size
40from flag_gems.fused.moe_sum import moe_sum
42# ----------------------------------------------------------------------------
43# quant_type_id constants — mirror a subset of vLLM scalar_types ids.
44# ----------------------------------------------------------------------------
45# GPTQ INT4 (weight stored as w + 8, dequant subtracts 8)
46QUANT_TYPE_UINT4B8 = 0
47# INT8 (weight stored as w + 128)
48QUANT_TYPE_UINT8B128 = 1
50_QUANT_TYPE_INT4 = {QUANT_TYPE_UINT4B8}
51_QUANT_TYPE_INT8 = {QUANT_TYPE_UINT8B128}
52_SUPPORTED_QUANT_TYPES = _QUANT_TYPE_INT4 | _QUANT_TYPE_INT8
55# ----------------------------------------------------------------------------
56# Phase-2 impl: copy of fused_experts_impl but with the dequant shortcut
57# removed so the wna16 Triton kernel is actually invoked for W4A16/W8A16.
58# ----------------------------------------------------------------------------
59def _fused_marlin_moe_impl(
60 hidden_states: torch.Tensor,
61 w1: torch.Tensor,
62 w2: torch.Tensor,
63 topk_weights: torch.Tensor,
64 topk_ids: torch.Tensor,
65 inplace: bool = False,
66 activation: str = "silu",
67 apply_router_weight_on_input: bool = False,
68 use_int8_w8a16: bool = False,
69 use_int4_w4a16: bool = False,
70 per_channel_quant: bool = False,
71 global_num_experts: int = -1,
72 expert_map: torch.Tensor | None = None,
73 w1_scale: Optional[torch.Tensor] = None,
74 w2_scale: Optional[torch.Tensor] = None,
75 w1_zp: torch.Tensor | None = None,
76 w2_zp: torch.Tensor | None = None,
77 block_shape: Optional[list[int]] = None,
78 w1_bias: Optional[torch.Tensor] = None,
79 w2_bias: Optional[torch.Tensor] = None,
80) -> torch.Tensor:
81 """
82 Like fused_experts_impl, but:
83 - drops all paths irrelevant to W4A16/W8A16 (no fp8, int8_w8a8, mxfp).
84 - REMOVES the `w = w.to(fp16) * scale.unsqueeze(-1)` dequant shortcut.
85 - forwards block_shape so the wna16 kernel uses the right group_size.
86 """
87 assert (
88 activation == "silu"
89 ), f"Only 'silu' activation is supported, got {activation}"
90 assert (
91 use_int4_w4a16 or use_int8_w8a16
92 ), "_fused_marlin_moe_impl expects a quantized path"
94 activation_enum = MoEActivation.from_str(activation)
96 # Packed-aware shape check.
97 # W4A16 (pack_factor=2): w1.size(2) == K // 2
98 # W8A16 (pack_factor=1): w1.size(2) == K
99 expected_packed_k = (
100 hidden_states.size(1) // 2 if use_int4_w4a16 else hidden_states.size(1)
101 )
102 assert w1.size(2) == expected_packed_k, (
103 f"w1 packed K mismatch: hidden_size={hidden_states.size(1)}, "
104 f"use_int4_w4a16={use_int4_w4a16}, expected w1.size(2)={expected_packed_k}, "
105 f"got {w1.size(2)}"
106 )
108 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
109 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
110 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
111 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
112 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
114 num_tokens = hidden_states.size(0)
115 E, N, _ = w1.size()
116 K = w2.size(1)
117 if global_num_experts == -1:
118 global_num_experts = E
119 top_k_num = topk_ids.size(1)
121 CHUNK_SIZE: int = 16 * 1024
122 M = min(num_tokens, CHUNK_SIZE)
124 config_dtype = _get_config_dtype_str(
125 use_fp8_w8a8=False,
126 use_int8_w8a16=use_int8_w8a16,
127 use_int4_w4a16=use_int4_w4a16,
128 ocp_mx_scheme=None,
129 dtype=hidden_states.dtype,
130 )
131 quant_dtype = _get_config_quant_dtype(
132 use_fp8_w8a8=False,
133 use_int8_w8a8=False,
134 ocp_mx_scheme=None,
135 )
137 get_config_func = functools.partial(
138 try_get_optimal_moe_config,
139 w1.size(),
140 w2.size(),
141 top_k_num,
142 config_dtype,
143 block_shape=block_shape,
144 E=E,
145 )
146 config = get_config_func(M)
147 config["SPLIT_K"] = 1
149 # cache1 and cache3 share memory (non-overlapping lifetime)
150 cache13 = torch.empty(
151 M * top_k_num * max(N, K),
152 device=hidden_states.device,
153 dtype=hidden_states.dtype,
154 )
155 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
156 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
158 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
159 intermediate_cache2 = torch.empty(
160 (M * top_k_num, activation_out_dim),
161 device=hidden_states.device,
162 dtype=hidden_states.dtype,
163 )
165 if hidden_states.dtype == torch.bfloat16:
166 compute_type = tl.bfloat16
167 elif hidden_states.dtype == torch.float16:
168 compute_type = tl.float16
169 elif hidden_states.dtype == torch.float32:
170 compute_type = tl.float32
171 else:
172 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
174 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
176 # ★ Phase-2 KEY DIFFERENCE: the W4A16/W8A16 dequant shortcut that lived
177 # here in `fused_experts_impl` is intentionally REMOVED. The wna16
178 # Triton kernel will consume INT4 weights + scale directly.
180 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
181 begin_chunk_idx, end_chunk_idx = (
182 chunk * CHUNK_SIZE,
183 min((chunk + 1) * CHUNK_SIZE, num_tokens),
184 )
185 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
186 tokens_in_chunk, _ = curr_hidden_states.size()
188 if tokens_in_chunk == 0:
189 break
191 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
192 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
193 intermediate_cache2 = intermediate_cache2[
194 : tokens_in_chunk * topk_ids.size(1)
195 ]
196 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
197 config = get_config_func(tokens_in_chunk)
198 config["SPLIT_K"] = 1
200 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
201 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
203 # Activation quantization is a no-op for W4A16/W8A16 (no input quant).
204 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
205 A=curr_hidden_states,
206 A_scale=None,
207 quant_dtype=quant_dtype,
208 per_act_token_quant=per_channel_quant,
209 block_shape=block_shape,
210 ocp_mx_scheme=None,
211 )
213 # Use the routed-path (skip the SPARSITY_FACTOR shortcut, which is
214 # explicitly disabled for quantized + block_shape configs anyway).
215 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
216 curr_topk_ids,
217 config["BLOCK_SIZE_M"],
218 global_num_experts,
219 expert_map,
220 )
222 # ----- GEMM 1: hidden @ w1 (fused dequant on B inside the kernel) -----
223 dispatch_fused_moe_kernel(
224 qcurr_hidden_states,
225 w1,
226 intermediate_cache1,
227 a1q_scale,
228 w1_scale,
229 w1_zp,
230 curr_topk_weights,
231 sorted_token_ids,
232 expert_ids,
233 num_tokens_post_padded,
234 apply_router_weight_on_input,
235 top_k_num,
236 config,
237 compute_type=compute_type,
238 use_fp8_w8a8=False,
239 use_int8_w8a8=False,
240 use_int8_w8a16=use_int8_w8a16,
241 use_int4_w4a16=use_int4_w4a16,
242 per_channel_quant=per_channel_quant,
243 block_shape=block_shape,
244 B_bias=w1_bias,
245 )
247 # ----- Activation: SwiGLU = silu(gate) * up -----
248 apply_moe_activation(
249 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
250 )
252 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
253 A=intermediate_cache2,
254 A_scale=None,
255 quant_dtype=quant_dtype,
256 per_act_token_quant=per_channel_quant,
257 block_shape=block_shape,
258 ocp_mx_scheme=None,
259 )
261 if expert_map is not None:
262 intermediate_cache3.zero_()
264 # ----- GEMM 2: act @ w2 (fused dequant on B inside the kernel) -----
265 dispatch_fused_moe_kernel(
266 qintermediate_cache2,
267 w2,
268 intermediate_cache3,
269 a2q_scale,
270 w2_scale,
271 w2_zp,
272 curr_topk_weights,
273 sorted_token_ids,
274 expert_ids,
275 num_tokens_post_padded,
276 not apply_router_weight_on_input,
277 1,
278 config,
279 compute_type=compute_type,
280 use_fp8_w8a8=False,
281 use_int8_w8a8=False,
282 use_int8_w8a16=use_int8_w8a16,
283 use_int4_w4a16=use_int4_w4a16,
284 per_channel_quant=per_channel_quant,
285 block_shape=block_shape,
286 B_bias=w2_bias,
287 )
289 # ----- Reduce: sum topk-weighted expert outputs back per token -----
290 moe_sum(
291 intermediate_cache3.view(*intermediate_cache3.size()),
292 out_hidden_states[begin_chunk_idx:end_chunk_idx],
293 )
295 return out_hidden_states
298# ----------------------------------------------------------------------------
299# Public entry point: vLLM-aligned wrapper.
300# ----------------------------------------------------------------------------
301def fused_marlin_moe(
302 hidden_states: torch.Tensor,
303 w1: torch.Tensor,
304 w2: torch.Tensor,
305 bias1: Optional[torch.Tensor],
306 bias2: Optional[torch.Tensor],
307 w1_scale: torch.Tensor,
308 w2_scale: torch.Tensor,
309 topk_weights: torch.Tensor,
310 topk_ids: torch.Tensor,
311 quant_type_id: int,
312 apply_router_weight_on_input: bool = False,
313 global_num_experts: int = -1,
314 activation: Any = None,
315 activation_func: Optional[Callable] = None,
316 moe_sum: Optional[Callable] = None,
317 expert_map: Optional[torch.Tensor] = None,
318 input_global_scale1: Optional[torch.Tensor] = None,
319 input_global_scale2: Optional[torch.Tensor] = None,
320 global_scale1: Optional[torch.Tensor] = None,
321 global_scale2: Optional[torch.Tensor] = None,
322 g_idx1: Optional[torch.Tensor] = None,
323 g_idx2: Optional[torch.Tensor] = None,
324 sort_indices1: Optional[torch.Tensor] = None,
325 sort_indices2: Optional[torch.Tensor] = None,
326 w1_zeros: Optional[torch.Tensor] = None,
327 w2_zeros: Optional[torch.Tensor] = None,
328 workspace: Optional[torch.Tensor] = None,
329 intermediate_cache13: Optional[torch.Tensor] = None,
330 intermediate_cache2: Optional[torch.Tensor] = None,
331 is_k_full: bool = True,
332 output: Optional[torch.Tensor] = None,
333 input_dtype: Optional[torch.dtype] = None,
334 inplace: bool = False,
335 clamp_limit: Optional[float] = None,
336 group_size: int = 128,
337) -> torch.Tensor:
338 """Phase-2 entry point: dispatch to local wna16-using impl."""
339 # ---- MVP guardrails --------------------------------------------------
340 if quant_type_id not in _SUPPORTED_QUANT_TYPES:
341 raise NotImplementedError(
342 f"MVP supports quant_type_id in {_SUPPORTED_QUANT_TYPES}, "
343 f"got {quant_type_id}"
344 )
345 if g_idx1 is not None or g_idx2 is not None:
346 raise NotImplementedError("act_order (g_idx) not yet supported in MVP")
347 if sort_indices1 is not None or sort_indices2 is not None:
348 raise NotImplementedError("act_order (sort_indices) not yet supported in MVP")
349 if input_dtype is not None:
350 raise NotImplementedError("FP8 / INT8 input quantization not supported")
351 if clamp_limit is not None:
352 raise NotImplementedError("clamp_limit (GLM-4 swiglu) not supported")
353 if input_global_scale1 is not None or input_global_scale2 is not None:
354 raise NotImplementedError("input_global_scale not supported in MVP")
355 if global_scale1 is not None or global_scale2 is not None:
356 raise NotImplementedError("global_scale not supported in MVP")
358 use_int4_w4a16 = quant_type_id in _QUANT_TYPE_INT4
359 use_int8_w8a16 = quant_type_id in _QUANT_TYPE_INT8
361 activation_str = "silu"
362 if activation is not None:
363 for attr in ("value", "name"):
364 v = getattr(activation, attr, None)
365 if isinstance(v, str):
366 activation_str = v.lower()
367 break
368 if isinstance(activation, str):
369 activation_str = activation.lower()
370 if activation_str != "silu":
371 raise NotImplementedError(
372 f"MVP only supports SiLU/SwiGLU activation, got {activation_str}"
373 )
375 if inplace and output is not None:
376 raise ValueError("Cannot pass both inplace=True and output")
378 result = _fused_marlin_moe_impl(
379 hidden_states=hidden_states,
380 w1=w1,
381 w2=w2,
382 topk_weights=topk_weights,
383 topk_ids=topk_ids,
384 inplace=inplace,
385 activation=activation_str,
386 apply_router_weight_on_input=apply_router_weight_on_input,
387 use_int4_w4a16=use_int4_w4a16,
388 use_int8_w8a16=use_int8_w8a16,
389 global_num_experts=global_num_experts,
390 expert_map=expert_map,
391 w1_scale=w1_scale,
392 w2_scale=w2_scale,
393 w1_zp=w1_zeros,
394 w2_zp=w2_zeros,
395 w1_bias=bias1,
396 w2_bias=bias2,
397 # Critical for Phase 2: block_shape=[0, group_size] makes the
398 # wna16 Triton kernel use the per-group scales correctly.
399 block_shape=[0, group_size],
400 )
402 if output is not None:
403 output.copy_(result)
404 return output
405 return result
408__all__ = ["fused_marlin_moe", "QUANT_TYPE_UINT4B8", "QUANT_TYPE_UINT8B128"]