Coverage for src/flag_gems/fused/fused_moe.py: 39%
921 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3#
4# Adapted from the vLLM project (https://github.com/vllm-project/vllm).
5# Source files under vllm/model_executor/layers/:
6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl
7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation
8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input
9# fused_moe/config.py – _get_config_dtype_str
10# quantization/utils/mxfp4_utils.py – dequant_mxfp4
11# quantization/utils/mxfp6_utils.py – dequant_mxfp6
12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE
15import functools
16import logging
17import os
18from enum import Enum
19from typing import Any, Optional
21import torch
22import torch.nn.functional as F
23import triton
24import triton.language as tl
25import yaml
27from flag_gems.fused.moe_align_block_size import moe_align_block_size
28from flag_gems.fused.moe_sum import moe_sum
29from flag_gems.runtime import device, torch_device_fn
30from flag_gems.utils import pointwise_dynamic
32logger = logging.getLogger(__name__)
34# OCP MX quantization helpers (requires amd-quark)
36OCP_MX_BLOCK_SIZE = 32
37# H100/Qwen-style MoE tuning thresholds. GEMM tile changes become reliably
38# positive from 4096 tokens; direct_sum is kept separate because it is a
39# reduction-layout decision even though it currently shares the same cutoff.
40MOE_GEMM_TUNING_MIN_TOKENS = 4096
41MOE_DIRECT_SUM_MIN_TOKENS = 4096
42_HALF_GEMM_TILE_M = 128
43_HALF_GEMM_TILE_K = 64
44_HALF_GEMM2_TILE_N = 256
45_PLAIN_HALF_CONFIG_DTYPES = ("fp16", "bf16")
48@functools.lru_cache(maxsize=1)
49def get_embedded_moe_configs():
50 config_path = os.path.join(
51 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.yaml"
52 )
53 if not os.path.exists(config_path):
54 return {}, {}
55 with open(config_path, "r") as f:
56 # JSON keys are strings, values are dicts where keys are M and values are configs
57 data = yaml.safe_load(f)
59 fallback = data.get("_FALLBACK", {})
61 # We need to convert the innermost keys (which are stringified integers for M) back to integers.
62 # Ensure we map the lists back to config dicts.
63 keys_order = [
64 "BLOCK_SIZE_M",
65 "BLOCK_SIZE_N",
66 "BLOCK_SIZE_K",
67 "GROUP_SIZE_M",
68 "num_warps",
69 "num_stages",
70 ]
71 parsed_data = {}
72 for dev, configs in data.items():
73 if dev == "_FALLBACK":
74 continue
75 parsed_data[dev] = {}
76 for k, m_dict in configs.items():
77 parsed_dict = {}
78 for m, v in m_dict.items():
79 if isinstance(v, list):
80 parsed_dict[int(m)] = dict(zip(keys_order, v))
81 else:
82 parsed_dict[int(m)] = v
83 parsed_data[dev][k] = parsed_dict
85 return parsed_data, fallback
88def dequant_mxfp4(
89 x: torch.Tensor,
90 scale: torch.Tensor,
91 float_dtype: torch.dtype,
92) -> torch.Tensor:
93 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4."""
94 try:
95 from quark.torch.kernel import mx
96 except ImportError as err:
97 raise ImportError("amd-quark is required for MX-FP4") from err
99 return mx.dq_mxfp4(x, scale, float_dtype)
102def dequant_mxfp6(
103 x: torch.Tensor,
104 scale: torch.Tensor,
105 float_dtype: torch.dtype,
106 quant_dtype: str,
107) -> torch.Tensor:
108 """Dequantize MXFP6 tensor via quark hw_emulation."""
109 try:
110 from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
111 dequantize_fp4_fp6_per_group,
112 )
113 from quark.torch.utils.pack import create_pack_method
114 except ImportError as err:
115 raise ImportError("amd-quark is required for MX-FP6") from err
117 pack_method = create_pack_method(None, dtype=quant_dtype)
118 unpacked_x = pack_method.unpack(x, reorder=False)
120 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
122 return dequantize_fp4_fp6_per_group(
123 unpacked_x,
124 scale,
125 axis=-1,
126 group_size=OCP_MX_BLOCK_SIZE,
127 quant_dtype=quant_dtype,
128 ).to(float_dtype)
131# Activation quantization helpers
134@functools.lru_cache(maxsize=1)
135def _get_device_name() -> str:
136 """Return the normalised device name (spaces replaced by underscores).
138 Matches the naming convention used by vLLM for its per-device config files.
139 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture).
140 """
141 try:
142 name = torch_device_fn.get_device_name().replace(" ", "_")
143 except AttributeError:
144 name = device.name
145 # Normalise the H200 product family to a single key, following vLLM.
146 if "H200" in name.split("_"):
147 name = "NVIDIA_H200"
148 # H800 has the same SM 9.0 as H100; use H100 configs as fallback.
149 embedded_configs, fallback_mapping = get_embedded_moe_configs()
150 if name in embedded_configs:
151 return name
152 # Fallback mapping for devices whose tuning profiles are equivalent.
153 fallback = fallback_mapping.get(name)
154 if fallback and fallback in embedded_configs:
155 logger.info("Device %s not in config table, falling back to %s", name, fallback)
156 return fallback
157 return name
160def get_moe_configs(
161 E: int,
162 N: int,
163 dtype: str | None,
164 block_n: int | None = None,
165 block_k: int | None = None,
166) -> dict[int, Any] | None:
167 """
168 Return optimized configurations for the fused MoE kernel.
170 Looks up pre-tuned configs from the embedded table (ported from vLLM)
171 for the current GPU device. Returns None if no matching config is found.
172 """
173 device_name = _get_device_name()
174 embedded_configs, _ = get_embedded_moe_configs()
175 device_table = embedded_configs.get(device_name)
176 if device_table is None:
177 logger.warning(
178 "No embedded MoE configs for device %s. Will use default config.",
179 device_name,
180 )
181 return None
183 _block_n = block_n if block_n else 0
184 _block_k = block_k if block_k else 0
185 key = f"{E},{N},{dtype},{_block_n},{_block_k}"
186 configs = device_table.get(key)
187 if configs is not None:
188 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key)
189 return configs
190 logger.warning(
191 "No embedded MoE config for device=%s, key=%s. Will use default config.",
192 device_name,
193 key,
194 )
195 return None
198def try_get_optimal_moe_config(
199 w1_shape: tuple[int, ...],
200 w2_shape: tuple[int, ...],
201 top_k: int,
202 dtype: str | None,
203 M: int,
204 E: int,
205 block_shape: list[int] | None = None,
206 gemm_stage: str = "gemm1",
207 enable_gemm_fast_path: bool = False,
208 return_is_embedded: bool = False,
209) -> dict[str, Any] | tuple[dict[str, Any], bool]:
210 if gemm_stage not in ("gemm1", "gemm2"):
211 raise ValueError(f"Unsupported MoE GEMM stage: {gemm_stage}")
212 _, _, config_n = w2_shape
213 if dtype == "int4_w4a16":
214 config_n = config_n * 2
215 block_n = block_shape[0] if block_shape else 0
216 block_k = block_shape[1] if block_shape else 0
217 configs = get_moe_configs(E, config_n, dtype, block_n, block_k)
218 if configs:
219 config = configs[min(configs.keys(), key=lambda x: abs(x - M))].copy()
220 is_embedded = True
221 else:
222 if gemm_stage == "gemm1":
223 _, N, K = w1_shape
224 else:
225 _, N, K = w2_shape
226 config = get_default_config(
227 M,
228 E,
229 N,
230 K,
231 top_k,
232 dtype,
233 block_shape,
234 gemm_stage=gemm_stage,
235 enable_gemm_fast_path=enable_gemm_fast_path,
236 )
237 is_embedded = False
238 if return_is_embedded:
239 return config, is_embedded
240 return config
243def _get_config_quant_dtype(
244 use_fp8_w8a8: bool,
245 use_int8_w8a8: bool,
246 ocp_mx_scheme: str | None,
247) -> None | torch.dtype | str:
248 """Map quantization flags to the corresponding dtype."""
249 if use_fp8_w8a8:
250 return torch.float8_e4m3fn
251 elif use_int8_w8a8:
252 return torch.int8
253 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
254 return "mxfp4"
255 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
256 return "mxfp6_e3m2"
257 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
258 return "mxfp6_e2m3"
259 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
260 return torch.bfloat16
261 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
262 return torch.float8_e4m3fn
264 return None
267def get_moe_wna16_block_config(
268 config: dict[str, int],
269 use_moe_wna16_cuda: bool,
270 num_valid_tokens: int,
271 size_k: int,
272 size_n: int,
273 num_experts: int,
274 group_size: int,
275 real_top_k: int,
276 block_size_m: int,
277):
278 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
279 return {}
280 if not use_moe_wna16_cuda:
281 if num_valid_tokens // real_top_k == 1:
282 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
283 else:
284 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
285 else:
286 block_size_n = 128
287 block_size_k = 128
288 if block_size_k <= group_size:
289 block_size_k = group_size
291 num_n_blocks = size_k // block_size_k
292 num_k_blocks = size_n // block_size_k
293 num_m_blocks = (
294 num_valid_tokens + block_size_m - 1
295 ) / block_size_m + num_experts
296 if num_valid_tokens // real_top_k <= block_size_m:
297 num_m_blocks = min(num_m_blocks, num_valid_tokens)
298 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
300 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256:
301 block_size_k = 256
302 num_blocks = num_blocks // (256 // block_size_k)
304 if (
305 num_m_blocks <= 16
306 and size_k % (block_size_k * 2) == 0
307 and size_k % (block_size_k * 2) == 0
308 and block_size_k <= 512
309 and num_blocks >= 512
310 ):
311 block_size_k = block_size_k * 2
312 num_blocks = num_blocks // 2
314 if num_blocks > 1024:
315 block_size_n = 256
316 num_n_blocks = num_n_blocks // 2
317 num_blocks = num_blocks // 2
319 if size_n <= 1024 and num_blocks >= 1024:
320 block_size_n = 1024
322 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
324 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
327def get_default_config(
328 M: int,
329 E: int,
330 N: int,
331 K: int,
332 topk: int,
333 dtype: str | None,
334 block_shape: list[int] | None = None,
335 gemm_stage: str = "gemm1",
336 enable_gemm_fast_path: bool = False,
337) -> dict[str, Any]:
338 """Default Triton config for fused MoE kernel.
340 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100.
341 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each
342 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical.
343 """
344 is_fp8_blockwise = dtype == "fp8_w8a8" and block_shape is not None
345 if gemm_stage not in ("gemm1", "gemm2"):
346 raise ValueError(f"Unsupported MoE GEMM stage: {gemm_stage}")
348 if is_fp8_blockwise:
349 avg_tokens_per_expert = M * max(topk, 1) // max(E, 1)
350 is_large_m = M >= 16384
351 if avg_tokens_per_expert <= 16:
352 block_m = 16
353 elif avg_tokens_per_expert <= 32:
354 block_m = 32
355 elif avg_tokens_per_expert <= 64 or not is_large_m:
356 block_m = 64
357 else:
358 block_m = 128
360 config = {
361 "BLOCK_SIZE_M": block_m,
362 "BLOCK_SIZE_N": block_shape[0],
363 "BLOCK_SIZE_K": block_shape[1],
364 "GROUP_SIZE_M": 8 if (is_large_m and avg_tokens_per_expert > 16) else 1,
365 "num_warps": 8 if (is_large_m and block_m > 32) else 4,
366 "num_stages": 4 if M >= 1024 else 3,
367 "SWAP_AB": False,
368 }
369 elif dtype in _PLAIN_HALF_CONFIG_DTYPES:
370 # Routed rows per expert drives block_m. Each token contributes topk
371 # rows to the expert-sorted GEMM input, so M * topk / E is the relevant
372 # density for high-expert-count MoE routing.
373 routed_tokens_per_expert = M * max(topk, 1) // max(E, 1)
374 tokens_per_expert = M // max(E, 1)
376 if routed_tokens_per_expert <= 16:
377 block_m = 16
378 elif routed_tokens_per_expert <= 64:
379 block_m = 64
380 else:
381 block_m = 128
383 if tokens_per_expert > 128:
384 group_m = 16
385 elif tokens_per_expert > 32:
386 group_m = 8
387 else:
388 group_m = 1
390 block_k = 128 if M <= 64 else 64
392 if N >= 4096:
393 block_n = 128 if M <= 128 else 256
394 else:
395 block_n = 64 if M <= 64 else 128
397 can_use_gemm_fast_path = (
398 enable_gemm_fast_path
399 and M >= MOE_GEMM_TUNING_MIN_TOKENS
400 and block_m == _HALF_GEMM_TILE_M
401 and block_k == _HALF_GEMM_TILE_K
402 )
404 use_gemm2_fast_path = (
405 gemm_stage == "gemm2"
406 and can_use_gemm_fast_path
407 and N % _HALF_GEMM2_TILE_N == 0
408 )
409 use_gemm1_fast_path = (
410 gemm_stage == "gemm1" and can_use_gemm_fast_path and N % block_n == 0
411 )
413 if gemm_stage == "gemm2" and enable_gemm_fast_path:
414 block_n = (
415 _HALF_GEMM2_TILE_N if use_gemm2_fast_path else (64 if M <= 64 else 128)
416 )
418 # Prefer 4 warps for small tiles; only use 8 for large M
419 num_warps = 4 if M <= 128 else 8
420 num_stages = 3
422 if use_gemm1_fast_path:
423 group_m = 1
424 num_stages = 4
425 elif use_gemm2_fast_path:
426 group_m = 2
427 num_stages = 4
429 smem_per_stage = (block_m * block_k + block_k * block_n) * 2
430 while num_stages > 2 and smem_per_stage * num_stages > 200_000:
431 num_stages -= 1
433 config = {
434 "BLOCK_SIZE_M": block_m,
435 "BLOCK_SIZE_N": block_n,
436 "BLOCK_SIZE_K": block_k,
437 "GROUP_SIZE_M": group_m,
438 "num_warps": num_warps,
439 "num_stages": num_stages,
440 }
441 if use_gemm1_fast_path:
442 config["PAIR_GATE_UP_DOT"] = True
443 else:
444 tokens_per_expert = M // max(E, 1)
446 if tokens_per_expert <= 2:
447 block_m = 16
448 elif tokens_per_expert <= 4:
449 block_m = 32
450 elif tokens_per_expert <= 16:
451 block_m = 64
452 else:
453 block_m = 128
455 # Tile sizing
456 if N >= 4096:
457 block_n = 128 if M <= 128 else 256
458 elif N >= 1024:
459 block_n = 64 if M <= 64 else 128
460 else:
461 block_n = 64 if M <= 64 else 128
463 if dtype == "fp8_w8a8":
464 block_k = 128
465 elif M <= 64:
466 block_k = 128
467 else:
468 block_k = 64
470 if tokens_per_expert > 128:
471 group_m = 16
472 elif tokens_per_expert > 32:
473 group_m = 8
474 else:
475 group_m = 1
477 # Prefer 4 warps for small tiles; only use 8 for large M
478 num_warps = 4 if M <= 128 else 8
479 num_stages = 3
481 smem_per_stage = (block_m * block_k + block_k * block_n) * 2
482 while num_stages > 2 and smem_per_stage * num_stages > 200_000:
483 num_stages -= 1
485 config = {
486 "BLOCK_SIZE_M": block_m,
487 "BLOCK_SIZE_N": block_n,
488 "BLOCK_SIZE_K": block_k,
489 "GROUP_SIZE_M": group_m,
490 "num_warps": num_warps,
491 "num_stages": num_stages,
492 }
493 return config
496def _get_config_dtype_str(
497 dtype: Optional[torch.dtype] = None,
498 use_fp8_w8a8: bool = False,
499 use_fp8_w8a16: bool = False,
500 use_int8_w8a16: bool = False,
501 use_int4_w4a16: bool = False,
502 ocp_mx_scheme: str | None = None,
503) -> str | None:
504 """Return dtype string for kernel config lookup."""
505 if use_fp8_w8a8:
506 return "fp8_w8a8"
507 elif use_fp8_w8a16:
508 return "fp8_w8a16"
509 elif use_int8_w8a16:
510 return "int8_w8a16"
511 elif use_int4_w4a16:
512 return "int4_w4a16"
513 elif ocp_mx_scheme is not None:
514 return None
515 elif dtype == torch.float16:
516 return "fp16"
517 elif dtype == torch.bfloat16:
518 return "bf16"
519 elif dtype == torch.float:
520 return "float32"
521 return None
524# MoE activation enum
527class MoEActivation(Enum):
528 """Activation functions for MoE layers."""
530 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d]
531 SILU = "silu"
532 GELU = "gelu"
533 RELU2 = "relu2"
534 SWIGLUOAI = "swigluoai"
535 SWIGLUSTEP = "swiglustep"
537 # Non-gated: input [..., d] -> output [..., d]
538 SILU_NO_MUL = "silu_no_mul"
539 GELU_NO_MUL = "gelu_no_mul"
540 RELU2_NO_MUL = "relu2_no_mul"
542 @property
543 def is_gated(self) -> bool:
544 return not self.value.endswith("_no_mul")
546 def without_mul(self) -> "MoEActivation":
547 """Return the non-gated variant."""
548 _without_mul: dict[MoEActivation, MoEActivation] = {
549 MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
550 MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
551 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
552 }
553 return _without_mul.get(self, self)
555 @classmethod
556 def from_str(cls, s: str) -> "MoEActivation":
557 for member in cls:
558 if member.value == s:
559 return member
560 valid = [m.value for m in cls]
561 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
563 @staticmethod
564 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int:
565 """Return N for non-gated, N // 2 for gated activations."""
566 return N if not activation.is_gated else N // 2
569def apply_moe_activation(
570 activation: MoEActivation,
571 output: torch.Tensor,
572 input: torch.Tensor,
573) -> torch.Tensor:
574 """Apply MoE activation (pure PyTorch / FlagGems Triton)."""
575 assert input.dim() == 2, "Input must be 2D"
576 assert output.dim() == 2, "Output must be 2D"
577 if activation.is_gated:
578 assert output.size(-1) * 2 == input.size(-1), (
579 f"{activation.value} expects 2x ratio: "
580 f"{output.size(-1) * 2} vs {input.size(-1)}"
581 )
582 else:
583 assert output.size(-1) == input.size(-1), (
584 f"{activation.value} expects equal sizes: "
585 f"{output.size(-1)} vs {input.size(-1)}"
586 )
588 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI):
589 N = output.size(-1)
590 x, y = input[:, :N], input[:, N:]
591 _silu_and_mul_kernel(x, y, out0=output)
592 elif activation == MoEActivation.GELU:
593 N = output.size(-1)
594 gate, up = input[:, :N], input[:, N:]
595 output.copy_(F.gelu(gate) * up)
596 elif activation == MoEActivation.SWIGLUSTEP:
597 N = output.size(-1)
598 gate, up = input[:, :N], input[:, N:]
599 output.copy_(torch.sigmoid(gate) * up)
600 elif activation == MoEActivation.RELU2:
601 N = output.size(-1)
602 gate, up = input[:, :N], input[:, N:]
603 output.copy_(F.relu(gate).square() * up)
605 elif activation == MoEActivation.SILU_NO_MUL:
606 output.copy_(F.silu(input))
607 elif activation == MoEActivation.GELU_NO_MUL:
608 output.copy_(F.gelu(input))
609 elif activation == MoEActivation.RELU2_NO_MUL:
610 F.relu(input, inplace=True)
611 torch.square(input, out=output)
612 else:
613 raise ValueError(f"Unsupported FusedMoe activation: {activation}")
615 return output
618def _fp8_quantize(
619 A: torch.Tensor,
620 A_scale: Optional[torch.Tensor],
621 per_act_token: bool,
622 block_shape: Optional[list[int]] = None,
623) -> tuple[torch.Tensor, torch.Tensor]:
624 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise."""
625 fp8_dtype = torch.float8_e4m3fn
626 finfo = torch.finfo(fp8_dtype)
627 fp8_max = finfo.max
628 fp8_min = finfo.min
629 eps = 1e-10
631 if block_shape is not None:
632 assert not per_act_token
633 assert len(block_shape) == 2
634 block_k = block_shape[1]
635 assert A.size(-1) % block_k == 0
636 if A.ndim == 2 and A.stride(-1) == 1:
637 from flag_gems.ops.per_token_group_quant_fp8 import (
638 per_token_group_quant_fp8,
639 )
641 return per_token_group_quant_fp8(
642 A,
643 group_size=block_k,
644 eps=eps,
645 dtype=fp8_dtype,
646 column_major_scales=False,
647 scale_ue8m0=False,
648 )
649 orig_shape = A.shape
650 A_flat = A.reshape(-1, A.size(-1))
651 M, K = A_flat.shape
652 A_groups = A_flat.reshape(M * (K // block_k), block_k)
653 amax = (
654 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
655 )
656 scale = amax / fp8_max
657 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
658 A_q = A_q.reshape(orig_shape)
659 scale = scale.reshape(M, K // block_k)
660 return A_q, scale
662 elif per_act_token:
663 A_flat = A.reshape(-1, A.size(-1))
664 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
665 scale = amax / fp8_max
666 min_scale = torch.tensor(
667 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device
668 )
669 scale = scale.clamp(min=min_scale)
670 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
671 A_q = A_q.reshape(A.shape)
672 scale = scale.reshape(A.shape[:-1] + (1,))
673 return A_q, scale
675 else:
676 if A_scale is not None:
677 scale = (
678 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
679 )
680 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
681 return A_q, A_scale
682 else:
683 amax = A.abs().amax().clamp(min=eps).to(torch.float32)
684 scale = amax / fp8_max
685 iscale = 1.0 / scale
686 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype)
687 return A_q, scale.view(1)
690def _int8_quantize(
691 A: torch.Tensor,
692 A_scale: Optional[torch.Tensor],
693 per_act_token: bool,
694 block_shape: Optional[list[int]] = None,
695) -> tuple[torch.Tensor, torch.Tensor]:
696 """INT8 quantization: per-tensor, per-token, or block-wise."""
697 iinfo = torch.iinfo(torch.int8)
698 int8_max = iinfo.max
699 int8_min = iinfo.min
700 eps = 1e-10
702 if block_shape is not None:
703 assert not per_act_token
704 assert len(block_shape) == 2
705 block_k = block_shape[1]
706 assert A.size(-1) % block_k == 0
707 orig_shape = A.shape
708 A_flat = A.reshape(-1, A.size(-1))
709 M, K = A_flat.shape
710 A_groups = A_flat.reshape(M * (K // block_k), block_k)
711 amax = (
712 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
713 )
714 scale = amax / int8_max
715 A_q = (
716 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
717 )
718 A_q = A_q.reshape(orig_shape)
719 scale = scale.reshape(M, K // block_k)
720 return A_q, scale
722 elif per_act_token:
723 A_flat = A.reshape(-1, A.size(-1))
724 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
725 scale = amax / int8_max
726 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
727 A_q = A_q.reshape(A.shape)
728 scale = scale.reshape(A.shape[:-1] + (1,))
729 return A_q, scale
731 else:
732 assert A_scale is not None, "int8 per-tensor requires A_scale"
733 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
734 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
735 return A_q, A_scale
738def moe_kernel_quantize_input(
739 A: torch.Tensor,
740 A_scale: Optional[torch.Tensor],
741 quant_dtype: None | torch.dtype | str,
742 per_act_token_quant: bool,
743 block_shape: Optional[list[int]] = None,
744 ocp_mx_scheme: str | None = None,
745) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
746 """Quantize MoE input activations before GEMM."""
747 if ocp_mx_scheme is not None:
748 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
749 pass
750 elif ocp_mx_scheme.endswith("a_fp8"):
751 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False)
752 A = (qA.float() * qA_scale.float()).to(A.dtype)
753 return A, None
755 if quant_dtype is None:
756 return A, A_scale
757 elif quant_dtype == torch.float8_e4m3fn:
758 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
759 elif quant_dtype == torch.int8:
760 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
761 else:
762 return A, A_scale
765def _ensure_block_size_k_divisible(
766 size_k: int, block_size_k: int, group_size: int
767) -> int:
768 """Find largest block_size_k that divides size_k and is divisible by group_size."""
769 if size_k % block_size_k == 0 and block_size_k % group_size == 0:
770 return block_size_k
772 max_search = min(block_size_k, size_k)
773 start = (max_search // group_size) * group_size
774 for candidate in range(start, group_size - 1, -group_size):
775 if size_k % candidate == 0:
776 return candidate
778 if size_k % group_size == 0:
779 return group_size
781 return size_k
784@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
785@triton.jit
786def _silu_and_mul_kernel(x, y):
787 x_fp32 = x.to(tl.float32)
788 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
789 return x_silu * y
792@triton.jit
793def write_zeros_to_output(
794 c_ptr,
795 stride_cm,
796 stride_cn,
797 pid_n,
798 N,
799 offs_token,
800 token_mask,
801 BLOCK_SIZE_M,
802 BLOCK_SIZE_N,
803 compute_type,
804):
805 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
806 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
807 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
808 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
809 tl.store(c_ptrs, accumulator, mask=c_mask)
812@triton.jit
813def fused_moe_kernel_gptq_awq(
814 # Pointers to matrices
815 a_ptr,
816 b_ptr,
817 c_ptr,
818 b_scale_ptr,
819 b_zp_ptr,
820 topk_weights_ptr,
821 sorted_token_ids_ptr,
822 expert_ids_ptr,
823 num_tokens_post_padded_ptr,
824 # Matrix dimensions
825 N: tl.constexpr,
826 K: tl.constexpr,
827 EM,
828 num_valid_tokens,
829 # The stride variables represent how much to increase the ptr by when
830 # moving by 1 element in a particular dimension. E.g. `stride_am` is
831 # how much to increase `a_ptr` by to get the element one row down
832 # (A has M rows).
833 stride_am,
834 stride_ak,
835 stride_be,
836 stride_bk,
837 stride_bn,
838 stride_cm,
839 stride_cn,
840 stride_bse,
841 stride_bsk,
842 stride_bsn,
843 stride_bze,
844 stride_bzk,
845 stride_bzn,
846 block_k_diviable: tl.constexpr,
847 group_size: tl.constexpr,
848 # Meta-parameters
849 BLOCK_SIZE_M: tl.constexpr,
850 BLOCK_SIZE_N: tl.constexpr,
851 BLOCK_SIZE_K: tl.constexpr,
852 GROUP_SIZE_M: tl.constexpr,
853 SPLIT_K: tl.constexpr,
854 MUL_ROUTED_WEIGHT: tl.constexpr,
855 top_k: tl.constexpr,
856 compute_type: tl.constexpr,
857 has_zp: tl.constexpr,
858 use_int4_w4a16: tl.constexpr,
859 use_int8_w8a16: tl.constexpr,
860):
861 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights."""
862 # Map pid to C block (grouped ordering for L2 reuse)
863 pid = tl.program_id(axis=0)
864 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
865 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
866 num_pid_in_group = GROUP_SIZE_M * num_pid_n
867 group_id = pid // num_pid_in_group
868 first_pid_m = group_id * GROUP_SIZE_M
869 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
870 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
871 pid_n = (pid % num_pid_in_group) // group_size_m
873 # Create pointers for first blocks of A and B
874 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
875 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
876 return
877 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
878 # Cast to int64 to prevent overflow in stride*offset products
879 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
880 token_mask = offs_token < num_valid_tokens
882 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
883 if off_experts == -1:
884 # -----------------------------------------------------------
885 # Write back zeros to the output when the expert is not
886 # in the current expert parallel rank.
887 write_zeros_to_output(
888 c_ptr,
889 stride_cm,
890 stride_cn,
891 pid_n,
892 N,
893 offs_token,
894 token_mask,
895 BLOCK_SIZE_M,
896 BLOCK_SIZE_N,
897 compute_type,
898 )
899 return
901 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
902 offs_k = tl.arange(0, BLOCK_SIZE_K)
903 a_ptrs = a_ptr + (
904 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
905 )
907 if use_int4_w4a16:
908 b_ptrs = (
909 b_ptr
910 + off_experts * stride_be
911 + (offs_k[:, None] // 2) * stride_bk
912 + offs_bn[None, :] * stride_bn
913 )
914 b_shifter = (offs_k[:, None] % 2) * 4
915 elif use_int8_w8a16:
916 b_ptrs = (
917 b_ptr
918 + off_experts * stride_be
919 + offs_k[:, None] * stride_bk
920 + offs_bn[None, :] * stride_bn
921 )
923 if not has_zp and use_int4_w4a16:
924 b_zp_num = 8
925 if not has_zp and use_int8_w8a16:
926 b_zp_num = 128
927 elif has_zp and use_int4_w4a16:
928 b_zp_shifter = (offs_bn[None, :] % 2) * 4
930 # Accumulate C block in fp32
931 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
932 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
933 if not block_k_diviable:
934 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
935 k_other = 0.0
936 else:
937 k_mask = None
938 k_other = None
940 a = tl.load(
941 a_ptrs,
942 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
943 other=0.0,
944 )
945 b = tl.load(b_ptrs)
946 if use_int4_w4a16:
947 b = (b >> b_shifter) & 0xF
949 b_scale_ptrs = (
950 b_scale_ptr
951 + off_experts * stride_bse
952 + offs_bn[None, :] * stride_bsn
953 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
954 )
955 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
956 b_scale = b_scale.to(tl.float32)
958 if has_zp and use_int4_w4a16:
959 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
960 b_zp_ptrs = (
961 b_zp_ptr
962 + off_experts * stride_bze
963 + (offs_bn[None, :] // 2) * stride_bzn
964 + offs_k_true * stride_bzk
965 )
966 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
967 b_zp = (b_zp >> b_zp_shifter) & 0xF
968 b_zp = b_zp.to(tl.float32)
969 elif has_zp and use_int8_w8a16:
970 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
971 b_zp_ptrs = (
972 b_zp_ptr
973 + off_experts * stride_bze
974 + offs_bn[None, :] * stride_bzn
975 + offs_k_true * stride_bzk
976 )
977 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
978 b_zp = b_zp.to(tl.float32)
980 if has_zp:
981 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
982 else:
983 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
984 accumulator = tl.dot(a, b, acc=accumulator)
986 a_ptrs += BLOCK_SIZE_K * stride_ak
987 if use_int4_w4a16:
988 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
989 else:
990 b_ptrs += BLOCK_SIZE_K * stride_bk
992 if MUL_ROUTED_WEIGHT:
993 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
994 accumulator = accumulator * moe_weight[:, None]
996 accumulator = accumulator.to(compute_type)
997 # Write back output
998 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
999 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
1000 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
1001 tl.store(c_ptrs, accumulator, mask=c_mask)
1004@triton.jit
1005def fused_moe_kernel(
1006 # Pointers to matrices
1007 a_ptr,
1008 b_ptr,
1009 c_ptr,
1010 b_bias_ptr,
1011 a_scale_ptr,
1012 b_scale_ptr,
1013 topk_weights_ptr,
1014 sorted_token_ids_ptr,
1015 expert_ids_ptr,
1016 num_tokens_post_padded_ptr,
1017 # Matrix dimensions
1018 N,
1019 K,
1020 EM,
1021 num_valid_tokens,
1022 stride_am,
1023 stride_ak,
1024 stride_be,
1025 stride_bk,
1026 stride_bn,
1027 stride_cm,
1028 stride_cn,
1029 stride_asm,
1030 stride_ask,
1031 stride_bse,
1032 stride_bsk,
1033 stride_bsn,
1034 stride_bbe, # bias expert stride
1035 stride_bbn, # bias N stride
1036 # Block size for block-wise quantization
1037 group_n: tl.constexpr,
1038 group_k: tl.constexpr,
1039 naive_block_assignment: tl.constexpr,
1040 # Meta-parameters
1041 BLOCK_SIZE_M: tl.constexpr,
1042 BLOCK_SIZE_N: tl.constexpr,
1043 BLOCK_SIZE_K: tl.constexpr,
1044 GROUP_SIZE_M: tl.constexpr,
1045 SPLIT_K: tl.constexpr,
1046 MUL_ROUTED_WEIGHT: tl.constexpr,
1047 top_k: tl.constexpr,
1048 compute_type: tl.constexpr,
1049 use_fp8_w8a8: tl.constexpr,
1050 use_int8_w8a8: tl.constexpr,
1051 use_int8_w8a16: tl.constexpr,
1052 per_channel_quant: tl.constexpr,
1053 HAS_BIAS: tl.constexpr,
1054 SWAP_AB: tl.constexpr,
1055 K_DIVISIBLE_BY_BLOCK_K: tl.constexpr,
1056 N_DIVISIBLE_BY_BLOCK_N: tl.constexpr,
1057 PAIR_GATE_UP_DOT: tl.constexpr,
1058 DIRECT_SUM: tl.constexpr,
1059 OUT_TOP_K: tl.constexpr,
1060 FUSE_SILU: tl.constexpr,
1061):
1062 """Fused MoE kernel: token × expert GEMM with quantization support and optional SiLU fusion."""
1063 # Map pid to C block (grouped ordering for L2 reuse)
1064 pid = tl.program_id(axis=0)
1065 # Adjust N for FUSE_SILU. If fused, the actual output dimension is N // 2
1066 N_out = N // 2 if FUSE_SILU else N
1067 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
1068 num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N)
1069 num_pid_in_group = GROUP_SIZE_M * num_pid_n
1070 group_id = pid // num_pid_in_group
1071 first_pid_m = group_id * GROUP_SIZE_M
1072 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
1073 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
1074 pid_n = (pid % num_pid_in_group) // group_size_m
1076 # Create pointers for first blocks of A and B
1077 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
1078 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
1079 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
1080 return
1081 offs_token_id = pid_m * BLOCK_SIZE_M + offs
1082 if not naive_block_assignment:
1083 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
1084 else:
1085 offs_token = tl.where(
1086 offs == 0,
1087 pid_m, # first element = pid_m
1088 num_valid_tokens, # remaining elements = constant
1089 )
1090 offs_token = offs_token.to(tl.int64) # prevent int32 overflow
1092 token_mask = offs_token < num_valid_tokens
1094 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
1096 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
1097 if not N_DIVISIBLE_BY_BLOCK_N:
1098 offs_bn = offs_bn % N_out
1099 offs_k = tl.arange(0, BLOCK_SIZE_K)
1100 a_base = a_ptr + (offs_token[:, None] // top_k * stride_am)
1102 if FUSE_SILU and PAIR_GATE_UP_DOT:
1103 if off_experts == -1:
1104 write_zeros_to_output(
1105 c_ptr,
1106 stride_cm,
1107 stride_cn,
1108 pid_n,
1109 N_out,
1110 offs_token,
1111 token_mask,
1112 BLOCK_SIZE_M,
1113 BLOCK_SIZE_N,
1114 compute_type,
1115 )
1116 return
1118 offs_pair = tl.arange(0, BLOCK_SIZE_N * 2).to(tl.int64)
1119 offs_pair_bn = tl.where(
1120 offs_pair < BLOCK_SIZE_N,
1121 pid_n * BLOCK_SIZE_N + offs_pair,
1122 N_out + pid_n * BLOCK_SIZE_N + offs_pair - BLOCK_SIZE_N,
1123 )
1124 a_ptrs = a_base + offs_k[None, :] * stride_ak
1125 b_pair_ptrs = (
1126 b_ptr
1127 + off_experts * stride_be
1128 + (offs_k[:, None] * stride_bk + offs_pair_bn[None, :] * stride_bn)
1129 )
1130 pair_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N * 2), dtype=tl.float32)
1131 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1132 if K_DIVISIBLE_BY_BLOCK_K:
1133 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
1134 if N_DIVISIBLE_BY_BLOCK_N:
1135 b_pair = tl.load(b_pair_ptrs)
1136 else:
1137 b_pair = tl.load(
1138 b_pair_ptrs, mask=offs_pair_bn[None, :] < N, other=0.0
1139 )
1140 else:
1141 k_remaining = K - k * BLOCK_SIZE_K
1142 a = tl.load(
1143 a_ptrs,
1144 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
1145 other=0.0,
1146 )
1147 b_pair = tl.load(
1148 b_pair_ptrs,
1149 mask=(offs_k[:, None] < k_remaining) & (offs_pair_bn[None, :] < N),
1150 other=0.0,
1151 )
1152 pair_acc += tl.dot(a, b_pair)
1153 a_ptrs += BLOCK_SIZE_K * stride_ak
1154 b_pair_ptrs += BLOCK_SIZE_K * stride_bk
1156 if HAS_BIAS:
1157 pair_bias_ptrs = (
1158 b_bias_ptr + off_experts * stride_bbe + (offs_pair_bn * stride_bbn)
1159 )
1160 pair_bias = tl.load(pair_bias_ptrs, mask=offs_pair_bn < N, other=0.0)
1161 pair_acc += pair_bias[None, :]
1163 gate_up = tl.trans(
1164 tl.reshape(pair_acc, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N)),
1165 (0, 2, 1),
1166 )
1167 gate_acc, up_acc = tl.split(gate_up)
1168 gate_sig = tl.sigmoid(gate_acc)
1169 accumulator = (
1170 gate_acc.to(compute_type)
1171 * gate_sig.to(compute_type)
1172 * up_acc.to(compute_type)
1173 )
1175 elif FUSE_SILU:
1176 offs_bn_gate = offs_bn
1177 offs_bn_up = offs_bn + N_out
1179 b_expert_base = b_ptr + off_experts * stride_be
1180 b_ptrs_gate = b_expert_base + (
1181 offs_k[:, None] * stride_bk + offs_bn_gate[None, :] * stride_bn
1182 )
1183 b_ptrs_up = b_expert_base + (
1184 offs_k[:, None] * stride_bk + offs_bn_up[None, :] * stride_bn
1185 )
1187 if use_fp8_w8a8 or use_int8_w8a8:
1188 if group_k > 0 and group_n > 0: # block-wise
1189 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1190 # Use scalar scale load for hardware broadcast when block size fits within quantization group.
1191 if BLOCK_SIZE_N <= group_n:
1192 offs_bsn_gate_idx = (pid_n * BLOCK_SIZE_N) % N_out // group_n
1193 offs_bsn_up_idx = (
1194 (pid_n * BLOCK_SIZE_N) % N_out + N_out
1195 ) // group_n
1196 else:
1197 offs_bsn_gate_idx = offs_bn_gate // group_n
1198 offs_bsn_up_idx = offs_bn_up // group_n
1199 b_scale_gate_ptrs = (
1200 b_scale_ptr
1201 + off_experts * stride_bse
1202 + offs_bsn_gate_idx * stride_bsn
1203 )
1204 b_scale_up_ptrs = (
1205 b_scale_ptr
1206 + off_experts * stride_bse
1207 + offs_bsn_up_idx * stride_bsn
1208 )
1209 elif per_channel_quant: # channel-wise
1210 b_scale_gate_ptrs = (
1211 b_scale_ptr
1212 + off_experts * stride_bse
1213 + offs_bn_gate[None, :] * stride_bsn
1214 )
1215 b_scale_gate = tl.load(b_scale_gate_ptrs)
1216 b_scale_up_ptrs = (
1217 b_scale_ptr
1218 + off_experts * stride_bse
1219 + offs_bn_up[None, :] * stride_bsn
1220 )
1221 b_scale_up = tl.load(b_scale_up_ptrs)
1222 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1223 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
1224 else: # tensor-wise
1225 a_scale = tl.load(a_scale_ptr)
1226 b_scale_gate = tl.load(b_scale_ptr + off_experts)
1227 b_scale_up = b_scale_gate
1229 # Pass 1: Sequential execution of gate projection to minimize peak register pressure.
1230 a_ptrs = a_base + offs_k[None, :] * stride_ak
1231 acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1233 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1234 # Eliminate masking overhead when K is perfectly aligned with BLOCK_SIZE_K.
1235 if K_DIVISIBLE_BY_BLOCK_K:
1236 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
1237 b_gate = tl.load(b_ptrs_gate)
1238 else:
1239 k_remaining = K - k * BLOCK_SIZE_K
1240 a = tl.load(
1241 a_ptrs,
1242 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
1243 other=0.0,
1244 )
1245 b_gate = tl.load(
1246 b_ptrs_gate, mask=offs_k[:, None] < k_remaining, other=0.0
1247 )
1249 if use_fp8_w8a8 or use_int8_w8a8:
1250 if group_k > 0 and group_n > 0:
1251 k_start = k * BLOCK_SIZE_K
1252 offs_ks = k_start // group_k
1253 a_scale = tl.load(
1254 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
1255 )
1256 b_scale_val = tl.load(b_scale_gate_ptrs + offs_ks * stride_bsk)
1258 # Pre-compute combined scale to reduce arithmetic overhead via the associative property.
1259 if BLOCK_SIZE_N <= group_n:
1260 combined_scale = a_scale[:, None] * b_scale_val
1261 else:
1262 combined_scale = a_scale[:, None] * b_scale_val[None, :]
1263 acc_gate += tl.dot(a, b_gate) * combined_scale
1264 else:
1265 if use_fp8_w8a8:
1266 acc_gate = tl.dot(a, b_gate, acc=acc_gate)
1267 else:
1268 acc_gate += tl.dot(a, b_gate)
1269 else:
1270 acc_gate += tl.dot(a, b_gate)
1272 a_ptrs += BLOCK_SIZE_K * stride_ak
1273 b_ptrs_gate += BLOCK_SIZE_K * stride_bk
1275 if use_fp8_w8a8 or use_int8_w8a8:
1276 if group_k > 0 and group_n > 0:
1277 pass
1278 elif per_channel_quant:
1279 acc_gate = acc_gate * a_scale * b_scale_gate
1280 else:
1281 acc_gate = acc_gate * a_scale * b_scale_gate
1283 # Pass 2: Sequential up projection; operand A is reloaded with high L1 hit rate.
1284 a_ptrs = a_base + offs_k[None, :] * stride_ak
1285 acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1287 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1288 # Apply mask elimination during the up projection stage.
1289 if K_DIVISIBLE_BY_BLOCK_K:
1290 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
1291 b_up = tl.load(b_ptrs_up)
1292 else:
1293 k_remaining = K - k * BLOCK_SIZE_K
1294 a = tl.load(
1295 a_ptrs,
1296 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
1297 other=0.0,
1298 )
1299 b_up = tl.load(b_ptrs_up, mask=offs_k[:, None] < k_remaining, other=0.0)
1301 if use_fp8_w8a8 or use_int8_w8a8:
1302 if group_k > 0 and group_n > 0:
1303 k_start = k * BLOCK_SIZE_K
1304 offs_ks = k_start // group_k
1305 a_scale = tl.load(
1306 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
1307 )
1308 b_scale_val = tl.load(b_scale_up_ptrs + offs_ks * stride_bsk)
1310 # Apply pre-computed scale merging to reduce multiplication overhead.
1311 if BLOCK_SIZE_N <= group_n:
1312 combined_scale = a_scale[:, None] * b_scale_val
1313 else:
1314 combined_scale = a_scale[:, None] * b_scale_val[None, :]
1315 acc_up += tl.dot(a, b_up) * combined_scale
1316 else:
1317 if use_fp8_w8a8:
1318 acc_up = tl.dot(a, b_up, acc=acc_up)
1319 else:
1320 acc_up += tl.dot(a, b_up)
1321 else:
1322 acc_up += tl.dot(a, b_up)
1324 a_ptrs += BLOCK_SIZE_K * stride_ak
1325 b_ptrs_up += BLOCK_SIZE_K * stride_bk
1327 if use_fp8_w8a8 or use_int8_w8a8:
1328 if group_k > 0 and group_n > 0:
1329 pass
1330 elif per_channel_quant:
1331 acc_up = acc_up * a_scale * b_scale_up
1332 else:
1333 acc_up = acc_up * a_scale * b_scale_up
1335 # SiLU activation fusion
1336 accumulator = tl.fdiv(acc_gate, (1.0 + tl.exp(-acc_gate))) * acc_up
1338 else:
1339 if off_experts == -1:
1340 # Expert not in current EP rank, write zeros
1341 write_zeros_to_output(
1342 c_ptr,
1343 stride_cm,
1344 stride_cn,
1345 pid_n,
1346 N_out,
1347 offs_token,
1348 token_mask,
1349 BLOCK_SIZE_M,
1350 BLOCK_SIZE_N,
1351 compute_type,
1352 )
1353 return
1354 a_ptrs = a_base + offs_k[None, :] * stride_ak
1355 b_ptrs = (
1356 b_ptr
1357 + off_experts * stride_be
1358 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
1359 )
1361 if use_int8_w8a16:
1362 b_scale_ptrs = (
1363 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
1364 )
1365 b_scale = tl.load(b_scale_ptrs)
1367 if use_fp8_w8a8 or use_int8_w8a8:
1368 if group_k > 0 and group_n > 0: # block-wise
1369 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1370 # Use scalar scale load for hardware broadcast when block size fits within quantization group.
1371 if BLOCK_SIZE_N <= group_n:
1372 offs_bsn = (pid_n * BLOCK_SIZE_N) % N_out // group_n
1373 else:
1374 offs_bsn = offs_bn // group_n
1375 b_scale_ptrs = (
1376 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
1377 )
1378 elif per_channel_quant: # channel-wise
1379 b_scale_ptrs = (
1380 b_scale_ptr
1381 + off_experts * stride_bse
1382 + offs_bn[None, :] * stride_bsn
1383 )
1384 b_scale = tl.load(b_scale_ptrs)
1385 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1386 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
1387 else: # tensor-wise
1388 a_scale = tl.load(a_scale_ptr)
1389 b_scale = tl.load(b_scale_ptr + off_experts)
1391 if HAS_BIAS:
1392 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
1393 bias = tl.load(bias_ptrs, mask=(offs_bn < N_out), other=0.0)
1395 # Accumulate C block in fp32
1396 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1397 if SWAP_AB:
1398 accumulator_nm = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
1400 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1401 # Eliminate masking overhead when K is perfectly aligned with BLOCK_SIZE_K.
1402 if K_DIVISIBLE_BY_BLOCK_K:
1403 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
1404 b = tl.load(b_ptrs)
1405 else:
1406 k_remaining = K - k * BLOCK_SIZE_K
1407 a = tl.load(
1408 a_ptrs,
1409 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
1410 other=0.0,
1411 )
1412 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
1414 if use_int8_w8a16:
1415 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
1416 elif use_fp8_w8a8 or use_int8_w8a8:
1417 if group_k > 0 and group_n > 0:
1418 k_start = k * BLOCK_SIZE_K
1419 offs_ks = k_start // group_k
1420 a_scale = tl.load(
1421 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
1422 )
1423 if SWAP_AB:
1424 b_scale_val = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
1425 if BLOCK_SIZE_N <= group_n:
1426 combined_scale_nm = b_scale_val * a_scale[None, :]
1427 else:
1428 combined_scale_nm = b_scale_val[:, None] * a_scale[None, :]
1429 accumulator_nm += (
1430 tl.dot(tl.trans(b), tl.trans(a)) * combined_scale_nm
1431 )
1432 else:
1433 b_scale_val = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
1434 # Pre-compute combined scale to reduce arithmetic overhead via the associative property.
1435 if BLOCK_SIZE_N <= group_n:
1436 combined_scale = a_scale[:, None] * b_scale_val
1437 else:
1438 combined_scale = a_scale[:, None] * b_scale_val[None, :]
1439 accumulator += tl.dot(a, b) * combined_scale
1440 else:
1441 if use_fp8_w8a8:
1442 if SWAP_AB:
1443 accumulator_nm = tl.dot(
1444 tl.trans(b), tl.trans(a), acc=accumulator_nm
1445 )
1446 else:
1447 accumulator = tl.dot(a, b, acc=accumulator)
1448 else:
1449 if SWAP_AB:
1450 accumulator_nm += tl.dot(tl.trans(b), tl.trans(a))
1451 else:
1452 accumulator += tl.dot(a, b)
1453 else:
1454 if SWAP_AB:
1455 accumulator_nm += tl.dot(tl.trans(b), tl.trans(a))
1456 else:
1457 accumulator += tl.dot(a, b)
1458 a_ptrs += BLOCK_SIZE_K * stride_ak
1459 b_ptrs += BLOCK_SIZE_K * stride_bk
1461 if SWAP_AB:
1462 accumulator = tl.trans(accumulator_nm)
1464 # Dequantization
1465 if use_int8_w8a16:
1466 accumulator = accumulator * b_scale
1467 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
1468 accumulator = accumulator * a_scale * b_scale
1470 if HAS_BIAS:
1471 accumulator += bias[None, :]
1473 # Router weight multiplication (must be in fp32)
1474 if MUL_ROUTED_WEIGHT:
1475 moe_weight = tl.load(
1476 topk_weights_ptr + offs_token,
1477 mask=token_mask,
1478 other=0,
1479 )
1480 accumulator *= moe_weight[:, None]
1482 accumulator = accumulator.to(compute_type)
1484 # Write back output
1485 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1486 if DIRECT_SUM:
1487 offs_c = offs_token // OUT_TOP_K
1488 else:
1489 offs_c = offs_token
1490 c_ptrs = c_ptr + stride_cm * offs_c[:, None] + stride_cn * offs_cn[None, :]
1491 c_mask = token_mask[:, None]
1492 if not N_DIVISIBLE_BY_BLOCK_N:
1493 c_mask = c_mask & (offs_cn[None, :] < N_out)
1494 if DIRECT_SUM:
1495 # Kernel completion provides the only ordering needed here.
1496 tl.atomic_add(c_ptrs, accumulator, sem="relaxed", mask=c_mask)
1497 else:
1498 tl.store(c_ptrs, accumulator, mask=c_mask)
1501def invoke_fused_moe_wna16_triton_kernel(
1502 A: torch.Tensor,
1503 B: torch.Tensor,
1504 C: torch.Tensor,
1505 B_scale: torch.Tensor | None,
1506 B_zp: torch.Tensor | None,
1507 topk_weights: torch.Tensor | None,
1508 sorted_token_ids: torch.Tensor,
1509 expert_ids: torch.Tensor,
1510 num_tokens_post_padded: torch.Tensor,
1511 mul_routed_weight: bool,
1512 top_k: int,
1513 config: dict[str, Any],
1514 compute_type: tl.dtype,
1515 use_int8_w8a16: bool,
1516 use_int4_w4a16: bool,
1517 block_shape: list[int] | None,
1518):
1519 assert B_scale is not None and B_scale.ndim == 3
1520 assert B_zp is None or B_zp.ndim == 3
1521 assert block_shape is not None and block_shape[0] == 0
1523 M = A.size(0)
1524 num_tokens = M * top_k
1526 EM = sorted_token_ids.size(0)
1527 if A.size(0) < config["BLOCK_SIZE_M"]:
1528 # optimize for small batch_size.
1529 # We assume that top_ids of each token is unique,
1530 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
1531 # and we can skip some invalid blocks.
1532 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
1533 grid = lambda META: (
1534 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1535 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1536 )
1537 config = config.copy()
1538 config.update(
1539 get_moe_wna16_block_config(
1540 config=config,
1541 use_moe_wna16_cuda=False,
1542 num_valid_tokens=num_tokens,
1543 size_k=A.size(1),
1544 size_n=B.size(1),
1545 num_experts=B.size(1),
1546 group_size=block_shape[1],
1547 real_top_k=top_k,
1548 block_size_m=config["BLOCK_SIZE_M"],
1549 )
1550 )
1552 fused_moe_kernel_gptq_awq[grid](
1553 A,
1554 B,
1555 C,
1556 B_scale,
1557 B_zp,
1558 topk_weights,
1559 sorted_token_ids,
1560 expert_ids,
1561 num_tokens_post_padded,
1562 B.size(1),
1563 A.size(1),
1564 EM,
1565 num_tokens,
1566 A.stride(0),
1567 A.stride(1),
1568 B.stride(0),
1569 B.stride(2),
1570 B.stride(1),
1571 C.stride(1),
1572 C.stride(2),
1573 B_scale.stride(0),
1574 B_scale.stride(2),
1575 B_scale.stride(1),
1576 B_zp.stride(0) if B_zp is not None else 0,
1577 B_zp.stride(2) if B_zp is not None else 0,
1578 B_zp.stride(1) if B_zp is not None else 0,
1579 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
1580 group_size=block_shape[1],
1581 MUL_ROUTED_WEIGHT=mul_routed_weight,
1582 top_k=top_k,
1583 compute_type=compute_type,
1584 has_zp=B_zp is not None,
1585 use_int4_w4a16=use_int4_w4a16,
1586 use_int8_w8a16=use_int8_w8a16,
1587 **config,
1588 )
1591def invoke_fused_moe_triton_kernel(
1592 A: torch.Tensor,
1593 B: torch.Tensor,
1594 C: torch.Tensor,
1595 A_scale: Optional[torch.Tensor],
1596 B_scale: Optional[torch.Tensor],
1597 topk_weights: Optional[torch.Tensor],
1598 sorted_token_ids: torch.Tensor,
1599 expert_ids: torch.Tensor,
1600 num_tokens_post_padded: torch.Tensor,
1601 mul_routed_weight: bool,
1602 top_k: int,
1603 config: dict[str, Any],
1604 compute_type: tl.dtype,
1605 use_fp8_w8a8: bool = False,
1606 use_int8_w8a8: bool = False,
1607 use_int8_w8a16: bool = False,
1608 use_int4_w4a16: bool = False,
1609 per_channel_quant: bool = False,
1610 block_shape: Optional[list[int]] = None,
1611 B_bias: torch.Tensor | None = None,
1612 FUSE_SILU: bool = False,
1613 direct_sum: bool = False,
1614 out_top_k: int = 1,
1615) -> None:
1616 """Launch the fused_moe_kernel Triton kernel."""
1617 assert topk_weights is not None or not mul_routed_weight
1618 assert topk_weights is None or topk_weights.stride(1) == 1
1619 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1621 if use_fp8_w8a8 or use_int8_w8a8:
1622 assert B_scale is not None
1623 assert block_shape is None or triton.cdiv(
1624 B.size(-2), block_shape[0]
1625 ) == B_scale.size(-2)
1626 assert block_shape is None or triton.cdiv(
1627 B.size(-1), block_shape[1]
1628 ) == B_scale.size(-1)
1629 elif use_int8_w8a16 or use_int4_w4a16:
1630 assert B_scale is not None
1631 assert block_shape is None or block_shape[0] == 0
1632 else:
1633 assert A_scale is None
1634 assert B_scale is None
1636 M = A.size(0)
1637 num_tokens = M * top_k
1638 if sorted_token_ids is not None:
1639 EM = sorted_token_ids.size(0)
1640 if A.size(0) < config["BLOCK_SIZE_M"]:
1641 EM = min(
1642 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
1643 )
1644 else:
1645 EM = num_tokens * config["BLOCK_SIZE_M"]
1647 # FUSE_SILU means B.size(1) contains both Gate and Up. N is halved.
1648 actual_N = B.size(1) // 2 if FUSE_SILU else B.size(1)
1649 grid = lambda META: (
1650 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1651 * triton.cdiv(actual_N, META["BLOCK_SIZE_N"]),
1652 )
1653 HAS_BIAS = B_bias is not None
1655 config = config.copy()
1656 config["SPLIT_K"] = 1
1657 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
1658 if block_shape is not None:
1659 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
1661 swap_AB = config.pop("SWAP_AB", False)
1662 pair_gate_up_dot = config.pop("PAIR_GATE_UP_DOT", False)
1663 # Force disable SWAP_AB in fusion mode
1664 if FUSE_SILU:
1665 swap_AB = False
1667 fused_moe_kernel[grid](
1668 A,
1669 B,
1670 C,
1671 B_bias,
1672 A_scale,
1673 B_scale,
1674 topk_weights,
1675 sorted_token_ids,
1676 expert_ids,
1677 num_tokens_post_padded,
1678 B.size(1), # N
1679 B.size(2), # K
1680 EM,
1681 num_tokens,
1682 A.stride(0),
1683 A.stride(1),
1684 B.stride(0),
1685 B.stride(2),
1686 B.stride(1),
1687 C.stride(1),
1688 C.stride(2),
1689 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
1690 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
1691 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
1692 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
1693 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
1694 B_bias.stride(0) if B_bias is not None else 0,
1695 B_bias.stride(1) if B_bias is not None else 0,
1696 0 if block_shape is None else block_shape[0],
1697 0 if block_shape is None else block_shape[1],
1698 MUL_ROUTED_WEIGHT=mul_routed_weight,
1699 top_k=top_k,
1700 compute_type=compute_type,
1701 use_fp8_w8a8=use_fp8_w8a8,
1702 use_int8_w8a8=use_int8_w8a8,
1703 use_int8_w8a16=use_int8_w8a16,
1704 per_channel_quant=per_channel_quant,
1705 naive_block_assignment=(sorted_token_ids is None),
1706 HAS_BIAS=HAS_BIAS,
1707 BLOCK_SIZE_K=BLOCK_SIZE_K,
1708 SWAP_AB=swap_AB,
1709 K_DIVISIBLE_BY_BLOCK_K=(B.size(2) % BLOCK_SIZE_K == 0),
1710 N_DIVISIBLE_BY_BLOCK_N=(actual_N % config["BLOCK_SIZE_N"] == 0),
1711 PAIR_GATE_UP_DOT=pair_gate_up_dot,
1712 DIRECT_SUM=direct_sum,
1713 OUT_TOP_K=out_top_k,
1714 FUSE_SILU=FUSE_SILU,
1715 **config,
1716 )
1719def dispatch_fused_moe_kernel(
1720 A: torch.Tensor,
1721 B: torch.Tensor,
1722 C: torch.Tensor,
1723 A_scale: Optional[torch.Tensor],
1724 B_scale: Optional[torch.Tensor],
1725 B_zp: Optional[torch.Tensor],
1726 topk_weights: Optional[torch.Tensor],
1727 sorted_token_ids: torch.Tensor,
1728 expert_ids: torch.Tensor,
1729 num_tokens_post_padded: torch.Tensor,
1730 mul_routed_weight: bool,
1731 top_k: int,
1732 config: dict[str, Any],
1733 compute_type: tl.dtype,
1734 use_fp8_w8a8: bool,
1735 use_int8_w8a8: bool,
1736 use_int8_w8a16: bool,
1737 use_int4_w4a16: bool,
1738 per_channel_quant: bool,
1739 block_shape: Optional[list[int]] = None,
1740 B_bias: Optional[torch.Tensor] = None,
1741 FUSE_SILU: bool = False,
1742 direct_sum: bool = False,
1743 out_top_k: int = 1,
1744) -> None:
1745 """Dispatch to the appropriate fused MoE kernel based on quantization flags."""
1746 assert topk_weights is not None or not mul_routed_weight
1747 assert topk_weights is None or topk_weights.stride(1) == 1
1748 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1750 # M = A.size(0)
1751 # num_tokens = M * top_k
1753 if False:
1754 # TODO: Other precision-specific implementations
1755 # use_fp8_w8a8,
1756 # use_int8_w8a8,
1757 # use_int8_w8a16,
1758 # use_int4_w4a16,
1759 pass
1760 if (use_int8_w8a16 or use_int4_w4a16) and (
1761 block_shape is not None and block_shape[1] > 0
1762 ):
1763 assert B_bias is None
1764 invoke_fused_moe_wna16_triton_kernel(
1765 A,
1766 B,
1767 C,
1768 B_scale,
1769 B_zp,
1770 topk_weights,
1771 sorted_token_ids,
1772 expert_ids,
1773 num_tokens_post_padded,
1774 mul_routed_weight,
1775 top_k,
1776 config,
1777 compute_type,
1778 use_int8_w8a16,
1779 use_int4_w4a16,
1780 block_shape,
1781 )
1782 else:
1783 invoke_fused_moe_triton_kernel(
1784 A,
1785 B,
1786 C,
1787 A_scale,
1788 B_scale,
1789 topk_weights,
1790 sorted_token_ids,
1791 expert_ids,
1792 num_tokens_post_padded,
1793 mul_routed_weight,
1794 top_k,
1795 config,
1796 compute_type,
1797 use_fp8_w8a8,
1798 use_int8_w8a8,
1799 use_int8_w8a16,
1800 use_int4_w4a16,
1801 per_channel_quant,
1802 block_shape,
1803 B_bias,
1804 FUSE_SILU=FUSE_SILU,
1805 direct_sum=direct_sum,
1806 out_top_k=out_top_k,
1807 )
1810def fused_experts_impl(
1811 hidden_states: torch.Tensor,
1812 w1: torch.Tensor,
1813 w2: torch.Tensor,
1814 topk_weights: torch.Tensor,
1815 topk_ids: torch.Tensor,
1816 inplace: bool = False,
1817 activation: str = "silu",
1818 apply_router_weight_on_input: bool = False,
1819 use_fp8_w8a8: bool = False,
1820 use_int8_w8a8: bool = False,
1821 use_int8_w8a16: bool = False,
1822 use_int4_w4a16: bool = False,
1823 ocp_mx_scheme: str | None = None,
1824 per_channel_quant: bool = False,
1825 global_num_experts: int = -1,
1826 expert_map: torch.Tensor | None = None,
1827 w1_scale: Optional[torch.Tensor] = None,
1828 w2_scale: Optional[torch.Tensor] = None,
1829 w1_zp: torch.Tensor | None = None,
1830 w2_zp: torch.Tensor | None = None,
1831 a1_scale: Optional[torch.Tensor] = None,
1832 a2_scale: Optional[torch.Tensor] = None,
1833 block_shape: Optional[list[int]] = None,
1834 w1_bias: Optional[torch.Tensor] = None,
1835 w2_bias: Optional[torch.Tensor] = None,
1836) -> torch.Tensor:
1837 logger.debug("GEMS FUSED MOE")
1838 assert (
1839 activation == "silu"
1840 ), f"Only 'silu' activation is supported, got {activation}"
1842 activation_enum = MoEActivation.from_str(activation)
1844 # Check constraints
1845 if use_int4_w4a16:
1846 # INT4 stored unpacked in INT8 containers (full K dim)
1847 assert hidden_states.size(1) == w1.size(
1848 2
1849 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1850 elif ocp_mx_scheme is not None:
1851 if ocp_mx_scheme.startswith("w_mxfp4"):
1852 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
1853 elif ocp_mx_scheme.startswith("w_mxfp6"):
1854 assert (
1855 hidden_states.size(1) == (w1.size(2) * 4) // 3
1856 ), "hidden size mismatch"
1857 else:
1858 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1859 else:
1860 assert hidden_states.size(1) == w1.size(
1861 2
1862 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1864 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
1865 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1866 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
1867 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
1868 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
1870 num_tokens = hidden_states.size(0)
1871 E, N, _ = w1.size()
1872 K = w2.size(1)
1873 if global_num_experts == -1:
1874 global_num_experts = E
1875 top_k_num = topk_ids.size(1)
1877 CHUNK_SIZE: int = 32 * 1024
1878 M = min(num_tokens, CHUNK_SIZE)
1880 config_dtype = _get_config_dtype_str(
1881 use_fp8_w8a8=use_fp8_w8a8,
1882 use_int8_w8a16=use_int8_w8a16,
1883 use_int4_w4a16=use_int4_w4a16,
1884 ocp_mx_scheme=ocp_mx_scheme,
1885 dtype=hidden_states.dtype,
1886 )
1887 is_plain_half_config = config_dtype in _PLAIN_HALF_CONFIG_DTYPES
1888 is_fp8_blockwise = config_dtype == "fp8_w8a8" and block_shape is not None
1890 quant_dtype = _get_config_quant_dtype(
1891 use_fp8_w8a8=use_fp8_w8a8,
1892 use_int8_w8a8=use_int8_w8a8,
1893 ocp_mx_scheme=ocp_mx_scheme,
1894 )
1896 get_moe_config = functools.partial(
1897 try_get_optimal_moe_config,
1898 w1.size(),
1899 w2.size(),
1900 top_k_num,
1901 config_dtype,
1902 block_shape=block_shape,
1903 E=E,
1904 return_is_embedded=True,
1905 )
1907 base_config, is_embedded_config = get_moe_config(M)
1909 # cache1 and cache3 share memory (non-overlapping lifetime)
1910 cache13 = torch.empty(
1911 M * top_k_num * max(N, K),
1912 device=hidden_states.device,
1913 dtype=hidden_states.dtype,
1914 )
1915 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
1916 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
1918 # cache2 needs separate memory (concurrent with cache1)
1919 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
1920 intermediate_cache2 = torch.empty(
1921 (M * top_k_num, activation_out_dim),
1922 device=hidden_states.device,
1923 dtype=hidden_states.dtype,
1924 )
1926 if hidden_states.dtype == torch.bfloat16:
1927 compute_type = tl.bfloat16
1928 elif hidden_states.dtype == torch.float16:
1929 compute_type = tl.float16
1930 elif hidden_states.dtype == torch.float32:
1931 compute_type = tl.float32
1932 else:
1933 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1935 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
1937 if ocp_mx_scheme is not None:
1938 # Dequantize OCP MX weights (TODO: skip on platforms with native MX)
1939 if ocp_mx_scheme.startswith("w_mxfp4"):
1940 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
1941 w1_scale = None
1942 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
1943 w2_scale = None
1944 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
1945 w1 = dequant_mxfp6(
1946 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1947 )
1948 w1_scale = None
1949 w2 = dequant_mxfp6(
1950 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1951 )
1952 w2_scale = None
1953 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
1954 w1 = dequant_mxfp6(
1955 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1956 )
1957 w1_scale = None
1958 w2 = dequant_mxfp6(
1959 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1960 )
1961 w2_scale = None
1962 else:
1963 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1965 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot)
1966 if use_int8_w8a16 or use_int4_w4a16:
1967 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype)
1968 w1_scale = None
1969 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype)
1970 w2_scale = None
1971 use_int8_w8a16 = False
1972 use_int4_w4a16 = False
1974 direct_sum_supported = is_plain_half_config or is_fp8_blockwise
1976 # Check if we can safely fuse the activation with the first GEMM pass
1977 can_use_fused_silu = (
1978 activation_enum in (MoEActivation.SILU, MoEActivation.SWIGLUOAI)
1979 and w1_bias is None
1980 and expert_map is None # Fused kernel doesn't handle EP -1 experts
1981 )
1983 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
1984 begin_chunk_idx, end_chunk_idx = (
1985 chunk * CHUNK_SIZE,
1986 min((chunk + 1) * CHUNK_SIZE, num_tokens),
1987 )
1988 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
1989 tokens_in_chunk, _ = curr_hidden_states.size()
1991 if tokens_in_chunk == 0:
1992 break
1994 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
1995 # Adjust cache size for last chunk
1996 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1997 intermediate_cache2 = intermediate_cache2[
1998 : tokens_in_chunk * topk_ids.size(1)
1999 ]
2000 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
2001 base_config, is_embedded_config = get_moe_config(tokens_in_chunk)
2003 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
2004 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
2005 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
2006 A=curr_hidden_states,
2007 A_scale=a1_scale,
2008 quant_dtype=quant_dtype,
2009 per_act_token_quant=per_channel_quant,
2010 block_shape=block_shape,
2011 ocp_mx_scheme=ocp_mx_scheme,
2012 )
2014 SPARSITY_FACTOR = 4
2015 naive_block_assignment = (
2016 expert_map is None
2017 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
2018 and not (
2019 (use_int8_w8a16 or use_int4_w4a16)
2020 and block_shape is not None
2021 and block_shape[1] > 0
2022 )
2023 )
2025 if not naive_block_assignment:
2026 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
2027 curr_topk_ids,
2028 base_config["BLOCK_SIZE_M"],
2029 global_num_experts,
2030 expert_map,
2031 # ignore_invalid_experts=True,
2032 )
2033 else:
2034 max_num_tokens_padded = topk_ids.numel() * base_config["BLOCK_SIZE_M"]
2035 expert_ids = curr_topk_ids.view(-1)
2036 num_tokens_post_padded = torch.empty(
2037 (1), dtype=torch.int32, device=topk_ids.device
2038 )
2039 num_tokens_post_padded.fill_(max_num_tokens_padded)
2040 sorted_token_ids = None
2042 # 1. Extract a unified boolean flag for GEMM1 fusion and select config
2043 do_fuse_silu = can_use_fused_silu and not naive_block_assignment
2044 use_half_gemm_fast_paths = not is_embedded_config and is_plain_half_config
2046 gemm1_config = base_config
2047 if do_fuse_silu and use_half_gemm_fast_paths:
2048 gemm1_config, _ = get_moe_config(
2049 tokens_in_chunk,
2050 gemm_stage="gemm1",
2051 enable_gemm_fast_path=True,
2052 )
2054 # 2. Dynamically determine the differing parameters based on the fusion flag
2055 if do_fuse_silu:
2056 # Output goes directly to cache 2 with adjusted dimensions
2057 out_cache = intermediate_cache2.view(
2058 tokens_in_chunk, top_k_num, activation_out_dim
2059 )
2060 # Fused kernel weight handling depends on apply_router_weight_on_input
2061 if apply_router_weight_on_input:
2062 weights_arg = curr_topk_weights
2063 else:
2064 weights_arg = None
2065 else:
2066 # Standard path outputs to cache 1
2067 out_cache = intermediate_cache1
2068 # Standard path always passes the weights
2069 weights_arg = curr_topk_weights
2071 # 3. Unified GEMM1 dispatch call to eliminate redundant code blocks
2072 dispatch_fused_moe_kernel(
2073 qcurr_hidden_states,
2074 w1,
2075 out_cache, # Dynamically assigned output buffer
2076 a1q_scale,
2077 w1_scale,
2078 w1_zp,
2079 weights_arg, # Dynamically assigned weights argument
2080 sorted_token_ids,
2081 expert_ids,
2082 num_tokens_post_padded,
2083 apply_router_weight_on_input,
2084 top_k_num,
2085 gemm1_config,
2086 compute_type=compute_type,
2087 use_fp8_w8a8=use_fp8_w8a8,
2088 use_int8_w8a8=use_int8_w8a8,
2089 use_int8_w8a16=use_int8_w8a16,
2090 use_int4_w4a16=use_int4_w4a16,
2091 per_channel_quant=per_channel_quant,
2092 block_shape=block_shape,
2093 B_bias=w1_bias,
2094 FUSE_SILU=do_fuse_silu, # Master switch for the kernel
2095 )
2097 # 4. Apply activation separately if the fused path was not taken
2098 if not do_fuse_silu:
2099 apply_moe_activation(
2100 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
2101 )
2103 # 5. Quantize activated intermediate for GEMM2
2104 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
2105 A=intermediate_cache2,
2106 A_scale=a2_scale,
2107 quant_dtype=quant_dtype,
2108 per_act_token_quant=per_channel_quant,
2109 block_shape=block_shape,
2110 ocp_mx_scheme=ocp_mx_scheme,
2111 )
2113 if expert_map is not None:
2114 intermediate_cache3.zero_()
2116 # 6. Select GEMM2 config and output buffer/reduction path
2117 gemm2_config = base_config
2118 if use_half_gemm_fast_paths:
2119 gemm2_config, _ = get_moe_config(
2120 tokens_in_chunk,
2121 gemm_stage="gemm2",
2122 enable_gemm_fast_path=True,
2123 )
2124 use_direct_sum = (
2125 not is_embedded_config
2126 and direct_sum_supported
2127 and tokens_in_chunk >= MOE_DIRECT_SUM_MIN_TOKENS
2128 and expert_map is None
2129 and not apply_router_weight_on_input
2130 )
2131 if use_direct_sum:
2132 gemm2_output = out_hidden_states[begin_chunk_idx:end_chunk_idx].view(
2133 tokens_in_chunk, 1, K
2134 )
2135 gemm2_output.zero_()
2136 else:
2137 gemm2_output = intermediate_cache3
2139 # 7. Dispatch GEMM2
2140 dispatch_fused_moe_kernel(
2141 qintermediate_cache2,
2142 w2,
2143 gemm2_output,
2144 a2q_scale,
2145 w2_scale,
2146 w2_zp,
2147 curr_topk_weights,
2148 sorted_token_ids,
2149 expert_ids,
2150 num_tokens_post_padded,
2151 not apply_router_weight_on_input,
2152 1,
2153 gemm2_config,
2154 compute_type=compute_type,
2155 use_fp8_w8a8=use_fp8_w8a8,
2156 use_int8_w8a8=use_int8_w8a8,
2157 use_int8_w8a16=use_int8_w8a16,
2158 use_int4_w4a16=use_int4_w4a16,
2159 per_channel_quant=per_channel_quant,
2160 block_shape=block_shape,
2161 B_bias=w2_bias,
2162 FUSE_SILU=False,
2163 direct_sum=use_direct_sum,
2164 out_top_k=top_k_num,
2165 )
2167 # 8. Reduce GEMM2 top-k outputs unless direct_sum wrote final output directly
2168 if not use_direct_sum:
2169 moe_sum(
2170 intermediate_cache3.view(*intermediate_cache3.size()),
2171 out_hidden_states[begin_chunk_idx:end_chunk_idx],
2172 )
2174 return out_hidden_states
2177def inplace_fused_experts(
2178 hidden_states: torch.Tensor,
2179 w1: torch.Tensor,
2180 w2: torch.Tensor,
2181 topk_weights: torch.Tensor,
2182 topk_ids: torch.Tensor,
2183 activation: str = "silu",
2184 apply_router_weight_on_input: bool = False,
2185 use_fp8_w8a8: bool = False,
2186 use_int8_w8a8: bool = False,
2187 use_int8_w8a16: bool = False,
2188 use_int4_w4a16: bool = False,
2189 per_channel_quant: bool = False,
2190 global_num_experts: int = -1,
2191 w1_scale: Optional[torch.Tensor] = None,
2192 w2_scale: Optional[torch.Tensor] = None,
2193 a1_scale: Optional[torch.Tensor] = None,
2194 a2_scale: Optional[torch.Tensor] = None,
2195 block_shape: Optional[list[int]] = None,
2196 w1_bias: Optional[torch.Tensor] = None,
2197 w2_bias: Optional[torch.Tensor] = None,
2198) -> None:
2199 """
2200 In-place fused MoE: writes output directly into ``hidden_states``.
2202 Same semantics as ``fused_experts_impl(..., inplace=True)``.
2203 Returns None (the result is stored in ``hidden_states``).
2204 """
2205 fused_experts_impl(
2206 hidden_states,
2207 w1,
2208 w2,
2209 topk_weights,
2210 topk_ids,
2211 inplace=True,
2212 activation=activation,
2213 apply_router_weight_on_input=apply_router_weight_on_input,
2214 use_fp8_w8a8=use_fp8_w8a8,
2215 use_int8_w8a8=use_int8_w8a8,
2216 use_int8_w8a16=use_int8_w8a16,
2217 use_int4_w4a16=use_int4_w4a16,
2218 per_channel_quant=per_channel_quant,
2219 global_num_experts=global_num_experts,
2220 w1_scale=w1_scale,
2221 w2_scale=w2_scale,
2222 a1_scale=a1_scale,
2223 a2_scale=a2_scale,
2224 block_shape=block_shape,
2225 w1_bias=w1_bias,
2226 w2_bias=w2_bias,
2227 )
2230def outplace_fused_experts(
2231 hidden_states: torch.Tensor,
2232 w1: torch.Tensor,
2233 w2: torch.Tensor,
2234 topk_weights: torch.Tensor,
2235 topk_ids: torch.Tensor,
2236 activation: str = "silu",
2237 apply_router_weight_on_input: bool = False,
2238 use_fp8_w8a8: bool = False,
2239 use_int8_w8a8: bool = False,
2240 use_int8_w8a16: bool = False,
2241 use_int4_w4a16: bool = False,
2242 per_channel_quant: bool = False,
2243 global_num_experts: int = -1,
2244 w1_scale: Optional[torch.Tensor] = None,
2245 w2_scale: Optional[torch.Tensor] = None,
2246 a1_scale: Optional[torch.Tensor] = None,
2247 a2_scale: Optional[torch.Tensor] = None,
2248 block_shape: Optional[list[int]] = None,
2249 w1_bias: Optional[torch.Tensor] = None,
2250 w2_bias: Optional[torch.Tensor] = None,
2251) -> torch.Tensor:
2252 """
2253 Out-of-place fused MoE: allocates and returns a new output tensor.
2255 Same semantics as ``fused_experts_impl(..., inplace=False)``.
2256 """
2257 return fused_experts_impl(
2258 hidden_states,
2259 w1,
2260 w2,
2261 topk_weights,
2262 topk_ids,
2263 inplace=False,
2264 activation=activation,
2265 apply_router_weight_on_input=apply_router_weight_on_input,
2266 use_fp8_w8a8=use_fp8_w8a8,
2267 use_int8_w8a8=use_int8_w8a8,
2268 use_int8_w8a16=use_int8_w8a16,
2269 use_int4_w4a16=use_int4_w4a16,
2270 per_channel_quant=per_channel_quant,
2271 global_num_experts=global_num_experts,
2272 w1_scale=w1_scale,
2273 w2_scale=w2_scale,
2274 a1_scale=a1_scale,
2275 a2_scale=a2_scale,
2276 block_shape=block_shape,
2277 w1_bias=w1_bias,
2278 w2_bias=w2_bias,
2279 )