Coverage for src/flag_gems/fused_moe_mxq.py: 0%
310 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# SPDX-License-Identifier: Apache-2.0
2# QC-MoE: Quantized Mixture of Experts kernel for FlagGems
3# Main module integrating MoE kernels with quantization support
5from dataclasses import dataclass
6from enum import Enum
7from typing import Any, List, Optional, Tuple
9import torch
10import triton
11import triton.language as tl
13# Device detection
14_is_cuda = torch.cuda.is_available()
16if _is_cuda:
18 def is_sm90_supported():
19 device_cap = torch.cuda.get_device_capability()
20 return device_cap[0] >= 9 # H100, H200, etc.
22else:
24 def is_sm90_supported():
25 return False
28# ============================================================================
29# QuantMode and QuantConfig
30# ============================================================================
33class QuantMode(Enum):
34 """Quantization modes supported by QC-MoE."""
36 FP16 = "fp16"
37 FP8 = "fp8"
38 INT8 = "int8"
39 W8A16 = "w8a16" # INT8 weight, FP16 activation
40 W4A16 = "w4a16" # INT4 weight, FP16 activation
43@dataclass
44class QuantConfig:
45 """Configuration for MoE quantization."""
47 mode: QuantMode = QuantMode.FP16
48 group_size: int = 128
49 has_zero_point: bool = True
50 per_channel_quant: bool = False
52 @property
53 def w_nbits(self) -> int:
54 """Get weight bit width from mode."""
55 if self.mode == QuantMode.W4A16:
56 return 4
57 elif self.mode in (QuantMode.W8A16, QuantMode.INT8, QuantMode.FP8):
58 return 8
59 return 16
61 @property
62 def use_int4(self) -> bool:
63 return self.mode == QuantMode.W4A16
65 @property
66 def use_int8(self) -> bool:
67 return self.mode in (QuantMode.W8A16, QuantMode.INT8)
70# ============================================================================
71# Triton Kernels
72# ============================================================================
75@triton.jit
76def fused_moe_kernel_gptq_awq(
77 # Pointers to matrices
78 A,
79 B,
80 C,
81 B_scale,
82 B_zp,
83 topk_weights,
84 sorted_token_ids,
85 expert_ids,
86 num_tokens_post_padded,
87 # Matrix dimensions
88 N,
89 K,
90 EM,
91 num_valid_tokens,
92 # Strides
93 stride_am,
94 stride_ak,
95 stride_be,
96 stride_bk,
97 stride_bn,
98 stride_cm,
99 stride_cn,
100 stride_bse,
101 stride_bsk,
102 stride_bsn,
103 stride_bze,
104 stride_bzk,
105 stride_bzn,
106 group_size: tl.constexpr,
107 # Meta-parameters
108 BLOCK_SIZE_M: tl.constexpr,
109 BLOCK_SIZE_N: tl.constexpr,
110 BLOCK_SIZE_K: tl.constexpr,
111 GROUP_SIZE_M: tl.constexpr,
112 MUL_ROUTED_WEIGHT: tl.constexpr,
113 top_k: tl.constexpr,
114 compute_type: tl.constexpr,
115 has_zp: tl.constexpr,
116 use_int4_w4a16: tl.constexpr,
117 use_int8_w8a16: tl.constexpr,
118 even_Ks: tl.constexpr,
119 filter_expert: tl.constexpr,
120):
121 """
122 Simplified MoE kernel for single dispatch entry processing.
123 Each program processes one (token, expert) pair.
124 """
125 pid = tl.program_id(0)
127 # Check bounds
128 if pid >= num_valid_tokens:
129 return
131 # Load dispatch information
132 token_id = tl.load(sorted_token_ids + pid).to(tl.int64)
133 expert_id = tl.load(expert_ids + pid).to(tl.int64)
134 weight = tl.load(topk_weights + pid).to(compute_type)
136 # Precompute strides
137 stride_bn_c = tl.constexpr(stride_bn)
138 stride_bk_c = tl.constexpr(stride_bk)
139 stride_bsn_c = tl.constexpr(stride_bsn)
140 stride_bsk_c = tl.constexpr(stride_bsk)
141 stride_bzn_c = tl.constexpr(stride_bzn)
142 stride_bzk_c = tl.constexpr(stride_bzk)
143 stride_be_c = tl.constexpr(stride_be)
144 stride_bse_c = tl.constexpr(stride_bse)
145 stride_bze_c = tl.constexpr(stride_bze)
147 # offs_n: range of N elements
148 offs_n = tl.arange(0, BLOCK_SIZE_N)
149 n_mask = offs_n < N
151 # Initialize accumulator
152 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32)
154 # Process all K elements in BLOCK_SIZE_K chunks
155 for k_block in range(tl.cdiv(K, BLOCK_SIZE_K)):
156 k_base = k_block * BLOCK_SIZE_K
157 offs_k = tl.arange(0, BLOCK_SIZE_K)
158 k_indices = k_base + offs_k
159 k_mask = k_indices < K
161 # Load activation: A[token_id, k_indices]
162 a = tl.load(
163 A + (token_id * stride_am + k_indices * stride_ak), mask=k_mask, other=0.0
164 ).to(tl.float32)
166 # Load weight values: W[expert_id, offs_n, k_indices]
167 w = tl.load(
168 B
169 + (
170 expert_id * stride_be_c
171 + offs_n[None, :] * stride_bn_c
172 + k_indices[:, None] * stride_bk_c
173 ),
174 mask=k_mask[:, None] & n_mask[None, :],
175 other=0.0,
176 )
178 # Dequantize weights
179 if use_int4_w4a16:
180 w = (w & 0xF).to(compute_type)
181 elif use_int8_w8a16:
182 w = w.to(compute_type)
184 # Load scales: scales[expert_id, offs_n, group]
185 scale_group = k_indices // group_size
186 scales = tl.load(
187 B_scale
188 + (
189 expert_id * stride_bse_c
190 + offs_n[None, :] * stride_bsn_c
191 + scale_group[:, None] * stride_bsk_c
192 ),
193 mask=k_mask[:, None] & n_mask[None, :],
194 other=1.0,
195 ).to(tl.float32)
197 # Dequantize based on quantization mode
198 if use_int4_w4a16:
199 if has_zp:
200 zp = tl.load(
201 B_zp
202 + (
203 expert_id * stride_bze_c
204 + offs_n[None, :] * stride_bzn_c
205 + scale_group[:, None] * stride_bzk_c
206 ),
207 mask=k_mask[:, None] & n_mask[None, :],
208 other=0.0,
209 ).to(tl.float32)
210 w_dequant = (w.to(tl.float32) - zp) * scales
211 else:
212 w_dequant = (w.to(tl.float32) - 8.0) * scales
213 elif use_int8_w8a16:
214 if has_zp:
215 zp = tl.load(
216 B_zp
217 + (
218 expert_id * stride_bze_c
219 + offs_n[None, :] * stride_bzn_c
220 + scale_group[:, None] * stride_bzk_c
221 ),
222 mask=k_mask[:, None] & n_mask[None, :],
223 other=0.0,
224 ).to(tl.float32)
225 w_dequant = (w.to(tl.float32) - zp) * scales
226 else:
227 w_dequant = (w.to(tl.float32) - 128.0) * scales
228 else:
229 # No quantization - weights are already in compute_type (FP16)
230 w_dequant = w.to(tl.float32) * scales
232 # Compute matrix multiply using expand and sum: [BLOCK_SIZE_K, BLOCK_SIZE_N] * [BLOCK_SIZE_K, 1]
233 a_expanded = a[:, None] # [BLOCK_SIZE_K, BLOCK_SIZE_N]
234 result = tl.sum(a_expanded * w_dequant, axis=0) # [BLOCK_SIZE_N]
236 # Accumulate
237 accumulator = accumulator + result
239 # Apply routing weight
240 if MUL_ROUTED_WEIGHT:
241 accumulator = accumulator * weight
243 accumulator = accumulator.to(compute_type)
245 # Store result using atomic add
246 offs_n = tl.arange(0, BLOCK_SIZE_N)
247 n_mask = offs_n < N
248 output_ptrs = C + (token_id * stride_cm + offs_n * stride_cn)
249 tl.atomic_add(output_ptrs, accumulator, mask=n_mask)
252@triton.jit
253def fused_moe_kernel_fp16_swiglu(
254 A,
255 C,
256 B_gate,
257 B_up,
258 B_down,
259 topk_weights,
260 sorted_token_ids,
261 expert_ids,
262 num_tokens_post_padded,
263 inter_ptr,
264 N,
265 K,
266 EM,
267 num_valid_tokens,
268 stride_am,
269 stride_ak,
270 stride_bn,
271 stride_bk,
272 stride_cm,
273 stride_cn,
274 stride_gate_e,
275 stride_up_e,
276 stride_down_e,
277 stride_gate_n,
278 stride_gate_k,
279 stride_up_n,
280 stride_up_k,
281 stride_down_k,
282 stride_down_n,
283 stride_inter_m,
284 BLOCK_SIZE_K: tl.constexpr,
285 top_k: tl.constexpr,
286 even_Ks: tl.constexpr,
287):
288 """
289 FP16 SwiGLU MoE — complete gate(W1)/up(W3)/down(W2) in one dispatch entry.
291 FFN(x) = W2 @ (silu(W1 @ x) * (W3 @ x))
292 Each program processes one (token, expert) pair.
293 All loops use 1-element scalar iterations to avoid shape-compatibility issues.
294 """
295 pid = tl.program_id(0)
296 if pid >= num_valid_tokens:
297 return
299 token_id = tl.load(sorted_token_ids + pid).to(tl.int64)
300 expert_id = tl.load(expert_ids + pid).to(tl.int64)
301 weight = tl.load(topk_weights + pid).to(tl.float32)
303 # Compute inter_size = N in multiples of 32; partial blocks handled by mask
304 inter_off = pid * stride_inter_m
306 # ---------- GEMM 1: gate_acc[n] = sum_k( A[token,k] * W1[exp,n,k] ) ----------
307 for n in range(N):
308 acc = 0.0
309 for kb in range(tl.cdiv(K, BLOCK_SIZE_K)):
310 k_offs = kb * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
311 k_mask = k_offs < K
312 a_vals = tl.load(
313 A + token_id * stride_am + k_offs, mask=k_mask, other=0.0
314 ).to(tl.float32)
315 w_gate = tl.load(
316 B_gate
317 + expert_id * stride_gate_e
318 + n * stride_gate_n
319 + k_offs * stride_gate_k,
320 mask=k_mask,
321 other=0.0,
322 ).to(tl.float32)
323 acc = acc + tl.sum(a_vals * w_gate)
324 # Store gate result to inter[n] (we reuse the same buffer; gate first)
325 gate_val = acc
326 tl.store(inter_ptr + inter_off + n, gate_val)
328 # ---------- GEMM 2: up_acc[n] = sum_k( A[token,k] * W3[exp,n,k] ), multiply with gate ----------
329 for n in range(N):
330 acc = 0.0
331 for kb in range(tl.cdiv(K, BLOCK_SIZE_K)):
332 k_offs = kb * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
333 k_mask = k_offs < K
334 a_vals = tl.load(
335 A + token_id * stride_am + k_offs, mask=k_mask, other=0.0
336 ).to(tl.float32)
337 w_up = tl.load(
338 B_up + expert_id * stride_up_e + n * stride_up_n + k_offs * stride_up_k,
339 mask=k_mask,
340 other=0.0,
341 ).to(tl.float32)
342 acc = acc + tl.sum(a_vals * w_up)
343 gate_val = tl.load(inter_ptr + inter_off + n).to(tl.float32)
344 # SiLU(gate) * up -> store back as intermediate
345 act_val = tl.sigmoid(gate_val) * acc
346 tl.store(inter_ptr + inter_off + n, act_val)
348 # ---------- GEMM 3: down_acc[k] = sum_n( inter[n] * W2[exp,k,n] ), then scale and store ----------
349 for k in range(K):
350 acc = 0.0
351 for nb in range(tl.cdiv(N, 32)):
352 base_n = nb * 32
353 n_offs = base_n + tl.arange(0, 32)
354 n_mask = n_offs < N
355 inter_vals = tl.load(
356 inter_ptr + inter_off + n_offs, mask=n_mask, other=0.0
357 ).to(tl.float32)
358 w_down = tl.load(
359 B_down
360 + expert_id * stride_down_e
361 + k * stride_down_k
362 + n_offs * stride_down_n,
363 mask=n_mask,
364 other=0.0,
365 ).to(tl.float32)
366 acc = acc + tl.sum(inter_vals * w_down)
367 result = (acc * weight).to(tl.float16)
368 out_idx = token_id * stride_cm + k * stride_cn
369 cur = tl.load(C + out_idx).to(tl.float16)
370 tl.store(C + out_idx, cur + result)
373# ============================================================================
374# Helper Functions
375# ============================================================================
378def get_num_experts(shape_desc: str) -> int:
379 """Extract number of experts from shape description.
381 Common patterns:
382 - Qwen3.5-397B-A17B: 8 experts
383 - Mixtral-8x7B: 8 experts
384 - Switch Transformer: variable
385 """
386 if "Qwen" in shape_desc:
387 if "397B" in shape_desc:
388 return 8
389 elif "72B" in shape_desc:
390 return 8
391 elif "Mixtral" in shape_desc:
392 return 8
393 elif "Switch" in shape_desc:
394 return 64
395 return 8 # default
398def prepare_moe_inputs(
399 x: torch.Tensor,
400 topk_weights: torch.Tensor,
401 topk_ids: torch.Tensor,
402 num_experts: int,
403) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
404 """
405 Prepare inputs for fused MoE kernel.
407 Args:
408 x: Input tensor of shape (num_tokens, hidden_dim)
409 topk_weights: Weights for selected experts, shape (num_tokens, topk)
410 topk_ids: Expert indices, shape (num_tokens, topk)
411 num_experts: Total number of experts
413 Returns:
414 sorted_token_ids: Sorted token indices
415 expert_ids: Expert index for each block
416 num_tokens_post_padded: Total tokens after padding
417 block_size_m: Block size for tokens
418 """
419 num_tokens = x.shape[0]
420 topk = topk_ids.shape[1]
422 # Flatten and prepare for MoE dispatch
423 flat_topk_weights = topk_weights.view(-1)
424 flat_topk_ids = topk_ids.view(-1)
426 # Create mapping from token to expert selection
427 _, sorted_token_ids = torch.sort(flat_topk_weights, dim=0, descending=True)
429 # Get expert assignments
430 expert_ids = flat_topk_ids[sorted_token_ids]
432 # Pad to block size
433 block_size_m = 32 # Default block size
434 num_tokens_post_padded = (
435 (num_tokens * topk + block_size_m - 1) // block_size_m
436 ) * block_size_m
438 return sorted_token_ids, expert_ids, num_tokens_post_padded, block_size_m
441def quantize_weights_moe(
442 weights: torch.Tensor,
443 num_experts: int,
444 quant_config: QuantConfig,
445) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
446 """
447 Quantize MoE expert weights.
449 Args:
450 weights: Expert weights of shape (num_experts, out_features, in_features)
451 num_experts: Number of experts
452 quant_config: Quantization configuration
454 Returns:
455 W_q: Quantized weights (same shape as input if int8, packed if int4)
456 scales: Quantization scales of shape (num_experts, out_features, num_groups)
457 zeros: Optional zero points of same shape as scales
458 """
459 if quant_config.mode == QuantMode.FP16:
460 return weights, None, None
462 num_experts_e, n_out, k_in = weights.shape
463 num_groups = k_in // quant_config.group_size
465 if quant_config.use_int4:
466 w_bits = 4
467 else:
468 w_bits = 8
470 # Reshape for per-group quantization along the last dimension
471 # weights shape: (E, n_out, k_in) -> (E, n_out, num_groups, group_size)
472 weights_reshaped = weights.view(
473 num_experts, n_out, num_groups, quant_config.group_size
474 )
475 w_min = weights_reshaped.min(dim=-1, keepdim=True)[0]
476 w_max = weights_reshaped.max(dim=-1, keepdim=True)[0]
477 scale = (w_max - w_min) / ((2**w_bits) - 1)
478 scale = torch.where(scale > 0, scale, torch.ones_like(scale))
480 # Quantize
481 W_normalized = (weights_reshaped - w_min) / (scale + 1e-8)
482 W_q = W_normalized.round().clamp(0, 2**w_bits - 1)
483 W_q = W_q.to(torch.uint8)
485 # Reshape back - pack if int4
486 if quant_config.use_int4:
487 # Pack 2 int4 values per byte
488 W_q = W_q.view(num_experts, n_out, num_groups, quant_config.group_size // 2, 2)
489 W_q_packed = (W_q[..., 0] & 0xF) | (W_q[..., 1] << 4)
490 W_q = W_q_packed.view(num_experts, n_out, -1)
491 else:
492 W_q = W_q.view(num_experts, n_out, -1)
494 # Scales shape: (num_experts, n_out, num_groups)
495 scales = scale.squeeze(-1).view(num_experts, n_out, num_groups)
497 # Zero points if needed
498 zeros = None
499 if quant_config.has_zero_point:
500 zeros = w_min.squeeze(-1).view(num_experts, n_out, num_groups)
502 return W_q, scales, zeros
505def get_default_config(block_size_m=1, block_size_n=128, block_size_k=64):
506 """Get default kernel configuration with reduced sizes for shared memory."""
507 return {
508 "BLOCK_SIZE_M": block_size_m,
509 "BLOCK_SIZE_N": block_size_n,
510 "BLOCK_SIZE_K": block_size_k,
511 }
514def get_autotune_config():
515 """Get autotuning configurations for MoE kernel with reduced sizes for H20."""
516 return [
517 triton.Config(
518 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_stages=2, num_warps=4
519 ),
520 triton.Config(
521 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_stages=2, num_warps=4
522 ),
523 ]
526# ============================================================================
527# Kernel Invocation
528# ============================================================================
530_fp16_intermediate_buf = None
533def invoke_fused_moe(
534 x: torch.Tensor,
535 W1_q: torch.Tensor,
536 W2_q: torch.Tensor,
537 W3_q: Optional[torch.Tensor],
538 output: torch.Tensor,
539 W1_scales: torch.Tensor,
540 W1_zeros: Optional[torch.Tensor],
541 W2_scales: torch.Tensor,
542 W2_zeros: Optional[torch.Tensor],
543 W3_scales: Optional[torch.Tensor],
544 W3_zeros: Optional[torch.Tensor],
545 sorted_token_ids: torch.Tensor,
546 expert_ids: torch.Tensor,
547 num_tokens_post_padded: torch.Tensor,
548 topk_weights: torch.Tensor,
549 top_k: int,
550 quant_config: Any,
551 block_shape: List[int],
552) -> None:
553 """
554 Invoke the fused MoE kernel.
555 FP16 mode uses a dedicated SwiGLU path; quantized modes use fused_moe_kernel_gptq_awq.
556 """
557 num_tokens, hidden_dim = x.shape
558 num_experts, inter_dim, _ = W1_q.shape
559 num_valid_tokens = sorted_token_ids.shape[0]
561 K = hidden_dim
562 N = inter_dim
564 if topk_weights.dim() > 1:
565 topk_weights = topk_weights.view(-1)
567 BLOCK_SIZE_N = min(128, N)
568 BLOCK_SIZE_K = min(64, K)
569 grid = (num_valid_tokens,)
571 if not x.is_contiguous():
572 x = x.contiguous()
574 output.zero_()
576 # FP16 fast path — complete SwiGLU MoE: gate(W1) * up(W3), then W2 @ act
577 if quant_config.mode.value == "fp16" and W2_q is not None:
578 # FP16 SwiGLU mode requires all weights (W1, W2, optionally W3)
579 inter_buf = torch.empty(num_valid_tokens * N, dtype=x.dtype, device=x.device)
580 _W3 = W3_q if W3_q is not None else W1_q # use W1 if W3 missing
582 fused_moe_kernel_fp16_swiglu[grid](
583 x,
584 output,
585 W1_q, # gate
586 _W3, # up
587 W2_q, # down
588 topk_weights,
589 sorted_token_ids,
590 expert_ids,
591 num_tokens_post_padded,
592 inter_buf,
593 N=N,
594 K=K,
595 EM=num_valid_tokens,
596 num_valid_tokens=num_valid_tokens,
597 stride_am=x.stride(0),
598 stride_ak=x.stride(1),
599 stride_bn=W1_q.stride(1),
600 stride_bk=W1_q.stride(2),
601 stride_cm=output.stride(0),
602 stride_cn=output.stride(1),
603 stride_gate_e=W1_q.stride(0),
604 stride_up_e=_W3.stride(0),
605 stride_down_e=W2_q.stride(0),
606 stride_gate_n=W1_q.stride(1),
607 stride_gate_k=W1_q.stride(2),
608 stride_up_n=_W3.stride(1),
609 stride_up_k=_W3.stride(2),
610 stride_down_k=W2_q.stride(1),
611 stride_down_n=W2_q.stride(2),
612 stride_inter_m=N,
613 BLOCK_SIZE_K=BLOCK_SIZE_K,
614 top_k=top_k,
615 even_Ks=(K % BLOCK_SIZE_K) == 0,
616 )
617 return
619 # FP16 W1-only: use vectorized torch.mm as the reference implementation
620 # This is called when FP16 mode with W2_q=None reaches this function
621 # (weights were not quantized, so W1_scales is None)
622 if quant_config.mode.value == "fp16" and W2_q is None:
623 num_experts = W1_q.shape[0]
625 # topk_weights is already flattened at this point
626 # Vectorized approach: process each expert in batch using torch.matmul
627 for e in range(num_experts):
628 # Find all dispatch entries for expert e
629 mask = expert_ids == e
630 if not mask.any():
631 continue
633 indices = mask.nonzero(as_tuple=True)[0]
634 # Bounds check for padding
635 valid_mask = indices < num_valid_tokens
636 indices = indices[valid_mask]
638 # Skip if no valid entries
639 if indices.numel() == 0:
640 continue
642 # Get token indices and weights
643 token_indices = sorted_token_ids[indices]
644 weights_e = topk_weights[indices]
646 # Batch compute: W1[e] @ x[token_indices].T
647 # W1[e]: [n_out, k_in], x_e: [num_selections, k_in]
648 # Result: [n_out, num_selections]
649 x_e = x[token_indices] # [num_selections, k_in]
650 result = torch.matmul(W1_q[e], x_e.t()) # [n_out, num_selections]
652 # Apply weights and transpose: result.T * weights
653 # result.T: [num_selections, n_out], weights: [num_selections]
654 result = result.t() * weights_e.unsqueeze(1) # [num_selections, n_out]
656 # Use index_add for efficient accumulation (avoids Python loop)
657 output.index_add_(0, token_indices, result)
659 return
661 # Quantized path (W8A16 / W4A16) OR FP16 W1-only path
662 # W2_q is None means W1-only projection (quantized or FP16)
663 if W2_q is None:
664 # Determine if we should skip dequantization (FP16 mode with unit scales)
665 is_fp16_w1_only = (
666 quant_config.mode.value == "fp16"
667 and W1_q is not None
668 and W1_scales is not None
669 and W1_zeros is None
670 )
672 # For FP16 W1-only: skip INT8 offset (use_int8_w8a16=False)
673 # For quantized modes: use appropriate dequantization
674 kernel_use_int8 = quant_config.use_int8 and not is_fp16_w1_only
675 kernel_has_zp = quant_config.has_zero_point and not is_fp16_w1_only
677 # W1-only quantization path
678 fused_moe_kernel_gptq_awq[grid](
679 x,
680 W1_q,
681 output,
682 W1_scales,
683 W1_zeros if W1_zeros is not None else x.new_tensor([]),
684 topk_weights,
685 sorted_token_ids,
686 expert_ids,
687 num_tokens_post_padded,
688 N=N,
689 K=K,
690 EM=num_valid_tokens,
691 num_valid_tokens=num_valid_tokens,
692 stride_am=x.stride(0),
693 stride_ak=x.stride(1),
694 stride_be=W1_q.stride(0),
695 stride_bk=W1_q.stride(2),
696 stride_bn=W1_q.stride(1),
697 stride_cm=output.stride(0),
698 stride_cn=output.stride(1),
699 stride_bse=W1_scales.stride(0),
700 stride_bsk=W1_scales.stride(2),
701 stride_bsn=W1_scales.stride(1),
702 stride_bze=W1_zeros.stride(0) if W1_zeros is not None else 0,
703 stride_bzk=W1_zeros.stride(2) if W1_zeros is not None else 0,
704 stride_bzn=W1_zeros.stride(1) if W1_zeros is not None else 0,
705 group_size=quant_config.group_size,
706 BLOCK_SIZE_M=1,
707 BLOCK_SIZE_N=BLOCK_SIZE_N,
708 BLOCK_SIZE_K=BLOCK_SIZE_K,
709 GROUP_SIZE_M=1,
710 MUL_ROUTED_WEIGHT=True,
711 top_k=top_k,
712 compute_type=tl.float16,
713 has_zp=kernel_has_zp,
714 use_int4_w4a16=quant_config.use_int4,
715 use_int8_w8a16=kernel_use_int8,
716 even_Ks=(K % BLOCK_SIZE_K) == 0,
717 filter_expert=False,
718 )
719 else:
720 # W1 + W2 quantization path (SwiGLU)
721 fused_moe_kernel_gptq_awq[grid](
722 x,
723 W1_q,
724 output,
725 W1_scales,
726 W1_zeros if W1_zeros is not None else x.new_tensor([]),
727 topk_weights,
728 sorted_token_ids,
729 expert_ids,
730 num_tokens_post_padded,
731 N=N,
732 K=K,
733 EM=num_valid_tokens,
734 num_valid_tokens=num_valid_tokens,
735 stride_am=x.stride(0),
736 stride_ak=x.stride(1),
737 stride_be=W1_q.stride(0),
738 stride_bk=W1_q.stride(2),
739 stride_bn=W1_q.stride(1),
740 stride_cm=output.stride(0),
741 stride_cn=output.stride(1),
742 stride_bse=W1_scales.stride(0),
743 stride_bsk=W1_scales.stride(2),
744 stride_bsn=W1_scales.stride(1),
745 stride_bze=W1_zeros.stride(0) if W1_zeros is not None else 0,
746 stride_bzk=W1_zeros.stride(2) if W1_zeros is not None else 0,
747 stride_bzn=W1_zeros.stride(1) if W1_zeros is not None else 0,
748 group_size=quant_config.group_size,
749 BLOCK_SIZE_M=1,
750 BLOCK_SIZE_N=BLOCK_SIZE_N,
751 BLOCK_SIZE_K=BLOCK_SIZE_K,
752 GROUP_SIZE_M=1,
753 MUL_ROUTED_WEIGHT=True,
754 top_k=top_k,
755 compute_type=tl.float16,
756 has_zp=quant_config.has_zero_point,
757 use_int4_w4a16=quant_config.use_int4,
758 use_int8_w8a16=quant_config.use_int8,
759 even_Ks=(K % BLOCK_SIZE_K) == 0,
760 filter_expert=False,
761 )
764# ============================================================================
765# Main fused_moe Function
766# ============================================================================
769def fused_moe(
770 x: torch.Tensor,
771 w1: torch.Tensor,
772 w2: torch.Tensor,
773 w3: Optional[torch.Tensor] = None,
774 topk_weights: Optional[torch.Tensor] = None,
775 topk_ids: Optional[torch.Tensor] = None,
776 quant_config: QuantConfig = None,
777 num_experts: int = 8,
778 top_k: int = 2,
779 block_shape: Optional[List[int]] = None,
780 # Optional pre-quantized weights (from benchmark)
781 w1_q: Optional[torch.Tensor] = None,
782 w1_scales: Optional[torch.Tensor] = None,
783 w1_zeros: Optional[torch.Tensor] = None,
784 w2_q: Optional[torch.Tensor] = None,
785 w2_scales: Optional[torch.Tensor] = None,
786 w2_zeros: Optional[torch.Tensor] = None,
787 w3_q: Optional[torch.Tensor] = None,
788 w3_scales: Optional[torch.Tensor] = None,
789 w3_zeros: Optional[torch.Tensor] = None,
790) -> torch.Tensor:
791 """
792 Fused Mixture of Experts computation with quantization support.
794 This implements:
795 y = sum_i(topk_weights_i * FFN(experts_i(topk_ids_i)))
797 For SwiGLU MoE:
798 FFN(x) = Gate(x) * Up(x) = (silu(W1(x)) * W3(x)) @ W2
800 Args:
801 x: Input tensor of shape (batch_size, seq_len, hidden_dim) or (num_tokens, hidden_dim)
802 w1: First FFN layer weights (FP16) or can be pre-quantized (uint8)
803 w2: Second FFN layer weights (FP16) or can be pre-quantized (uint8)
804 w3: Optional gate weights for SwiGLU, shape (num_experts, hidden_dim, inter_dim)
805 topk_weights: Weights for top-k experts, shape (batch_size, seq_len, top_k)
806 topk_ids: Expert indices, shape (batch_size, seq_len, top_k)
807 quant_config: Quantization configuration
808 num_experts: Number of experts
809 top_k: Number of experts to select
810 block_shape: Block shape for block-wise quantization [block_n, block_k]
811 # Pre-quantized weights (if provided, skips quantization)
812 w1_q, w1_scales, w1_zeros: Pre-quantized W1 weights
813 w2_q, w2_scales, w2_zeros: Pre-quantized W2 weights
814 w3_q, w3_scales, w3_zeros: Pre-quantized W3 weights
816 Returns:
817 Output tensor of same shape as x
818 """
819 if quant_config is None:
820 quant_config = QuantConfig()
822 # Handle input shape
823 original_shape = x.shape
824 if len(x.shape) == 3:
825 x = x.view(-1, x.shape[-1]) # (B*S, H)
827 num_tokens = x.shape[0]
829 # Prepare routing information
830 if topk_weights is None or topk_ids is None:
831 # Create dummy routing for testing
832 topk_weights = (
833 torch.ones(num_tokens, top_k, device=x.device, dtype=x.dtype) / top_k
834 )
835 topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=x.device)
837 # Create dispatch arrays for MoE
838 # Each token has top_k expert selections, create entries for each (token, expert) pair
839 # sorted_token_ids: token index for each dispatch entry (repeated for each expert selection)
840 # expert_ids: expert index for each dispatch entry
842 # Create token indices: [0,0,1,1,...] where each token repeats top_k times
843 token_indices = torch.arange(num_tokens, device=x.device, dtype=torch.int64)
844 sorted_token_ids = (
845 token_indices.unsqueeze(1).expand(num_tokens, top_k).contiguous().view(-1)
846 )
848 # Expert IDs: [e0_0, e0_1, ..., e1_0, e1_1, ...]
849 flat_expert_ids = topk_ids.view(-1)
851 # Weights: [w0_0, w0_1, ..., w1_0, w1_1, ...]
852 flat_weights = topk_weights.view(-1)
854 # Sort by weight for efficient processing (optional, helps with cache locality)
855 sorted_indices = torch.argsort(flat_weights, dim=0, descending=True)
856 sorted_token_ids = sorted_token_ids[sorted_indices]
857 sorted_expert_ids = flat_expert_ids[sorted_indices]
858 sorted_weights = flat_weights[sorted_indices]
860 # Pad to block size
861 block_size_m = 32
862 num_tokens_post_padded = (
863 (num_tokens * top_k + block_size_m - 1) // block_size_m
864 ) * block_size_m
866 # Quantize weights if not pre-quantized
867 if w1_q is not None and w1_scales is not None:
868 # Use pre-quantized weights from benchmark
869 W1_q = w1_q.contiguous()
870 W1_scales = w1_scales.contiguous()
871 W1_zeros = w1_zeros.contiguous() if w1_zeros is not None else None
872 elif w1 is not None:
873 W1_q, W1_scales, W1_zeros = quantize_weights_moe(w1, num_experts, quant_config)
874 else:
875 raise ValueError("Either w1 or w1_q must be provided")
877 if w2_q is not None and w2_scales is not None:
878 W2_q = w2_q.contiguous()
879 W2_scales = w2_scales.contiguous()
880 W2_zeros = w2_zeros.contiguous() if w2_zeros is not None else None
881 elif w2 is not None:
882 W2_q, W2_scales, W2_zeros = quantize_weights_moe(w2, num_experts, quant_config)
883 else:
884 # W2 not provided, set to None for W1-only projection
885 W2_q = None
886 W2_scales = None
887 W2_zeros = None
889 if w3 is not None:
890 if w3_q is not None and w3_scales is not None:
891 W3_q = w3_q.contiguous()
892 W3_scales = w3_scales.contiguous()
893 W3_zeros = w3_zeros.contiguous() if w3_zeros is not None else None
894 else:
895 W3_q, W3_scales, W3_zeros = quantize_weights_moe(
896 w3, num_experts, quant_config
897 )
898 else:
899 W3_q, W3_scales, W3_zeros = None, None, None
901 # For FP16 W1-only mode, the weights are not quantized (quantize returns them as-is)
902 # W1_scales will be None, so invoke_fused_moe handles this case directly
903 # No need to create fake scales here
905 # Allocate output
906 # For W1-only projection (W2_q is None): output shape is (num_tokens, inter_dim)
907 # For SwiGLU (W2_q is not None): output shape is same as input (num_tokens, hidden_dim)
908 if W2_q is None and W1_q is not None:
909 # W1-only projection: output is (num_tokens, inter_dim)
910 num_experts_e, n_out, k_in = W1_q.shape
911 output = torch.zeros(num_tokens, n_out, dtype=x.dtype, device=x.device)
912 else:
913 output = torch.zeros_like(x)
915 # Default block shape
916 if block_shape is None:
917 block_shape = [128, 128]
919 # Invoke fused MoE kernel
920 invoke_fused_moe(
921 x,
922 W1_q,
923 W2_q,
924 W3_q,
925 output,
926 W1_scales,
927 W1_zeros,
928 W2_scales,
929 W2_zeros,
930 W3_scales,
931 W3_zeros,
932 sorted_token_ids,
933 sorted_expert_ids,
934 num_tokens_post_padded,
935 sorted_weights,
936 top_k,
937 quant_config,
938 block_shape,
939 )
941 # Reshape output
942 if len(original_shape) == 3:
943 output = output.view(original_shape)
945 return output
948# ============================================================================
949# FusedMoELinear Module
950# ============================================================================
953class FusedMoELinear(torch.nn.Module):
954 """
955 Fused MoE Linear layer with quantization support.
957 This module wraps the fused MoE computation for use in neural networks.
958 """
960 def __init__(
961 self,
962 hidden_dim: int,
963 inter_dim: int,
964 num_experts: int = 8,
965 top_k: int = 2,
966 quant_config: QuantConfig = None,
967 bias: bool = False,
968 ):
969 super().__init__()
971 self.hidden_dim = hidden_dim
972 self.inter_dim = inter_dim
973 self.num_experts = num_experts
974 self.top_k = top_k
975 self.quant_config = quant_config or QuantConfig()
977 # SwiGLU MoE weights
978 self.w1 = torch.nn.Parameter(
979 torch.randn(num_experts, inter_dim, hidden_dim, requires_grad=False)
980 )
981 self.w3 = torch.nn.Parameter(
982 torch.randn(num_experts, inter_dim, hidden_dim, requires_grad=False)
983 )
984 self.w2 = torch.nn.Parameter(
985 torch.randn(num_experts, hidden_dim, inter_dim, requires_grad=False)
986 )
988 self._packed = False
990 def pack(self):
991 """Prepare weights for quantized computation."""
992 self.W1_q, self.W1_scales, self.W1_zeros = quantize_weights_moe(
993 self.w1.data, self.num_experts, self.quant_config
994 )
995 self.W3_q, self.W3_scales, self.W3_zeros = quantize_weights_moe(
996 self.w3.data, self.num_experts, self.quant_config
997 )
998 self.W2_q, self.W2_scales, self.W2_zeros = quantize_weights_moe(
999 self.w2.data, self.num_experts, self.quant_config
1000 )
1001 self._packed = True
1003 def forward(
1004 self,
1005 x: torch.Tensor,
1006 topk_weights: Optional[torch.Tensor] = None,
1007 topk_ids: Optional[torch.Tensor] = None,
1008 ) -> torch.Tensor:
1009 """
1010 Forward pass for MoE.
1012 Args:
1013 x: Input tensor (B, S, H) or (T, H)
1014 topk_weights: Expert weights (B, S, K) or (T, K)
1015 topk_ids: Expert indices (B, S, K) or (T, K)
1017 Returns:
1018 Output tensor same shape as x
1019 """
1020 if not self._packed:
1021 self.pack()
1023 return fused_moe(
1024 x,
1025 self.w1,
1026 self.w2,
1027 self.w3,
1028 topk_weights,
1029 topk_ids,
1030 self.quant_config,
1031 self.num_experts,
1032 self.top_k,
1033 )
1035 def set_weights(self, w1: torch.Tensor, w3: torch.Tensor, w2: torch.Tensor):
1036 """Set weights from external source (e.g., model loading)."""
1037 self.w1.data = w1
1038 self.w3.data = w3
1039 self.w2.data = w2
1040 self._packed = False
1043# ============================================================================
1044# Exports
1045# ============================================================================
1047__all__ = [
1048 "fused_moe",
1049 "fused_moe_kernel_gptq_awq",
1050 "fused_moe_kernel_fp16_swiglu",
1051 "invoke_fused_moe",
1052 "FusedMoELinear",
1053 "QuantConfig",
1054 "QuantMode",
1055 "quantize_weights_moe",
1056 "prepare_moe_inputs",
1057 "get_num_experts",
1058 "get_default_config",
1059 "get_autotune_config",
1060]