Coverage for src/flag_gems/fused/fused_marlin_moe.py: 34%
326 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# SPDX-License-Identifier: Apache-2.0
2"""
3Fused Marlin MoE for FlagGems.
5Aligns the interface of vLLM v0.20.0:
6 vllm/model_executor/layers/fused_moe/fused_marlin_moe.py :: fused_marlin_moe
8PHASE 2 (this file): bypass `fused_experts_impl`'s dequant-then-FP16-GEMM
9shortcut and dispatch directly to the wna16 Triton kernel
10(`fused_moe_kernel_gptq_awq`) for true fused-dequant W4A16/W8A16 GEMM.
12The local helper `_fused_marlin_moe_impl` mirrors `fused_experts_impl`'s
13orchestration (chunk loop, moe_align, two GEMMs, activation, reduction)
14but deletes the INT4/INT8 dequant branch and forwards `block_shape` so
15the wna16 path is actually taken.
17MVP scope:
18 - quant_type: GPTQ uint4b8 (INT4) and uint8b128 (INT8)
19 - activation: SwiGLU / SiLU
20 - act_order: NOT supported (g_idx / sort_indices must be None)
21 - FP8 input: NOT supported
22 - LoRA, clamp_limit, expert_map: NOT supported
23"""
24import functools
25from typing import Any, Callable, Optional, Tuple
27import torch
28import triton
29import triton.language as tl
30from torch.utils.weak import WeakTensorKeyDictionary
32from flag_gems.fused.fused_moe import (
33 MoEActivation,
34 _get_config_dtype_str,
35 _get_config_quant_dtype,
36 apply_moe_activation,
37 dispatch_fused_moe_kernel,
38 moe_kernel_quantize_input,
39 try_get_optimal_moe_config,
40 write_zeros_to_output,
41)
42from flag_gems.fused.moe_align_block_size import moe_align_block_size
43from flag_gems.fused.moe_sum import moe_sum
44from flag_gems.fused.silu_and_mul import silu_and_mul_out
46# ----------------------------------------------------------------------------
47# quant_type_id constants — mirror a subset of vLLM scalar_types ids.
48# ----------------------------------------------------------------------------
49# GPTQ INT4 (weight stored as w + 8, dequant subtracts 8)
50QUANT_TYPE_UINT4B8 = 0
51# INT8 (weight stored as w + 128)
52QUANT_TYPE_UINT8B128 = 1
54_QUANT_TYPE_INT4 = {QUANT_TYPE_UINT4B8}
55_QUANT_TYPE_INT8 = {QUANT_TYPE_UINT8B128}
56_SUPPORTED_QUANT_TYPES = _QUANT_TYPE_INT4 | _QUANT_TYPE_INT8
59@functools.lru_cache(maxsize=1)
60def _is_hopper() -> bool:
61 return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9
64# ============================================================================
65# W4A16 (GPTQ uint4b8) fast path: tile-B + nibble-interleaved weight packing
66# fed to a magic-number SIMD INT4->bf16/fp16 dequant + tl.dot kernel. This is
67# the Hopper-gated short path taken by fused_marlin_moe for plain GPTQ uint4b8.
68# ============================================================================
69_W_PACK_CACHE: WeakTensorKeyDictionary = WeakTensorKeyDictionary()
70_SCALE_PACK_CACHE: WeakTensorKeyDictionary = WeakTensorKeyDictionary()
73def _pack_w_interleave(w: torch.Tensor, block_size_k: int) -> torch.Tensor:
74 assert w.dtype == torch.uint8
75 assert w.ndim == 3
76 assert (
77 block_size_k % 8 == 0
78 ), f"BLOCK_SIZE_K={block_size_k} must be multiple of 8 (8 logical K per int32)"
79 E, N_out, K_half = w.shape
80 K = K_half * 2
81 B = block_size_k // 8
82 assert K % (8 * B) == 0, f"K={K} must be divisible by BLOCK_SIZE_K={block_size_k}"
83 num_groups = K // (8 * B)
85 _NIBBLE_PERM = (0, 4, 1, 5, 2, 6, 3, 7)
86 _BIT_SHIFTS = tuple(4 * p for p in _NIBBLE_PERM)
87 shifts = torch.tensor(_BIT_SHIFTS, dtype=torch.int32, device=w.device)
88 out = torch.empty(E, K // 8, N_out, dtype=torch.int32, device=w.device)
90 for e in range(E):
91 we = w[e] # (N_out, K//2) uint8
92 low = (we & 0xF).to(torch.uint8)
93 high = ((we >> 4) & 0xF).to(torch.uint8)
94 unpacked = torch.stack([low, high], dim=-1).reshape(N_out, K)
95 tiled = unpacked.reshape(N_out, num_groups, 8, B).transpose(-1, -2)
96 # (N_out, num_groups, B, 8)
97 packed = (tiled.to(torch.int32) << shifts).sum(dim=-1, dtype=torch.int32)
98 # (N_out, num_groups, B) -> (N_out, K//8)
99 packed = packed.reshape(N_out, K // 8)
100 out[e].copy_(packed.transpose(0, 1))
101 return out # (E, K//8, N_out)
104def _pack_scale_transpose(s: torch.Tensor) -> torch.Tensor:
105 assert s.ndim == 3
106 return s.transpose(-2, -1).contiguous()
109def _cached_pack_w(w: torch.Tensor, block_size_k: int, cached: bool) -> torch.Tensor:
110 if not cached:
111 return _pack_w_interleave(w, block_size_k)
112 per_w = _W_PACK_CACHE.get(w)
113 if per_w is None:
114 per_w = {}
115 _W_PACK_CACHE[w] = per_w
116 packed = per_w.get(block_size_k)
117 if packed is None:
118 packed = _pack_w_interleave(w, block_size_k)
119 per_w[block_size_k] = packed
120 return packed
123def _cached_pack_scale(s: torch.Tensor, cached: bool) -> torch.Tensor:
124 if not cached:
125 return _pack_scale_transpose(s)
126 packed = _SCALE_PACK_CACHE.get(s)
127 if packed is None:
128 packed = _pack_scale_transpose(s)
129 _SCALE_PACK_CACHE[s] = packed
130 return packed
133def w4a16_pack(
134 w1: torch.Tensor,
135 w2: torch.Tensor,
136 w1_scale: Optional[torch.Tensor] = None,
137 w2_scale: Optional[torch.Tensor] = None,
138 *,
139 cached: bool = True,
140 pack_strategy: str = "interleave",
141 block_size_k: int = 16,
142) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
143 if pack_strategy != "interleave":
144 raise NotImplementedError(
145 f"pack_strategy={pack_strategy!r} not supported (only 'interleave')"
146 )
147 w1_packed = _cached_pack_w(w1, block_size_k, cached=cached)
148 w2_packed = _cached_pack_w(w2, block_size_k, cached=cached)
149 w1_scale_packed = (
150 _cached_pack_scale(w1_scale, cached=cached) if w1_scale is not None else None
151 )
152 w2_scale_packed = (
153 _cached_pack_scale(w2_scale, cached=cached) if w2_scale is not None else None
154 )
155 return w1_packed, w2_packed, w1_scale_packed, w2_scale_packed
158@triton.jit
159def _dequant_int4_fp16(b, scales):
160 x1, x2, x3, x4, x5, x6, x7, x8 = tl.inline_asm_elementwise(
161 asm="""
162 {
163 .reg .b32 r0, r1, r2, r3, r4, r5, r6, r8, r9, r10, r11, r12;
164 .reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;
165 .reg .b16 s;
166 mov.u32 r0, $8;
167 shr.u32 r1, r0, 8;
168 lop3.b32 r2, r0, 983055, 1677747200, 234; // (r0 & 0x000F000F) | 0x64006400
169 lop3.b32 r3, r0, 15728880, 1677747200, 234; // (r0 & 0x00F000F0) | 0x64006400
170 lop3.b32 r4, r1, 983055, 1677747200, 234;
171 lop3.b32 r5, r1, 15728880, 1677747200, 234;
172 mov.u32 r6, 1678271496; // 0x64086408 = (1032,1032)
173 mov.u32 r8, 738208768; // 0x2C002C00 = (1/16,1/16)
174 mov.u32 r9, -729754496; // 0xD480D480 = (-72,-72)
175 sub.f16x2 r10, r2, r6;
176 sub.f16x2 r12, r4, r6;
177 fma.rn.f16x2 r11, r3, r8, r9;
178 fma.rn.f16x2 r4, r5, r8, r9;
179 mov.b32 {h0, h1}, r10;
180 mov.b32 {h2, h3}, r11;
181 mov.b32 {h4, h5}, r12;
182 mov.b32 {h6, h7}, r4;
183 mov.b16 s, $9;
184 mul.f16 h0, h0, s;
185 mul.f16 h1, h1, s;
186 mul.f16 h2, h2, s;
187 mul.f16 h3, h3, s;
188 mul.f16 h4, h4, s;
189 mul.f16 h5, h5, s;
190 mul.f16 h6, h6, s;
191 mul.f16 h7, h7, s;
192 mov.b16 $0, h0;
193 mov.b16 $1, h1;
194 mov.b16 $2, h2;
195 mov.b16 $3, h3;
196 mov.b16 $4, h4;
197 mov.b16 $5, h5;
198 mov.b16 $6, h6;
199 mov.b16 $7, h7;
200 }
201 """,
202 constraints="=h,=h,=h,=h,=h,=h,=h,=h,r,h",
203 args=[b, scales],
204 dtype=(tl.float16,) * 8,
205 is_pure=True,
206 pack=1,
207 )
208 return x1, x2, x3, x4, x5, x6, x7, x8
211@triton.jit
212def _dequant_int4_bf16(b, scales):
213 x1, x2, x3, x4, x5, x6, x7, x8 = tl.inline_asm_elementwise(
214 asm="""
215 {
216 .reg .b32 r0, r1, r2, r3, q0, q1, q2, q3, s0, s1, s2, s3, magic;
217 .reg .b16 h0, h1, h2, h3, h4, h5, h6, h7;
218 .reg .b16 s;
219 mov.u32 r0, $8;
220 shr.u32 r1, r0, 4; // high nibble of bytes 0,2 -> bits 0-3
221 shr.u32 r2, r0, 8; // low nibble of bytes 1,3 -> bits 0-3
222 shr.u32 r3, r0, 12; // high nibble of bytes 1,3 -> bits 0-3
223 // (x & 0x000F000F) | 0x43004300 -> bf16x2 of (128+nibble, 128+nibble)
224 lop3.b32 q0, r0, 983055, 1124090624, 234;
225 lop3.b32 q1, r1, 983055, 1124090624, 234;
226 lop3.b32 q2, r2, 983055, 1124090624, 234;
227 lop3.b32 q3, r3, 983055, 1124090624, 234;
228 mov.u32 magic, 1124614920; // 0x43084308 = (136,136)
229 sub.rn.bf16x2 s0, q0, magic;
230 sub.rn.bf16x2 s1, q1, magic;
231 sub.rn.bf16x2 s2, q2, magic;
232 sub.rn.bf16x2 s3, q3, magic;
233 mov.b32 {h0, h1}, s0; // (n0-8, n4-8)
234 mov.b32 {h2, h3}, s1; // (n1-8, n5-8)
235 mov.b32 {h4, h5}, s2; // (n2-8, n6-8)
236 mov.b32 {h6, h7}, s3; // (n3-8, n7-8)
237 mov.b16 s, $9;
238 mul.rn.bf16 h0, h0, s;
239 mul.rn.bf16 h1, h1, s;
240 mul.rn.bf16 h2, h2, s;
241 mul.rn.bf16 h3, h3, s;
242 mul.rn.bf16 h4, h4, s;
243 mul.rn.bf16 h5, h5, s;
244 mul.rn.bf16 h6, h6, s;
245 mul.rn.bf16 h7, h7, s;
246 mov.b16 $0, h0;
247 mov.b16 $1, h1;
248 mov.b16 $2, h2;
249 mov.b16 $3, h3;
250 mov.b16 $4, h4;
251 mov.b16 $5, h5;
252 mov.b16 $6, h6;
253 mov.b16 $7, h7;
254 }
255 """,
256 constraints="=h,=h,=h,=h,=h,=h,=h,=h,r,h",
257 args=[b, scales],
258 dtype=(tl.bfloat16,) * 8,
259 is_pure=True,
260 pack=1,
261 )
262 return x1, x2, x3, x4, x5, x6, x7, x8
265@triton.jit
266def _stack_along_dim0(a, b, X: tl.constexpr, Y: tl.constexpr):
267 j = tl.join(a, b) # (X, Y, 2)
268 p = tl.permute(j, (2, 0, 1)) # (2, X, Y)
269 return tl.reshape(p, (2 * X, Y)) # (2X, Y) block-concat
272@triton.jit
273def _stack_8(bs, K_PACK: tl.constexpr, N: tl.constexpr):
274 s01 = _stack_along_dim0(bs[0], bs[1], K_PACK, N) # (2*K_PACK, N)
275 s23 = _stack_along_dim0(bs[2], bs[3], K_PACK, N)
276 s45 = _stack_along_dim0(bs[4], bs[5], K_PACK, N)
277 s67 = _stack_along_dim0(bs[6], bs[7], K_PACK, N)
278 s0123 = _stack_along_dim0(s01, s23, 2 * K_PACK, N) # (4*K_PACK, N)
279 s4567 = _stack_along_dim0(s45, s67, 2 * K_PACK, N)
280 return _stack_along_dim0(s0123, s4567, 4 * K_PACK, N) # (8*K_PACK, N)
283@triton.autotune(
284 configs=[
285 triton.Config(
286 {"BLOCK_SIZE_N": 64, "GROUP_SIZE_M": 1}, num_warps=4, num_stages=4
287 ),
288 triton.Config(
289 {"BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 1}, num_warps=4, num_stages=4
290 ),
291 triton.Config(
292 {"BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 4}, num_warps=4, num_stages=4
293 ),
294 triton.Config(
295 {"BLOCK_SIZE_N": 128, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=3
296 ),
297 triton.Config(
298 {"BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=3
299 ),
300 triton.Config(
301 {"BLOCK_SIZE_N": 256, "GROUP_SIZE_M": 4}, num_warps=8, num_stages=2
302 ),
303 ],
304 key=["N", "K"],
305)
306@triton.jit
307def _w4a16_moe_gemm_kernel(
308 a_ptr,
309 b_ptr,
310 c_ptr,
311 b_scale_ptr,
312 topk_weights_ptr,
313 sorted_token_ids_ptr,
314 expert_ids_ptr,
315 num_tokens_post_padded_ptr,
316 N: tl.constexpr,
317 K: tl.constexpr,
318 EM,
319 num_valid_tokens,
320 stride_am,
321 stride_ak,
322 stride_be,
323 stride_bk,
324 stride_bn,
325 stride_cm,
326 stride_cn,
327 stride_bse,
328 stride_bsg,
329 stride_bsn,
330 BLOCK_SIZE_M: tl.constexpr, # token tile (MMA M-dim, or N-dim if SWAP_AB)
331 BLOCK_SIZE_N: tl.constexpr, # weight tile (MMA N-dim, or M-dim if SWAP_AB)
332 BLOCK_SIZE_K: tl.constexpr, # logical-K tile (must match packing)
333 GROUP_SIZE_M: tl.constexpr,
334 GROUP_SIZE_K: tl.constexpr, # = quant group_size (e.g. 128)
335 MUL_ROUTED_WEIGHT: tl.constexpr,
336 top_k: tl.constexpr,
337 compute_type: tl.constexpr,
338 SWAP_AB: tl.constexpr,
339):
340 BLOCK_SIZE_K_PACK: tl.constexpr = BLOCK_SIZE_K // 8
342 pid = tl.program_id(axis=0)
343 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
344 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
345 num_pid_in_group = GROUP_SIZE_M * num_pid_n
346 group_id = pid // num_pid_in_group
347 first_pid_m = group_id * GROUP_SIZE_M
348 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
349 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
350 pid_n = (pid % num_pid_in_group) // group_size_m
352 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
353 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
354 return
356 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
357 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
358 token_mask = offs_token < num_valid_tokens
360 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
361 if off_experts == -1:
362 if SWAP_AB:
363 offs_cn0 = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
364 c_ptrs0 = (
365 c_ptr + stride_cm * offs_token[None, :] + stride_cn * offs_cn0[:, None]
366 )
367 c_mask0 = token_mask[None, :] & (offs_cn0[:, None] < N)
368 tl.store(
369 c_ptrs0,
370 tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=compute_type),
371 mask=c_mask0,
372 )
373 else:
374 write_zeros_to_output(
375 c_ptr,
376 stride_cm,
377 stride_cn,
378 pid_n,
379 N,
380 offs_token,
381 token_mask,
382 BLOCK_SIZE_M,
383 BLOCK_SIZE_N,
384 compute_type,
385 )
386 return
388 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
389 offs_ak_pack = tl.arange(0, BLOCK_SIZE_K_PACK)
390 offs_bk = tl.arange(0, BLOCK_SIZE_K_PACK)
392 if SWAP_AB:
393 a_base = a_ptr + (offs_token[None, :] // top_k * stride_am)
394 b_ptrs = (
395 b_ptr
396 + off_experts * stride_be
397 + offs_bn[:, None] * stride_bn
398 + offs_bk[None, :] * stride_bk
399 )
400 accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
401 else:
402 a_base = a_ptr + (offs_token[:, None] // top_k * stride_am)
403 b_ptrs = (
404 b_ptr
405 + off_experts * stride_be
406 + offs_bk[:, None] * stride_bk
407 + offs_bn[None, :] * stride_bn
408 )
409 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
410 scale_base = b_scale_ptr + off_experts * stride_bse + offs_bn * stride_bsn
412 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
413 b_packed = tl.load(b_ptrs)
414 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K
415 scale = tl.load(scale_base + scale_idx * stride_bsg)
416 scale_bc = scale[:, None] if SWAP_AB else scale[None, :]
418 if compute_type == tl.float16:
419 bs = _dequant_int4_fp16(b_packed, scale_bc)
420 else:
421 bs = _dequant_int4_bf16(b_packed, scale_bc)
423 k_logical_base = k * BLOCK_SIZE_K
424 for j in tl.static_range(8):
425 k_off = k_logical_base + j * BLOCK_SIZE_K_PACK
426 if SWAP_AB:
427 a_j_ptrs = a_base + (k_off + offs_ak_pack[:, None]) * stride_ak
428 a_j = tl.load(
429 a_j_ptrs, mask=token_mask[None, :], other=0.0
430 ) # (K_PACK, M)
431 accumulator = tl.dot(bs[j], a_j, acc=accumulator) # (N, M)
432 else:
433 a_j_ptrs = a_base + (k_off + offs_ak_pack[None, :]) * stride_ak
434 a_j = tl.load(
435 a_j_ptrs, mask=token_mask[:, None], other=0.0
436 ) # (M, K_PACK)
437 accumulator = tl.dot(a_j, bs[j], acc=accumulator) # (M, N)
439 b_ptrs += BLOCK_SIZE_K_PACK * stride_bk
441 if MUL_ROUTED_WEIGHT:
442 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
443 accumulator = accumulator * (
444 moe_weight[None, :] if SWAP_AB else moe_weight[:, None]
445 )
447 accumulator = accumulator.to(compute_type)
449 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
450 if SWAP_AB:
451 c_ptrs = c_ptr + stride_cm * offs_token[None, :] + stride_cn * offs_cn[:, None]
452 c_mask = token_mask[None, :] & (offs_cn[:, None] < N)
453 else:
454 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
455 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
456 tl.store(c_ptrs, accumulator, mask=c_mask)
459def _invoke_w4a16_moe_gemm(
460 A: torch.Tensor, # (M, K) for GEMM1, (M*top_k, K) for GEMM2
461 B: torch.Tensor, # (E, K//8, N) int32
462 C: torch.Tensor, # (M, top_k, N) or (M*top_k, N) view
463 B_scale: torch.Tensor, # (E, K/gs, N) fp16/bf16
464 topk_weights: Optional[torch.Tensor],
465 sorted_token_ids: torch.Tensor,
466 expert_ids: torch.Tensor,
467 num_tokens_post_padded: torch.Tensor,
468 *,
469 mul_routed_weight: bool,
470 top_k: int,
471 block_m: int,
472 block_size_k: int,
473 group_size: int,
474 compute_type, # tl.float16 or tl.bfloat16
475 swap_ab: bool = False,
476):
477 M_a = A.size(0)
478 K = A.size(1)
479 N = B.size(2)
480 EM = sorted_token_ids.size(0)
481 if M_a < block_m:
482 EM = min(EM, M_a * top_k * block_m)
484 if C.ndim == 3:
485 stride_cm = C.stride(1)
486 stride_cn = C.stride(2)
487 else:
488 stride_cm = C.stride(0)
489 stride_cn = C.stride(1)
491 grid = lambda META: ( # noqa: E731
492 triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
493 )
495 _w4a16_moe_gemm_kernel[grid](
496 A,
497 B,
498 C,
499 B_scale,
500 topk_weights,
501 sorted_token_ids,
502 expert_ids,
503 num_tokens_post_padded,
504 N,
505 K,
506 EM,
507 A.size(0) * top_k,
508 A.stride(0),
509 A.stride(1),
510 B.stride(0),
511 B.stride(1),
512 B.stride(2),
513 stride_cm,
514 stride_cn,
515 B_scale.stride(0),
516 B_scale.stride(1),
517 B_scale.stride(2),
518 BLOCK_SIZE_M=block_m,
519 BLOCK_SIZE_K=block_size_k,
520 GROUP_SIZE_K=group_size,
521 MUL_ROUTED_WEIGHT=mul_routed_weight,
522 top_k=top_k,
523 compute_type=compute_type,
524 SWAP_AB=swap_ab,
525 )
528def fused_moe_w4a16_gptq(
529 hidden_states: torch.Tensor,
530 w1: torch.Tensor,
531 w2: torch.Tensor,
532 w1_scale: torch.Tensor,
533 w2_scale: torch.Tensor,
534 topk_weights: torch.Tensor,
535 topk_ids: torch.Tensor,
536 *,
537 activation: str = "silu",
538 group_size: int = 128,
539 apply_router_weight_on_input: bool = False,
540 inplace: bool = False,
541 swap_ab: bool = True,
542) -> torch.Tensor:
543 assert activation == "silu"
544 assert hidden_states.dtype in (torch.float16, torch.bfloat16)
545 assert hidden_states.is_contiguous()
546 assert w1.dtype == torch.uint8 and w2.dtype == torch.uint8
547 assert w1.stride(-1) == 1 and w2.stride(-1) == 1
549 M = hidden_states.size(0)
550 K = hidden_states.size(1)
551 E = w1.size(0)
552 intermediate_size = w1.size(1) // 2
553 top_k_num = topk_ids.size(1)
555 assert w1.shape == (E, 2 * intermediate_size, K // 2)
556 assert w2.shape == (E, K, intermediate_size // 2)
557 assert K % group_size == 0
558 assert intermediate_size % group_size == 0
559 assert w1_scale.shape == (E, 2 * intermediate_size, K // group_size)
560 assert w2_scale.shape == (E, K, intermediate_size // group_size)
561 assert w1_scale.dtype == hidden_states.dtype
562 assert w2_scale.dtype == hidden_states.dtype
563 assert topk_weights.shape == topk_ids.shape
565 block_size_k = group_size
566 # Compute_type for the kernel.
567 if hidden_states.dtype == torch.float16:
568 compute_type = tl.float16
569 else:
570 compute_type = tl.bfloat16
572 w1_packed, w2_packed, w1_scale_packed, w2_scale_packed = w4a16_pack(
573 w1,
574 w2,
575 w1_scale,
576 w2_scale,
577 block_size_k=block_size_k,
578 cached=True,
579 )
581 cache13_size = M * top_k_num * max(2 * intermediate_size, K)
582 cache13 = torch.empty(
583 cache13_size, device=hidden_states.device, dtype=hidden_states.dtype
584 )
585 intermediate_cache1 = cache13[: M * top_k_num * 2 * intermediate_size].view(
586 M * top_k_num, 2 * intermediate_size
587 )
588 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
589 intermediate_cache2 = torch.empty(
590 (M * top_k_num, intermediate_size),
591 device=hidden_states.device,
592 dtype=hidden_states.dtype,
593 )
595 avg_tokens = max(M * top_k_num // max(E, 1), 1)
596 cutoff = 8 if swap_ab else 16
597 block_m = 16 if avg_tokens <= cutoff else (32 if avg_tokens <= 64 else 64)
598 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
599 topk_ids=topk_ids,
600 block_size=block_m,
601 num_experts=E,
602 expert_map=None,
603 )
605 _invoke_w4a16_moe_gemm(
606 A=hidden_states,
607 B=w1_packed,
608 C=intermediate_cache1,
609 B_scale=w1_scale_packed,
610 topk_weights=topk_weights if apply_router_weight_on_input else None,
611 sorted_token_ids=sorted_token_ids,
612 expert_ids=expert_ids,
613 num_tokens_post_padded=num_tokens_post_padded,
614 mul_routed_weight=apply_router_weight_on_input,
615 top_k=top_k_num,
616 block_m=block_m,
617 block_size_k=block_size_k,
618 group_size=group_size,
619 compute_type=compute_type,
620 swap_ab=swap_ab,
621 )
623 gate = intermediate_cache1[:, :intermediate_size]
624 up = intermediate_cache1[:, intermediate_size:]
625 silu_and_mul_out(gate, up, intermediate_cache2)
627 _invoke_w4a16_moe_gemm(
628 A=intermediate_cache2,
629 B=w2_packed,
630 C=intermediate_cache3,
631 B_scale=w2_scale_packed,
632 topk_weights=topk_weights if not apply_router_weight_on_input else None,
633 sorted_token_ids=sorted_token_ids,
634 expert_ids=expert_ids,
635 num_tokens_post_padded=num_tokens_post_padded,
636 mul_routed_weight=not apply_router_weight_on_input,
637 top_k=1,
638 block_m=block_m,
639 block_size_k=block_size_k,
640 group_size=group_size,
641 compute_type=compute_type,
642 swap_ab=swap_ab,
643 )
645 if inplace:
646 out_hidden_states = hidden_states
647 else:
648 out_hidden_states = torch.empty_like(hidden_states)
649 moe_sum(intermediate_cache3, out_hidden_states)
651 return out_hidden_states
654# ----------------------------------------------------------------------------
655# Phase-2 impl: copy of fused_experts_impl but with the dequant shortcut
656# removed so the wna16 Triton kernel is actually invoked for W4A16/W8A16.
657# ----------------------------------------------------------------------------
658def _fused_marlin_moe_impl(
659 hidden_states: torch.Tensor,
660 w1: torch.Tensor,
661 w2: torch.Tensor,
662 topk_weights: torch.Tensor,
663 topk_ids: torch.Tensor,
664 inplace: bool = False,
665 activation: str = "silu",
666 apply_router_weight_on_input: bool = False,
667 use_int8_w8a16: bool = False,
668 use_int4_w4a16: bool = False,
669 per_channel_quant: bool = False,
670 global_num_experts: int = -1,
671 expert_map: torch.Tensor | None = None,
672 w1_scale: Optional[torch.Tensor] = None,
673 w2_scale: Optional[torch.Tensor] = None,
674 w1_zp: torch.Tensor | None = None,
675 w2_zp: torch.Tensor | None = None,
676 block_shape: Optional[list[int]] = None,
677 w1_bias: Optional[torch.Tensor] = None,
678 w2_bias: Optional[torch.Tensor] = None,
679) -> torch.Tensor:
680 """
681 Like fused_experts_impl, but:
682 - drops all paths irrelevant to W4A16/W8A16 (no fp8, int8_w8a8, mxfp).
683 - REMOVES the `w = w.to(fp16) * scale.unsqueeze(-1)` dequant shortcut.
684 - forwards block_shape so the wna16 kernel uses the right group_size.
685 """
686 assert (
687 activation == "silu"
688 ), f"Only 'silu' activation is supported, got {activation}"
689 assert (
690 use_int4_w4a16 or use_int8_w8a16
691 ), "_fused_marlin_moe_impl expects a quantized path"
693 activation_enum = MoEActivation.from_str(activation)
695 # Packed-aware shape check.
696 # W4A16 (pack_factor=2): w1.size(2) == K // 2
697 # W8A16 (pack_factor=1): w1.size(2) == K
698 expected_packed_k = (
699 hidden_states.size(1) // 2 if use_int4_w4a16 else hidden_states.size(1)
700 )
701 assert w1.size(2) == expected_packed_k, (
702 f"w1 packed K mismatch: hidden_size={hidden_states.size(1)}, "
703 f"use_int4_w4a16={use_int4_w4a16}, expected w1.size(2)={expected_packed_k}, "
704 f"got {w1.size(2)}"
705 )
707 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
708 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
709 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
710 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
711 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
713 num_tokens = hidden_states.size(0)
714 E, N, _ = w1.size()
715 K = w2.size(1)
716 if global_num_experts == -1:
717 global_num_experts = E
718 top_k_num = topk_ids.size(1)
720 CHUNK_SIZE: int = 16 * 1024
721 M = min(num_tokens, CHUNK_SIZE)
723 config_dtype = _get_config_dtype_str(
724 use_fp8_w8a8=False,
725 use_int8_w8a16=use_int8_w8a16,
726 use_int4_w4a16=use_int4_w4a16,
727 ocp_mx_scheme=None,
728 dtype=hidden_states.dtype,
729 )
730 quant_dtype = _get_config_quant_dtype(
731 use_fp8_w8a8=False,
732 use_int8_w8a8=False,
733 ocp_mx_scheme=None,
734 )
736 get_config_func = functools.partial(
737 try_get_optimal_moe_config,
738 w1.size(),
739 w2.size(),
740 top_k_num,
741 config_dtype,
742 block_shape=block_shape,
743 E=E,
744 )
745 config = get_config_func(M)
746 config["SPLIT_K"] = 1
748 # cache1 and cache3 share memory (non-overlapping lifetime)
749 cache13 = torch.empty(
750 M * top_k_num * max(N, K),
751 device=hidden_states.device,
752 dtype=hidden_states.dtype,
753 )
754 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
755 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
757 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
758 intermediate_cache2 = torch.empty(
759 (M * top_k_num, activation_out_dim),
760 device=hidden_states.device,
761 dtype=hidden_states.dtype,
762 )
764 if hidden_states.dtype == torch.bfloat16:
765 compute_type = tl.bfloat16
766 elif hidden_states.dtype == torch.float16:
767 compute_type = tl.float16
768 elif hidden_states.dtype == torch.float32:
769 compute_type = tl.float32
770 else:
771 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
773 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
775 # ★ Phase-2 KEY DIFFERENCE: the W4A16/W8A16 dequant shortcut that lived
776 # here in `fused_experts_impl` is intentionally REMOVED. The wna16
777 # Triton kernel will consume INT4 weights + scale directly.
779 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
780 begin_chunk_idx, end_chunk_idx = (
781 chunk * CHUNK_SIZE,
782 min((chunk + 1) * CHUNK_SIZE, num_tokens),
783 )
784 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
785 tokens_in_chunk, _ = curr_hidden_states.size()
787 if tokens_in_chunk == 0:
788 break
790 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
791 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
792 intermediate_cache2 = intermediate_cache2[
793 : tokens_in_chunk * topk_ids.size(1)
794 ]
795 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
796 config = get_config_func(tokens_in_chunk)
797 config["SPLIT_K"] = 1
799 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
800 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
802 # Activation quantization is a no-op for W4A16/W8A16 (no input quant).
803 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
804 A=curr_hidden_states,
805 A_scale=None,
806 quant_dtype=quant_dtype,
807 per_act_token_quant=per_channel_quant,
808 block_shape=block_shape,
809 ocp_mx_scheme=None,
810 )
812 # Use the routed-path (skip the SPARSITY_FACTOR shortcut, which is
813 # explicitly disabled for quantized + block_shape configs anyway).
814 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
815 curr_topk_ids,
816 config["BLOCK_SIZE_M"],
817 global_num_experts,
818 expert_map,
819 )
821 # ----- GEMM 1: hidden @ w1 (fused dequant on B inside the kernel) -----
822 dispatch_fused_moe_kernel(
823 qcurr_hidden_states,
824 w1,
825 intermediate_cache1,
826 a1q_scale,
827 w1_scale,
828 w1_zp,
829 curr_topk_weights,
830 sorted_token_ids,
831 expert_ids,
832 num_tokens_post_padded,
833 apply_router_weight_on_input,
834 top_k_num,
835 config,
836 compute_type=compute_type,
837 use_fp8_w8a8=False,
838 use_int8_w8a8=False,
839 use_int8_w8a16=use_int8_w8a16,
840 use_int4_w4a16=use_int4_w4a16,
841 per_channel_quant=per_channel_quant,
842 block_shape=block_shape,
843 B_bias=w1_bias,
844 )
846 # ----- Activation: SwiGLU = silu(gate) * up -----
847 apply_moe_activation(
848 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
849 )
851 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
852 A=intermediate_cache2,
853 A_scale=None,
854 quant_dtype=quant_dtype,
855 per_act_token_quant=per_channel_quant,
856 block_shape=block_shape,
857 ocp_mx_scheme=None,
858 )
860 if expert_map is not None:
861 intermediate_cache3.zero_()
863 # ----- GEMM 2: act @ w2 (fused dequant on B inside the kernel) -----
864 dispatch_fused_moe_kernel(
865 qintermediate_cache2,
866 w2,
867 intermediate_cache3,
868 a2q_scale,
869 w2_scale,
870 w2_zp,
871 curr_topk_weights,
872 sorted_token_ids,
873 expert_ids,
874 num_tokens_post_padded,
875 not apply_router_weight_on_input,
876 1,
877 config,
878 compute_type=compute_type,
879 use_fp8_w8a8=False,
880 use_int8_w8a8=False,
881 use_int8_w8a16=use_int8_w8a16,
882 use_int4_w4a16=use_int4_w4a16,
883 per_channel_quant=per_channel_quant,
884 block_shape=block_shape,
885 B_bias=w2_bias,
886 )
888 # ----- Reduce: sum topk-weighted expert outputs back per token -----
889 moe_sum(
890 intermediate_cache3.view(*intermediate_cache3.size()),
891 out_hidden_states[begin_chunk_idx:end_chunk_idx],
892 )
894 return out_hidden_states
897# ----------------------------------------------------------------------------
898# Public entry point: vLLM-aligned wrapper.
899# ----------------------------------------------------------------------------
900def fused_marlin_moe(
901 hidden_states: torch.Tensor,
902 w1: torch.Tensor,
903 w2: torch.Tensor,
904 bias1: Optional[torch.Tensor],
905 bias2: Optional[torch.Tensor],
906 w1_scale: torch.Tensor,
907 w2_scale: torch.Tensor,
908 topk_weights: torch.Tensor,
909 topk_ids: torch.Tensor,
910 quant_type_id: int,
911 apply_router_weight_on_input: bool = False,
912 global_num_experts: int = -1,
913 activation: Any = None,
914 activation_func: Optional[Callable] = None,
915 moe_sum: Optional[Callable] = None,
916 expert_map: Optional[torch.Tensor] = None,
917 input_global_scale1: Optional[torch.Tensor] = None,
918 input_global_scale2: Optional[torch.Tensor] = None,
919 global_scale1: Optional[torch.Tensor] = None,
920 global_scale2: Optional[torch.Tensor] = None,
921 g_idx1: Optional[torch.Tensor] = None,
922 g_idx2: Optional[torch.Tensor] = None,
923 sort_indices1: Optional[torch.Tensor] = None,
924 sort_indices2: Optional[torch.Tensor] = None,
925 w1_zeros: Optional[torch.Tensor] = None,
926 w2_zeros: Optional[torch.Tensor] = None,
927 workspace: Optional[torch.Tensor] = None,
928 intermediate_cache13: Optional[torch.Tensor] = None,
929 intermediate_cache2: Optional[torch.Tensor] = None,
930 is_k_full: bool = True,
931 output: Optional[torch.Tensor] = None,
932 input_dtype: Optional[torch.dtype] = None,
933 inplace: bool = False,
934 clamp_limit: Optional[float] = None,
935 group_size: int = 128,
936) -> torch.Tensor:
937 """Phase-2 entry point: dispatch to local wna16-using impl."""
938 # ---- MVP guardrails --------------------------------------------------
939 if quant_type_id not in _SUPPORTED_QUANT_TYPES:
940 raise NotImplementedError(
941 f"MVP supports quant_type_id in {_SUPPORTED_QUANT_TYPES}, "
942 f"got {quant_type_id}"
943 )
944 if g_idx1 is not None or g_idx2 is not None:
945 raise NotImplementedError("act_order (g_idx) not yet supported in MVP")
946 if sort_indices1 is not None or sort_indices2 is not None:
947 raise NotImplementedError("act_order (sort_indices) not yet supported in MVP")
948 if input_dtype is not None:
949 raise NotImplementedError("FP8 / INT8 input quantization not supported")
950 if clamp_limit is not None:
951 raise NotImplementedError("clamp_limit (GLM-4 swiglu) not supported")
952 if input_global_scale1 is not None or input_global_scale2 is not None:
953 raise NotImplementedError("input_global_scale not supported in MVP")
954 if global_scale1 is not None or global_scale2 is not None:
955 raise NotImplementedError("global_scale not supported in MVP")
957 use_int4_w4a16 = quant_type_id in _QUANT_TYPE_INT4
958 use_int8_w8a16 = quant_type_id in _QUANT_TYPE_INT8
960 activation_str = "silu"
961 if activation is not None:
962 for attr in ("value", "name"):
963 v = getattr(activation, attr, None)
964 if isinstance(v, str):
965 activation_str = v.lower()
966 break
967 if isinstance(activation, str):
968 activation_str = activation.lower()
969 if activation_str != "silu":
970 raise NotImplementedError(
971 f"MVP only supports SiLU/SwiGLU activation, got {activation_str}"
972 )
974 if inplace and output is not None:
975 raise ValueError("Cannot pass both inplace=True and output")
977 if (
978 # The magic-trick kernel's bf16 dequant uses sub.bf16x2/mul.bf16 PTX,
979 # which require sm_90+; on pre-Hopper fall back to the generic wna16 kernel.
980 _is_hopper()
981 and use_int4_w4a16
982 and hidden_states.dtype in (torch.float16, torch.bfloat16)
983 and w1.dtype == torch.uint8
984 and w2.dtype == torch.uint8
985 and bias1 is None
986 and bias2 is None
987 and w1_zeros is None
988 and w2_zeros is None
989 and expert_map is None
990 and (global_num_experts == -1 or global_num_experts == w1.size(0))
991 and group_size >= 128
992 and w1_scale.dtype == hidden_states.dtype
993 and w2_scale.dtype == hidden_states.dtype
994 ):
995 result = fused_moe_w4a16_gptq(
996 hidden_states=hidden_states,
997 w1=w1,
998 w2=w2,
999 w1_scale=w1_scale,
1000 w2_scale=w2_scale,
1001 topk_weights=topk_weights,
1002 topk_ids=topk_ids,
1003 activation=activation_str,
1004 group_size=group_size,
1005 apply_router_weight_on_input=apply_router_weight_on_input,
1006 inplace=inplace,
1007 )
1008 if output is not None:
1009 output.copy_(result)
1010 return output
1011 return result
1013 result = _fused_marlin_moe_impl(
1014 hidden_states=hidden_states,
1015 w1=w1,
1016 w2=w2,
1017 topk_weights=topk_weights,
1018 topk_ids=topk_ids,
1019 inplace=inplace,
1020 activation=activation_str,
1021 apply_router_weight_on_input=apply_router_weight_on_input,
1022 use_int4_w4a16=use_int4_w4a16,
1023 use_int8_w8a16=use_int8_w8a16,
1024 global_num_experts=global_num_experts,
1025 expert_map=expert_map,
1026 w1_scale=w1_scale,
1027 w2_scale=w2_scale,
1028 w1_zp=w1_zeros,
1029 w2_zp=w2_zeros,
1030 w1_bias=bias1,
1031 w2_bias=bias2,
1032 # Critical for Phase 2: block_shape=[0, group_size] makes the
1033 # wna16 Triton kernel use the per-group scales correctly.
1034 block_shape=[0, group_size],
1035 )
1037 if output is not None:
1038 output.copy_(result)
1039 return output
1040 return result
1043__all__ = ["fused_marlin_moe", "QUANT_TYPE_UINT4B8", "QUANT_TYPE_UINT8B128"]