Coverage for src/flag_gems/fused/fused_moe.py: 42%
692 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3#
4# Adapted from the vLLM project (https://github.com/vllm-project/vllm).
5# Source files under vllm/model_executor/layers/:
6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl
7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation
8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input
9# fused_moe/config.py – _get_config_dtype_str
10# quantization/utils/mxfp4_utils.py – dequant_mxfp4
11# quantization/utils/mxfp6_utils.py – dequant_mxfp6
12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE
15import functools
16import logging
17import os
18from enum import Enum
19from typing import Any, Optional
21import torch
22import torch.nn.functional as F
23import triton
24import triton.language as tl
25import yaml
27from flag_gems.fused.moe_align_block_size import moe_align_block_size
28from flag_gems.fused.moe_sum import moe_sum
29from flag_gems.runtime import device, torch_device_fn
30from flag_gems.utils import pointwise_dynamic
32logger = logging.getLogger(__name__)
34# OCP MX quantization helpers (requires amd-quark)
36OCP_MX_BLOCK_SIZE = 32
39@functools.lru_cache(maxsize=1)
40def get_embedded_moe_configs():
41 config_path = os.path.join(
42 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.yaml"
43 )
44 if not os.path.exists(config_path):
45 return {}, {}
46 with open(config_path, "r") as f:
47 # JSON keys are strings, values are dicts where keys are M and values are configs
48 data = yaml.safe_load(f)
50 fallback = data.get("_FALLBACK", {})
52 # We need to convert the innermost keys (which are stringified integers for M) back to integers.
53 # Ensure we map the lists back to config dicts.
54 keys_order = [
55 "BLOCK_SIZE_M",
56 "BLOCK_SIZE_N",
57 "BLOCK_SIZE_K",
58 "GROUP_SIZE_M",
59 "num_warps",
60 "num_stages",
61 ]
62 parsed_data = {}
63 for dev, configs in data.items():
64 if dev == "_FALLBACK":
65 continue
66 parsed_data[dev] = {}
67 for k, m_dict in configs.items():
68 parsed_dict = {}
69 for m, v in m_dict.items():
70 if isinstance(v, list):
71 parsed_dict[int(m)] = dict(zip(keys_order, v))
72 else:
73 parsed_dict[int(m)] = v
74 parsed_data[dev][k] = parsed_dict
76 return parsed_data, fallback
79def dequant_mxfp4(
80 x: torch.Tensor,
81 scale: torch.Tensor,
82 float_dtype: torch.dtype,
83) -> torch.Tensor:
84 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4."""
85 try:
86 from quark.torch.kernel import mx
87 except ImportError as err:
88 raise ImportError("amd-quark is required for MX-FP4") from err
90 return mx.dq_mxfp4(x, scale, float_dtype)
93def dequant_mxfp6(
94 x: torch.Tensor,
95 scale: torch.Tensor,
96 float_dtype: torch.dtype,
97 quant_dtype: str,
98) -> torch.Tensor:
99 """Dequantize MXFP6 tensor via quark hw_emulation."""
100 try:
101 from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
102 dequantize_fp4_fp6_per_group,
103 )
104 from quark.torch.utils.pack import create_pack_method
105 except ImportError as err:
106 raise ImportError("amd-quark is required for MX-FP6") from err
108 pack_method = create_pack_method(None, dtype=quant_dtype)
109 unpacked_x = pack_method.unpack(x, reorder=False)
111 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
113 return dequantize_fp4_fp6_per_group(
114 unpacked_x,
115 scale,
116 axis=-1,
117 group_size=OCP_MX_BLOCK_SIZE,
118 quant_dtype=quant_dtype,
119 ).to(float_dtype)
122# Activation quantization helpers
125@functools.lru_cache(maxsize=1)
126def _get_device_name() -> str:
127 """Return the normalised device name (spaces replaced by underscores).
129 Matches the naming convention used by vLLM for its per-device config files.
130 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture).
131 """
132 try:
133 name = torch_device_fn.get_device_name().replace(" ", "_")
134 except AttributeError:
135 name = device.name
136 # Normalise the H200 product family to a single key, following vLLM.
137 if "H200" in name.split("_"):
138 name = "NVIDIA_H200"
139 # H800 has the same SM 9.0 as H100; use H100 configs as fallback.
140 embedded_configs, fallback_mapping = get_embedded_moe_configs()
141 if name in embedded_configs:
142 return name
143 # Fallback mapping for devices whose tuning profiles are equivalent.
144 fallback = fallback_mapping.get(name)
145 if fallback and fallback in embedded_configs:
146 logger.info("Device %s not in config table, falling back to %s", name, fallback)
147 return fallback
148 return name
151def get_moe_configs(
152 E: int,
153 N: int,
154 dtype: str | None,
155 block_n: int | None = None,
156 block_k: int | None = None,
157) -> dict[int, Any] | None:
158 """
159 Return optimized configurations for the fused MoE kernel.
161 Looks up pre-tuned configs from the embedded table (ported from vLLM)
162 for the current GPU device. Returns None if no matching config is found.
163 """
164 device_name = _get_device_name()
165 embedded_configs, _ = get_embedded_moe_configs()
166 device_table = embedded_configs.get(device_name)
167 if device_table is None:
168 logger.warning(
169 "No embedded MoE configs for device %s. Will use default config.",
170 device_name,
171 )
172 return None
174 _block_n = block_n if block_n else 0
175 _block_k = block_k if block_k else 0
176 key = f"{E},{N},{dtype},{_block_n},{_block_k}"
177 configs = device_table.get(key)
178 if configs is not None:
179 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key)
180 return configs
181 logger.warning(
182 "No embedded MoE config for device=%s, key=%s. Will use default config.",
183 device_name,
184 key,
185 )
186 return None
189def try_get_optimal_moe_config(
190 w1_shape: tuple[int, ...],
191 w2_shape: tuple[int, ...],
192 top_k: int,
193 dtype: str | None,
194 M: int,
195 E: int,
196 block_shape: list[int] | None = None,
197) -> dict[str, int]:
198 override_config: Optional[dict[str, Any]] = None
200 is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (
201 9,
202 0,
203 )
204 if (
205 is_hopper
206 and dtype == "fp8_w8a8"
207 and block_shape is not None
208 and len(block_shape) == 2
209 ):
210 # Use heuristic config like hpc-ops
211 avg_tokens_per_expert = M * top_k // E
212 if avg_tokens_per_expert <= 16:
213 block_size_m = 16
214 elif avg_tokens_per_expert <= 32:
215 block_size_m = 32
216 elif avg_tokens_per_expert <= 48:
217 block_size_m = 48
218 else:
219 block_size_m = 64
220 config = {
221 "BLOCK_SIZE_M": block_size_m,
222 "BLOCK_SIZE_N": block_shape[0],
223 "BLOCK_SIZE_K": block_shape[1],
224 "GROUP_SIZE_M": 1,
225 "num_warps": 4,
226 "num_stages": 3,
227 "SWAP_AB": True,
228 }
229 override_config = config
231 if override_config:
232 config = override_config
233 else:
234 # First try to load optimal config from the file
235 E, _, N = w2_shape
236 if dtype == "int4_w4a16":
237 N = N * 2
238 block_n = block_shape[0] if block_shape else 0
239 block_k = block_shape[1] if block_shape else 0
240 configs = get_moe_configs(E, N, dtype, block_n, block_k)
242 if configs:
243 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
244 else:
245 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
246 return config
249def _get_config_quant_dtype(
250 use_fp8_w8a8: bool,
251 use_int8_w8a8: bool,
252 ocp_mx_scheme: str | None,
253) -> None | torch.dtype | str:
254 """Map quantization flags to the corresponding dtype."""
255 if use_fp8_w8a8:
256 return torch.float8_e4m3fn
257 elif use_int8_w8a8:
258 return torch.int8
259 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
260 return "mxfp4"
261 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
262 return "mxfp6_e3m2"
263 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
264 return "mxfp6_e2m3"
265 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
266 return torch.bfloat16
267 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
268 return torch.float8_e4m3fn
270 return None
273def get_moe_wna16_block_config(
274 config: dict[str, int],
275 use_moe_wna16_cuda: bool,
276 num_valid_tokens: int,
277 size_k: int,
278 size_n: int,
279 num_experts: int,
280 group_size: int,
281 real_top_k: int,
282 block_size_m: int,
283):
284 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
285 return {}
286 if not use_moe_wna16_cuda:
287 if num_valid_tokens // real_top_k == 1:
288 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
289 else:
290 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
291 else:
292 block_size_n = 128
293 block_size_k = 128
294 if block_size_k <= group_size:
295 block_size_k = group_size
297 num_n_blocks = size_k // block_size_k
298 num_k_blocks = size_n // block_size_k
299 num_m_blocks = (
300 num_valid_tokens + block_size_m - 1
301 ) / block_size_m + num_experts
302 if num_valid_tokens // real_top_k <= block_size_m:
303 num_m_blocks = min(num_m_blocks, num_valid_tokens)
304 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
306 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256:
307 block_size_k = 256
308 num_blocks = num_blocks // (256 // block_size_k)
310 if (
311 num_m_blocks <= 16
312 and size_k % (block_size_k * 2) == 0
313 and size_k % (block_size_k * 2) == 0
314 and block_size_k <= 512
315 and num_blocks >= 512
316 ):
317 block_size_k = block_size_k * 2
318 num_blocks = num_blocks // 2
320 if num_blocks > 1024:
321 block_size_n = 256
322 num_n_blocks = num_n_blocks // 2
323 num_blocks = num_blocks // 2
325 if size_n <= 1024 and num_blocks >= 1024:
326 block_size_n = 1024
328 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
330 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
333def get_default_config(
334 M: int,
335 E: int,
336 N: int,
337 K: int,
338 topk: int,
339 dtype: str | None,
340 block_shape: list[int] | None = None,
341) -> dict[str, int]:
342 """Default Triton config for fused MoE kernel.
344 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100.
345 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each
346 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical.
347 """
348 if dtype == "fp8_w8a8" and block_shape is not None:
349 config = {
350 "BLOCK_SIZE_M": 16 if M <= 64 else 64,
351 "BLOCK_SIZE_N": block_shape[0],
352 "BLOCK_SIZE_K": block_shape[1],
353 "GROUP_SIZE_M": 1 if M <= 16 else 32,
354 "num_warps": 4,
355 "num_stages": 3,
356 }
357 else:
358 # tokens_per_expert drives block_m: use M//E (not M*topk//E) to
359 # estimate the actual per-expert token count after routing.
360 tokens_per_expert = M // max(E, 1)
362 if tokens_per_expert <= 2:
363 block_m = 16
364 elif tokens_per_expert <= 4:
365 block_m = 32
366 elif tokens_per_expert <= 16:
367 block_m = 64
368 else:
369 block_m = 128
371 # Tile sizing
372 if N >= 4096:
373 block_n = 128 if M <= 128 else 256
374 elif N >= 1024:
375 block_n = 64 if M <= 64 else 128
376 else:
377 block_n = 64 if M <= 64 else 128
379 if dtype == "fp8_w8a8":
380 block_k = 128
381 elif M <= 64:
382 block_k = 128
383 else:
384 block_k = 64
386 if tokens_per_expert > 128:
387 group_m = 16
388 elif tokens_per_expert > 32:
389 group_m = 8
390 else:
391 group_m = 1
393 # Prefer 4 warps for small tiles; only use 8 for large M
394 num_warps = 4 if M <= 128 else 8
395 num_stages = 3
397 smem_per_stage = (block_m * block_k + block_k * block_n) * 2
398 while num_stages > 2 and smem_per_stage * num_stages > 200_000:
399 num_stages -= 1
401 config = {
402 "BLOCK_SIZE_M": block_m,
403 "BLOCK_SIZE_N": block_n,
404 "BLOCK_SIZE_K": block_k,
405 "GROUP_SIZE_M": group_m,
406 "num_warps": num_warps,
407 "num_stages": num_stages,
408 }
409 return config
412def _get_config_dtype_str(
413 dtype: Optional[torch.dtype] = None,
414 use_fp8_w8a8: bool = False,
415 use_fp8_w8a16: bool = False,
416 use_int8_w8a16: bool = False,
417 use_int4_w4a16: bool = False,
418 ocp_mx_scheme: str | None = None,
419) -> str | None:
420 """Return dtype string for kernel config lookup."""
421 if use_fp8_w8a8:
422 return "fp8_w8a8"
423 elif use_fp8_w8a16:
424 return "fp8_w8a16"
425 elif use_int8_w8a16:
426 return "int8_w8a16"
427 elif use_int4_w4a16:
428 return "int4_w4a16"
429 elif ocp_mx_scheme is not None:
430 return None
431 elif dtype == torch.float:
432 return "float32"
433 return None
436# MoE activation enum
439class MoEActivation(Enum):
440 """Activation functions for MoE layers."""
442 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d]
443 SILU = "silu"
444 GELU = "gelu"
445 RELU2 = "relu2"
446 SWIGLUOAI = "swigluoai"
447 SWIGLUSTEP = "swiglustep"
449 # Non-gated: input [..., d] -> output [..., d]
450 SILU_NO_MUL = "silu_no_mul"
451 GELU_NO_MUL = "gelu_no_mul"
452 RELU2_NO_MUL = "relu2_no_mul"
454 @property
455 def is_gated(self) -> bool:
456 return not self.value.endswith("_no_mul")
458 def without_mul(self) -> "MoEActivation":
459 """Return the non-gated variant."""
460 _without_mul: dict[MoEActivation, MoEActivation] = {
461 MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
462 MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
463 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
464 }
465 return _without_mul.get(self, self)
467 @classmethod
468 def from_str(cls, s: str) -> "MoEActivation":
469 for member in cls:
470 if member.value == s:
471 return member
472 valid = [m.value for m in cls]
473 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
475 @staticmethod
476 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int:
477 """Return N for non-gated, N // 2 for gated activations."""
478 return N if not activation.is_gated else N // 2
481def apply_moe_activation(
482 activation: MoEActivation,
483 output: torch.Tensor,
484 input: torch.Tensor,
485) -> torch.Tensor:
486 """Apply MoE activation (pure PyTorch / FlagGems Triton)."""
487 assert input.dim() == 2, "Input must be 2D"
488 assert output.dim() == 2, "Output must be 2D"
489 if activation.is_gated:
490 assert output.size(-1) * 2 == input.size(-1), (
491 f"{activation.value} expects 2x ratio: "
492 f"{output.size(-1) * 2} vs {input.size(-1)}"
493 )
494 else:
495 assert output.size(-1) == input.size(-1), (
496 f"{activation.value} expects equal sizes: "
497 f"{output.size(-1)} vs {input.size(-1)}"
498 )
500 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI):
501 N = output.size(-1)
502 x, y = input[:, :N], input[:, N:]
503 _silu_and_mul_kernel(x, y, out0=output)
504 elif activation == MoEActivation.GELU:
505 N = output.size(-1)
506 gate, up = input[:, :N], input[:, N:]
507 output.copy_(F.gelu(gate) * up)
508 elif activation == MoEActivation.SWIGLUSTEP:
509 N = output.size(-1)
510 gate, up = input[:, :N], input[:, N:]
511 output.copy_(torch.sigmoid(gate) * up)
512 elif activation == MoEActivation.RELU2:
513 N = output.size(-1)
514 gate, up = input[:, :N], input[:, N:]
515 output.copy_(F.relu(gate).square() * up)
517 elif activation == MoEActivation.SILU_NO_MUL:
518 output.copy_(F.silu(input))
519 elif activation == MoEActivation.GELU_NO_MUL:
520 output.copy_(F.gelu(input))
521 elif activation == MoEActivation.RELU2_NO_MUL:
522 F.relu(input, inplace=True)
523 torch.square(input, out=output)
524 else:
525 raise ValueError(f"Unsupported FusedMoe activation: {activation}")
527 return output
530def _fp8_quantize(
531 A: torch.Tensor,
532 A_scale: Optional[torch.Tensor],
533 per_act_token: bool,
534 block_shape: Optional[list[int]] = None,
535) -> tuple[torch.Tensor, torch.Tensor]:
536 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise."""
537 fp8_dtype = torch.float8_e4m3fn
538 finfo = torch.finfo(fp8_dtype)
539 fp8_max = finfo.max
540 fp8_min = finfo.min
541 eps = 1e-10
543 if block_shape is not None:
544 assert not per_act_token
545 assert len(block_shape) == 2
546 block_k = block_shape[1]
547 assert A.size(-1) % block_k == 0
548 if A.ndim == 2 and A.stride(-1) == 1:
549 from flag_gems.ops.per_token_group_quant_fp8 import (
550 per_token_group_quant_fp8,
551 )
553 return per_token_group_quant_fp8(
554 A,
555 group_size=block_k,
556 eps=eps,
557 dtype=fp8_dtype,
558 column_major_scales=False,
559 scale_ue8m0=False,
560 )
561 orig_shape = A.shape
562 A_flat = A.reshape(-1, A.size(-1))
563 M, K = A_flat.shape
564 A_groups = A_flat.reshape(M * (K // block_k), block_k)
565 amax = (
566 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
567 )
568 scale = amax / fp8_max
569 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
570 A_q = A_q.reshape(orig_shape)
571 scale = scale.reshape(M, K // block_k)
572 return A_q, scale
574 elif per_act_token:
575 A_flat = A.reshape(-1, A.size(-1))
576 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
577 scale = amax / fp8_max
578 min_scale = torch.tensor(
579 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device
580 )
581 scale = scale.clamp(min=min_scale)
582 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
583 A_q = A_q.reshape(A.shape)
584 scale = scale.reshape(A.shape[:-1] + (1,))
585 return A_q, scale
587 else:
588 if A_scale is not None:
589 scale = (
590 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
591 )
592 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
593 return A_q, A_scale
594 else:
595 amax = A.abs().amax().clamp(min=eps).to(torch.float32)
596 scale = amax / fp8_max
597 iscale = 1.0 / scale
598 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype)
599 return A_q, scale.view(1)
602def _int8_quantize(
603 A: torch.Tensor,
604 A_scale: Optional[torch.Tensor],
605 per_act_token: bool,
606 block_shape: Optional[list[int]] = None,
607) -> tuple[torch.Tensor, torch.Tensor]:
608 """INT8 quantization: per-tensor, per-token, or block-wise."""
609 iinfo = torch.iinfo(torch.int8)
610 int8_max = iinfo.max
611 int8_min = iinfo.min
612 eps = 1e-10
614 if block_shape is not None:
615 assert not per_act_token
616 assert len(block_shape) == 2
617 block_k = block_shape[1]
618 assert A.size(-1) % block_k == 0
619 orig_shape = A.shape
620 A_flat = A.reshape(-1, A.size(-1))
621 M, K = A_flat.shape
622 A_groups = A_flat.reshape(M * (K // block_k), block_k)
623 amax = (
624 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
625 )
626 scale = amax / int8_max
627 A_q = (
628 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
629 )
630 A_q = A_q.reshape(orig_shape)
631 scale = scale.reshape(M, K // block_k)
632 return A_q, scale
634 elif per_act_token:
635 A_flat = A.reshape(-1, A.size(-1))
636 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
637 scale = amax / int8_max
638 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
639 A_q = A_q.reshape(A.shape)
640 scale = scale.reshape(A.shape[:-1] + (1,))
641 return A_q, scale
643 else:
644 assert A_scale is not None, "int8 per-tensor requires A_scale"
645 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
646 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
647 return A_q, A_scale
650def moe_kernel_quantize_input(
651 A: torch.Tensor,
652 A_scale: Optional[torch.Tensor],
653 quant_dtype: None | torch.dtype | str,
654 per_act_token_quant: bool,
655 block_shape: Optional[list[int]] = None,
656 ocp_mx_scheme: str | None = None,
657) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
658 """Quantize MoE input activations before GEMM."""
659 if ocp_mx_scheme is not None:
660 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
661 pass
662 elif ocp_mx_scheme.endswith("a_fp8"):
663 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False)
664 A = (qA.float() * qA_scale.float()).to(A.dtype)
665 return A, None
667 if quant_dtype is None:
668 return A, A_scale
669 elif quant_dtype == torch.float8_e4m3fn:
670 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
671 elif quant_dtype == torch.int8:
672 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
673 else:
674 return A, A_scale
677def _ensure_block_size_k_divisible(
678 size_k: int, block_size_k: int, group_size: int
679) -> int:
680 """Find largest block_size_k that divides size_k and is divisible by group_size."""
681 if size_k % block_size_k == 0 and block_size_k % group_size == 0:
682 return block_size_k
684 max_search = min(block_size_k, size_k)
685 start = (max_search // group_size) * group_size
686 for candidate in range(start, group_size - 1, -group_size):
687 if size_k % candidate == 0:
688 return candidate
690 if size_k % group_size == 0:
691 return group_size
693 return size_k
696@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
697@triton.jit
698def _silu_and_mul_kernel(x, y):
699 x_fp32 = x.to(tl.float32)
700 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
701 return x_silu * y
704@triton.jit
705def write_zeros_to_output(
706 c_ptr,
707 stride_cm,
708 stride_cn,
709 pid_n,
710 N,
711 offs_token,
712 token_mask,
713 BLOCK_SIZE_M,
714 BLOCK_SIZE_N,
715 compute_type,
716):
717 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
718 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
719 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
720 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
721 tl.store(c_ptrs, accumulator, mask=c_mask)
724@triton.jit
725def fused_moe_kernel_gptq_awq(
726 # Pointers to matrices
727 a_ptr,
728 b_ptr,
729 c_ptr,
730 b_scale_ptr,
731 b_zp_ptr,
732 topk_weights_ptr,
733 sorted_token_ids_ptr,
734 expert_ids_ptr,
735 num_tokens_post_padded_ptr,
736 # Matrix dimensions
737 N: tl.constexpr,
738 K: tl.constexpr,
739 EM,
740 num_valid_tokens,
741 # The stride variables represent how much to increase the ptr by when
742 # moving by 1 element in a particular dimension. E.g. `stride_am` is
743 # how much to increase `a_ptr` by to get the element one row down
744 # (A has M rows).
745 stride_am,
746 stride_ak,
747 stride_be,
748 stride_bk,
749 stride_bn,
750 stride_cm,
751 stride_cn,
752 stride_bse,
753 stride_bsk,
754 stride_bsn,
755 stride_bze,
756 stride_bzk,
757 stride_bzn,
758 block_k_diviable: tl.constexpr,
759 group_size: tl.constexpr,
760 # Meta-parameters
761 BLOCK_SIZE_M: tl.constexpr,
762 BLOCK_SIZE_N: tl.constexpr,
763 BLOCK_SIZE_K: tl.constexpr,
764 GROUP_SIZE_M: tl.constexpr,
765 SPLIT_K: tl.constexpr,
766 MUL_ROUTED_WEIGHT: tl.constexpr,
767 top_k: tl.constexpr,
768 compute_type: tl.constexpr,
769 has_zp: tl.constexpr,
770 use_int4_w4a16: tl.constexpr,
771 use_int8_w8a16: tl.constexpr,
772):
773 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights."""
774 # Map pid to C block (grouped ordering for L2 reuse)
775 pid = tl.program_id(axis=0)
776 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
777 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
778 num_pid_in_group = GROUP_SIZE_M * num_pid_n
779 group_id = pid // num_pid_in_group
780 first_pid_m = group_id * GROUP_SIZE_M
781 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
782 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
783 pid_n = (pid % num_pid_in_group) // group_size_m
785 # Create pointers for first blocks of A and B
786 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
787 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
788 return
789 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
790 # Cast to int64 to prevent overflow in stride*offset products
791 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
792 token_mask = offs_token < num_valid_tokens
794 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
795 if off_experts == -1:
796 # -----------------------------------------------------------
797 # Write back zeros to the output when the expert is not
798 # in the current expert parallel rank.
799 write_zeros_to_output(
800 c_ptr,
801 stride_cm,
802 stride_cn,
803 pid_n,
804 N,
805 offs_token,
806 token_mask,
807 BLOCK_SIZE_M,
808 BLOCK_SIZE_N,
809 compute_type,
810 )
811 return
813 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
814 offs_k = tl.arange(0, BLOCK_SIZE_K)
815 a_ptrs = a_ptr + (
816 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
817 )
819 if use_int4_w4a16:
820 b_ptrs = (
821 b_ptr
822 + off_experts * stride_be
823 + (offs_k[:, None] // 2) * stride_bk
824 + offs_bn[None, :] * stride_bn
825 )
826 b_shifter = (offs_k[:, None] % 2) * 4
827 elif use_int8_w8a16:
828 b_ptrs = (
829 b_ptr
830 + off_experts * stride_be
831 + offs_k[:, None] * stride_bk
832 + offs_bn[None, :] * stride_bn
833 )
835 if not has_zp and use_int4_w4a16:
836 b_zp_num = 8
837 if not has_zp and use_int8_w8a16:
838 b_zp_num = 128
839 elif has_zp and use_int4_w4a16:
840 b_zp_shifter = (offs_bn[None, :] % 2) * 4
842 # Accumulate C block in fp32
843 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
844 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
845 if not block_k_diviable:
846 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
847 k_other = 0.0
848 else:
849 k_mask = None
850 k_other = None
852 a = tl.load(
853 a_ptrs,
854 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
855 other=0.0,
856 )
857 b = tl.load(b_ptrs)
858 if use_int4_w4a16:
859 b = (b >> b_shifter) & 0xF
861 b_scale_ptrs = (
862 b_scale_ptr
863 + off_experts * stride_bse
864 + offs_bn[None, :] * stride_bsn
865 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
866 )
867 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
868 b_scale = b_scale.to(tl.float32)
870 if has_zp and use_int4_w4a16:
871 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
872 b_zp_ptrs = (
873 b_zp_ptr
874 + off_experts * stride_bze
875 + (offs_bn[None, :] // 2) * stride_bzn
876 + offs_k_true * stride_bzk
877 )
878 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
879 b_zp = (b_zp >> b_zp_shifter) & 0xF
880 b_zp = b_zp.to(tl.float32)
881 elif has_zp and use_int8_w8a16:
882 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
883 b_zp_ptrs = (
884 b_zp_ptr
885 + off_experts * stride_bze
886 + offs_bn[None, :] * stride_bzn
887 + offs_k_true * stride_bzk
888 )
889 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
890 b_zp = b_zp.to(tl.float32)
892 if has_zp:
893 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
894 else:
895 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
896 accumulator = tl.dot(a, b, acc=accumulator)
898 a_ptrs += BLOCK_SIZE_K * stride_ak
899 if use_int4_w4a16:
900 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
901 else:
902 b_ptrs += BLOCK_SIZE_K * stride_bk
904 if MUL_ROUTED_WEIGHT:
905 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
906 accumulator = accumulator * moe_weight[:, None]
908 accumulator = accumulator.to(compute_type)
909 # Write back output
910 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
911 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
912 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
913 tl.store(c_ptrs, accumulator, mask=c_mask)
916@triton.jit
917def fused_moe_kernel(
918 # Pointers to matrices
919 a_ptr,
920 b_ptr,
921 c_ptr,
922 b_bias_ptr,
923 a_scale_ptr,
924 b_scale_ptr,
925 topk_weights_ptr,
926 sorted_token_ids_ptr,
927 expert_ids_ptr,
928 num_tokens_post_padded_ptr,
929 # Matrix dimensions
930 N,
931 K,
932 EM,
933 num_valid_tokens,
934 stride_am,
935 stride_ak,
936 stride_be,
937 stride_bk,
938 stride_bn,
939 stride_cm,
940 stride_cn,
941 stride_asm,
942 stride_ask,
943 stride_bse,
944 stride_bsk,
945 stride_bsn,
946 stride_bbe, # bias expert stride
947 stride_bbn, # bias N stride
948 # Block size for block-wise quantization
949 group_n: tl.constexpr,
950 group_k: tl.constexpr,
951 naive_block_assignment: tl.constexpr,
952 # Meta-parameters
953 BLOCK_SIZE_M: tl.constexpr,
954 BLOCK_SIZE_N: tl.constexpr,
955 BLOCK_SIZE_K: tl.constexpr,
956 GROUP_SIZE_M: tl.constexpr,
957 SPLIT_K: tl.constexpr,
958 MUL_ROUTED_WEIGHT: tl.constexpr,
959 top_k: tl.constexpr,
960 compute_type: tl.constexpr,
961 use_fp8_w8a8: tl.constexpr,
962 use_int8_w8a8: tl.constexpr,
963 use_int8_w8a16: tl.constexpr,
964 per_channel_quant: tl.constexpr,
965 HAS_BIAS: tl.constexpr,
966 SWAP_AB: tl.constexpr,
967):
968 """Fused MoE kernel: token × expert GEMM with quantization support."""
969 # Map pid to C block (grouped ordering for L2 reuse)
970 pid = tl.program_id(axis=0)
971 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
972 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
973 num_pid_in_group = GROUP_SIZE_M * num_pid_n
974 group_id = pid // num_pid_in_group
975 first_pid_m = group_id * GROUP_SIZE_M
976 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
977 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
978 pid_n = (pid % num_pid_in_group) // group_size_m
980 # Create pointers for first blocks of A and B
981 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
982 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
983 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
984 return
985 if not naive_block_assignment:
986 offs_token_id = pid_m * BLOCK_SIZE_M + offs
987 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
988 else:
989 offs_token = tl.where(
990 offs == 0,
991 pid_m, # first element = pid_m
992 num_valid_tokens, # remaining elements = constant
993 )
994 offs_token = offs_token.to(tl.int64) # prevent int32 overflow
996 token_mask = offs_token < num_valid_tokens
998 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
999 if off_experts == -1:
1000 # Expert not in current EP rank, write zeros
1001 write_zeros_to_output(
1002 c_ptr,
1003 stride_cm,
1004 stride_cn,
1005 pid_n,
1006 N,
1007 offs_token,
1008 token_mask,
1009 BLOCK_SIZE_M,
1010 BLOCK_SIZE_N,
1011 compute_type,
1012 )
1013 return
1015 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
1016 offs_k = tl.arange(0, BLOCK_SIZE_K)
1017 a_ptrs = a_ptr + (
1018 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
1019 )
1021 b_ptrs = (
1022 b_ptr
1023 + off_experts * stride_be
1024 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
1025 )
1026 if use_int8_w8a16:
1027 b_scale_ptrs = (
1028 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
1029 )
1030 b_scale = tl.load(b_scale_ptrs)
1032 if use_fp8_w8a8 or use_int8_w8a8:
1033 if group_k > 0 and group_n > 0: # block-wise
1034 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1035 offs_bsn = offs_bn // group_n
1036 b_scale_ptrs = (
1037 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
1038 )
1039 elif per_channel_quant: # channel-wise
1040 b_scale_ptrs = (
1041 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
1042 )
1043 b_scale = tl.load(b_scale_ptrs)
1044 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1045 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
1046 else: # tensor-wise
1047 a_scale = tl.load(a_scale_ptr)
1048 b_scale = tl.load(b_scale_ptr + off_experts)
1049 if HAS_BIAS:
1050 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
1051 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
1052 # Accumulate C block in fp32
1053 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1054 if SWAP_AB:
1055 accumulator_nm = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
1056 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1057 a = tl.load(
1058 a_ptrs,
1059 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
1060 other=0.0,
1061 )
1062 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
1063 if use_int8_w8a16:
1064 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
1065 elif use_fp8_w8a8 or use_int8_w8a8:
1066 if group_k > 0 and group_n > 0:
1067 k_start = k * BLOCK_SIZE_K
1068 offs_ks = k_start // group_k
1069 a_scale = tl.load(
1070 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
1071 )
1072 if SWAP_AB:
1073 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
1074 accumulator_nm += (
1075 tl.dot(tl.trans(b), tl.trans(a))
1076 * b_scale[:, None]
1077 * a_scale[None, :]
1078 )
1079 else:
1080 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
1081 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
1082 else:
1083 if use_fp8_w8a8:
1084 accumulator = tl.dot(a, b, acc=accumulator)
1085 else:
1086 accumulator += tl.dot(a, b)
1087 else:
1088 accumulator += tl.dot(a, b)
1089 a_ptrs += BLOCK_SIZE_K * stride_ak
1090 b_ptrs += BLOCK_SIZE_K * stride_bk
1092 if SWAP_AB:
1093 accumulator = tl.trans(accumulator_nm)
1095 # Dequantization
1096 if use_int8_w8a16:
1097 accumulator = accumulator * b_scale
1098 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
1099 accumulator = accumulator * a_scale * b_scale
1101 if HAS_BIAS:
1102 accumulator += bias[None, :]
1104 # Router weight multiplication (must be in fp32)
1105 if MUL_ROUTED_WEIGHT:
1106 moe_weight = tl.load(
1107 topk_weights_ptr + offs_token,
1108 mask=token_mask,
1109 other=0,
1110 )
1111 accumulator *= moe_weight[:, None]
1113 accumulator = accumulator.to(compute_type)
1115 # Write back output
1116 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1117 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
1118 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
1119 tl.store(c_ptrs, accumulator, mask=c_mask)
1122def invoke_fused_moe_wna16_triton_kernel(
1123 A: torch.Tensor,
1124 B: torch.Tensor,
1125 C: torch.Tensor,
1126 B_scale: torch.Tensor | None,
1127 B_zp: torch.Tensor | None,
1128 topk_weights: torch.Tensor | None,
1129 sorted_token_ids: torch.Tensor,
1130 expert_ids: torch.Tensor,
1131 num_tokens_post_padded: torch.Tensor,
1132 mul_routed_weight: bool,
1133 top_k: int,
1134 config: dict[str, Any],
1135 compute_type: tl.dtype,
1136 use_int8_w8a16: bool,
1137 use_int4_w4a16: bool,
1138 block_shape: list[int] | None,
1139):
1140 assert B_scale is not None and B_scale.ndim == 3
1141 assert B_zp is None or B_zp.ndim == 3
1142 assert block_shape is not None and block_shape[0] == 0
1144 M = A.size(0)
1145 num_tokens = M * top_k
1147 EM = sorted_token_ids.size(0)
1148 if A.size(0) < config["BLOCK_SIZE_M"]:
1149 # optimize for small batch_size.
1150 # We assume that top_ids of each token is unique,
1151 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
1152 # and we can skip some invalid blocks.
1153 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
1154 grid = lambda META: (
1155 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1156 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1157 )
1158 config = config.copy()
1159 config.update(
1160 get_moe_wna16_block_config(
1161 config=config,
1162 use_moe_wna16_cuda=False,
1163 num_valid_tokens=num_tokens,
1164 size_k=A.size(1),
1165 size_n=B.size(1),
1166 num_experts=B.size(1),
1167 group_size=block_shape[1],
1168 real_top_k=top_k,
1169 block_size_m=config["BLOCK_SIZE_M"],
1170 )
1171 )
1173 fused_moe_kernel_gptq_awq[grid](
1174 A,
1175 B,
1176 C,
1177 B_scale,
1178 B_zp,
1179 topk_weights,
1180 sorted_token_ids,
1181 expert_ids,
1182 num_tokens_post_padded,
1183 B.size(1),
1184 A.size(1),
1185 EM,
1186 num_tokens,
1187 A.stride(0),
1188 A.stride(1),
1189 B.stride(0),
1190 B.stride(2),
1191 B.stride(1),
1192 C.stride(1),
1193 C.stride(2),
1194 B_scale.stride(0),
1195 B_scale.stride(2),
1196 B_scale.stride(1),
1197 B_zp.stride(0) if B_zp is not None else 0,
1198 B_zp.stride(2) if B_zp is not None else 0,
1199 B_zp.stride(1) if B_zp is not None else 0,
1200 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
1201 group_size=block_shape[1],
1202 MUL_ROUTED_WEIGHT=mul_routed_weight,
1203 top_k=top_k,
1204 compute_type=compute_type,
1205 has_zp=B_zp is not None,
1206 use_int4_w4a16=use_int4_w4a16,
1207 use_int8_w8a16=use_int8_w8a16,
1208 **config,
1209 )
1212def invoke_fused_moe_triton_kernel(
1213 A: torch.Tensor,
1214 B: torch.Tensor,
1215 C: torch.Tensor,
1216 A_scale: Optional[torch.Tensor],
1217 B_scale: Optional[torch.Tensor],
1218 topk_weights: Optional[torch.Tensor],
1219 sorted_token_ids: torch.Tensor,
1220 expert_ids: torch.Tensor,
1221 num_tokens_post_padded: torch.Tensor,
1222 mul_routed_weight: bool,
1223 top_k: int,
1224 config: dict[str, Any],
1225 compute_type: tl.dtype,
1226 use_fp8_w8a8: bool = False,
1227 use_int8_w8a8: bool = False,
1228 use_int8_w8a16: bool = False,
1229 use_int4_w4a16: bool = False,
1230 per_channel_quant: bool = False,
1231 block_shape: Optional[list[int]] = None,
1232 B_bias: torch.Tensor | None = None,
1233) -> None:
1234 """Launch the fused_moe_kernel Triton kernel."""
1235 assert topk_weights is not None or not mul_routed_weight
1236 assert topk_weights is None or topk_weights.stride(1) == 1
1237 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1239 if use_fp8_w8a8 or use_int8_w8a8:
1240 assert B_scale is not None
1241 assert block_shape is None or triton.cdiv(
1242 B.size(-2), block_shape[0]
1243 ) == B_scale.size(-2)
1244 assert block_shape is None or triton.cdiv(
1245 B.size(-1), block_shape[1]
1246 ) == B_scale.size(-1)
1247 elif use_int8_w8a16 or use_int4_w4a16:
1248 assert B_scale is not None
1249 assert block_shape is None or block_shape[0] == 0
1250 else:
1251 assert A_scale is None
1252 assert B_scale is None
1254 M = A.size(0)
1255 num_tokens = M * top_k
1256 if sorted_token_ids is not None:
1257 EM = sorted_token_ids.size(0)
1258 if A.size(0) < config["BLOCK_SIZE_M"]:
1259 EM = min(
1260 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
1261 )
1262 else:
1263 EM = num_tokens * config["BLOCK_SIZE_M"]
1264 grid = lambda META: (
1265 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1266 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1267 )
1268 HAS_BIAS = B_bias is not None
1270 config = config.copy()
1271 config["SPLIT_K"] = 1
1272 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
1273 if block_shape is not None:
1274 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
1276 swap_AB = config.pop("SWAP_AB", False)
1277 fused_moe_kernel[grid](
1278 A,
1279 B,
1280 C,
1281 B_bias,
1282 A_scale,
1283 B_scale,
1284 topk_weights,
1285 sorted_token_ids,
1286 expert_ids,
1287 num_tokens_post_padded,
1288 B.size(1), # N
1289 B.size(2), # K
1290 EM,
1291 num_tokens,
1292 A.stride(0),
1293 A.stride(1),
1294 B.stride(0),
1295 B.stride(2),
1296 B.stride(1),
1297 C.stride(1),
1298 C.stride(2),
1299 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
1300 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
1301 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
1302 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
1303 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
1304 B_bias.stride(0) if B_bias is not None else 0,
1305 B_bias.stride(1) if B_bias is not None else 0,
1306 0 if block_shape is None else block_shape[0],
1307 0 if block_shape is None else block_shape[1],
1308 MUL_ROUTED_WEIGHT=mul_routed_weight,
1309 top_k=top_k,
1310 compute_type=compute_type,
1311 use_fp8_w8a8=use_fp8_w8a8,
1312 use_int8_w8a8=use_int8_w8a8,
1313 use_int8_w8a16=use_int8_w8a16,
1314 per_channel_quant=per_channel_quant,
1315 naive_block_assignment=(sorted_token_ids is None),
1316 HAS_BIAS=HAS_BIAS,
1317 BLOCK_SIZE_K=BLOCK_SIZE_K,
1318 SWAP_AB=swap_AB,
1319 **config,
1320 )
1323def dispatch_fused_moe_kernel(
1324 A: torch.Tensor,
1325 B: torch.Tensor,
1326 C: torch.Tensor,
1327 A_scale: Optional[torch.Tensor],
1328 B_scale: Optional[torch.Tensor],
1329 B_zp: Optional[torch.Tensor],
1330 topk_weights: Optional[torch.Tensor],
1331 sorted_token_ids: torch.Tensor,
1332 expert_ids: torch.Tensor,
1333 num_tokens_post_padded: torch.Tensor,
1334 mul_routed_weight: bool,
1335 top_k: int,
1336 config: dict[str, Any],
1337 compute_type: tl.dtype,
1338 use_fp8_w8a8: bool,
1339 use_int8_w8a8: bool,
1340 use_int8_w8a16: bool,
1341 use_int4_w4a16: bool,
1342 per_channel_quant: bool,
1343 block_shape: Optional[list[int]] = None,
1344 B_bias: Optional[torch.Tensor] = None,
1345) -> None:
1346 """Dispatch to the appropriate fused MoE kernel based on quantization flags."""
1347 assert topk_weights is not None or not mul_routed_weight
1348 assert topk_weights is None or topk_weights.stride(1) == 1
1349 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1351 # M = A.size(0)
1352 # num_tokens = M * top_k
1354 if False:
1355 # TODO: Other precision-specific implementations
1356 # use_fp8_w8a8,
1357 # use_int8_w8a8,
1358 # use_int8_w8a16,
1359 # use_int4_w4a16,
1360 pass
1361 if (use_int8_w8a16 or use_int4_w4a16) and (
1362 block_shape is not None and block_shape[1] > 0
1363 ):
1364 assert B_bias is None
1365 invoke_fused_moe_wna16_triton_kernel(
1366 A,
1367 B,
1368 C,
1369 B_scale,
1370 B_zp,
1371 topk_weights,
1372 sorted_token_ids,
1373 expert_ids,
1374 num_tokens_post_padded,
1375 mul_routed_weight,
1376 top_k,
1377 config,
1378 compute_type,
1379 use_int8_w8a16,
1380 use_int4_w4a16,
1381 block_shape,
1382 )
1383 else:
1384 invoke_fused_moe_triton_kernel(
1385 A,
1386 B,
1387 C,
1388 A_scale,
1389 B_scale,
1390 topk_weights,
1391 sorted_token_ids,
1392 expert_ids,
1393 num_tokens_post_padded,
1394 mul_routed_weight,
1395 top_k,
1396 config,
1397 compute_type,
1398 use_fp8_w8a8,
1399 use_int8_w8a8,
1400 use_int8_w8a16,
1401 use_int4_w4a16,
1402 per_channel_quant,
1403 block_shape,
1404 B_bias,
1405 )
1408def fused_experts_impl(
1409 hidden_states: torch.Tensor,
1410 w1: torch.Tensor,
1411 w2: torch.Tensor,
1412 topk_weights: torch.Tensor,
1413 topk_ids: torch.Tensor,
1414 inplace: bool = False,
1415 activation: str = "silu",
1416 apply_router_weight_on_input: bool = False,
1417 use_fp8_w8a8: bool = False,
1418 use_int8_w8a8: bool = False,
1419 use_int8_w8a16: bool = False,
1420 use_int4_w4a16: bool = False,
1421 ocp_mx_scheme: str | None = None,
1422 per_channel_quant: bool = False,
1423 global_num_experts: int = -1,
1424 expert_map: torch.Tensor | None = None,
1425 w1_scale: Optional[torch.Tensor] = None,
1426 w2_scale: Optional[torch.Tensor] = None,
1427 w1_zp: torch.Tensor | None = None,
1428 w2_zp: torch.Tensor | None = None,
1429 a1_scale: Optional[torch.Tensor] = None,
1430 a2_scale: Optional[torch.Tensor] = None,
1431 block_shape: Optional[list[int]] = None,
1432 w1_bias: Optional[torch.Tensor] = None,
1433 w2_bias: Optional[torch.Tensor] = None,
1434) -> torch.Tensor:
1435 logger.debug("GEMS FUSED MOE")
1436 assert (
1437 activation == "silu"
1438 ), f"Only 'silu' activation is supported, got {activation}"
1440 activation_enum = MoEActivation.from_str(activation)
1442 # Check constraints
1443 if use_int4_w4a16:
1444 # INT4 stored unpacked in INT8 containers (full K dim)
1445 assert hidden_states.size(1) == w1.size(
1446 2
1447 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1448 elif ocp_mx_scheme is not None:
1449 if ocp_mx_scheme.startswith("w_mxfp4"):
1450 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
1451 elif ocp_mx_scheme.startswith("w_mxfp6"):
1452 assert (
1453 hidden_states.size(1) == (w1.size(2) * 4) // 3
1454 ), "hidden size mismatch"
1455 else:
1456 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1457 else:
1458 assert hidden_states.size(1) == w1.size(
1459 2
1460 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1462 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
1463 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1464 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
1465 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
1466 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
1468 num_tokens = hidden_states.size(0)
1469 E, N, _ = w1.size()
1470 K = w2.size(1)
1471 if global_num_experts == -1:
1472 global_num_experts = E
1473 top_k_num = topk_ids.size(1)
1475 CHUNK_SIZE: int = 16 * 1024
1476 M = min(num_tokens, CHUNK_SIZE)
1478 config_dtype = _get_config_dtype_str(
1479 use_fp8_w8a8=use_fp8_w8a8,
1480 use_int8_w8a16=use_int8_w8a16,
1481 use_int4_w4a16=use_int4_w4a16,
1482 ocp_mx_scheme=ocp_mx_scheme,
1483 dtype=hidden_states.dtype,
1484 )
1486 quant_dtype = _get_config_quant_dtype(
1487 use_fp8_w8a8=use_fp8_w8a8,
1488 use_int8_w8a8=use_int8_w8a8,
1489 ocp_mx_scheme=ocp_mx_scheme,
1490 )
1492 get_config_func = functools.partial(
1493 try_get_optimal_moe_config,
1494 w1.size(),
1495 w2.size(),
1496 top_k_num,
1497 config_dtype,
1498 block_shape=block_shape,
1499 E=E,
1500 )
1502 config = get_config_func(M)
1504 # cache1 and cache3 share memory (non-overlapping lifetime)
1505 cache13 = torch.empty(
1506 M * top_k_num * max(N, K),
1507 device=hidden_states.device,
1508 dtype=hidden_states.dtype,
1509 )
1510 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
1511 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
1513 # cache2 needs separate memory (concurrent with cache1)
1514 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
1515 intermediate_cache2 = torch.empty(
1516 (M * top_k_num, activation_out_dim),
1517 device=hidden_states.device,
1518 dtype=hidden_states.dtype,
1519 )
1521 if hidden_states.dtype == torch.bfloat16:
1522 compute_type = tl.bfloat16
1523 elif hidden_states.dtype == torch.float16:
1524 compute_type = tl.float16
1525 elif hidden_states.dtype == torch.float32:
1526 compute_type = tl.float32
1527 else:
1528 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1530 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
1532 if ocp_mx_scheme is not None:
1533 # Dequantize OCP MX weights (TODO: skip on platforms with native MX)
1534 if ocp_mx_scheme.startswith("w_mxfp4"):
1535 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
1536 w1_scale = None
1537 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
1538 w2_scale = None
1539 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
1540 w1 = dequant_mxfp6(
1541 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1542 )
1543 w1_scale = None
1544 w2 = dequant_mxfp6(
1545 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1546 )
1547 w2_scale = None
1548 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
1549 w1 = dequant_mxfp6(
1550 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1551 )
1552 w1_scale = None
1553 w2 = dequant_mxfp6(
1554 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1555 )
1556 w2_scale = None
1557 else:
1558 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1560 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot)
1561 if use_int8_w8a16 or use_int4_w4a16:
1562 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype)
1563 w1_scale = None
1564 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype)
1565 w2_scale = None
1566 use_int8_w8a16 = False
1567 use_int4_w4a16 = False
1569 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
1570 begin_chunk_idx, end_chunk_idx = (
1571 chunk * CHUNK_SIZE,
1572 min((chunk + 1) * CHUNK_SIZE, num_tokens),
1573 )
1574 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
1575 tokens_in_chunk, _ = curr_hidden_states.size()
1577 if tokens_in_chunk == 0:
1578 break
1580 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
1581 # Adjust cache size for last chunk
1582 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1583 intermediate_cache2 = intermediate_cache2[
1584 : tokens_in_chunk * topk_ids.size(1)
1585 ]
1586 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
1587 config = get_config_func(tokens_in_chunk)
1589 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
1590 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
1591 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
1592 A=curr_hidden_states,
1593 A_scale=a1_scale,
1594 quant_dtype=quant_dtype,
1595 per_act_token_quant=per_channel_quant,
1596 block_shape=block_shape,
1597 ocp_mx_scheme=ocp_mx_scheme,
1598 )
1600 SPARSITY_FACTOR = 4
1601 naive_block_assignment = (
1602 expert_map is None
1603 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
1604 and not (
1605 (use_int8_w8a16 or use_int4_w4a16)
1606 and block_shape is not None
1607 and block_shape[1] > 0
1608 )
1609 )
1611 if not naive_block_assignment:
1612 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
1613 curr_topk_ids,
1614 config["BLOCK_SIZE_M"],
1615 global_num_experts,
1616 expert_map,
1617 # ignore_invalid_experts=True,
1618 )
1619 else:
1620 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
1621 expert_ids = curr_topk_ids.view(-1)
1622 num_tokens_post_padded = torch.empty(
1623 (1), dtype=torch.int32, device=topk_ids.device
1624 )
1625 num_tokens_post_padded.fill_(max_num_tokens_padded)
1626 sorted_token_ids = None
1628 dispatch_fused_moe_kernel(
1629 qcurr_hidden_states,
1630 w1,
1631 intermediate_cache1,
1632 a1q_scale,
1633 w1_scale,
1634 w1_zp,
1635 curr_topk_weights,
1636 sorted_token_ids,
1637 expert_ids,
1638 num_tokens_post_padded,
1639 apply_router_weight_on_input,
1640 top_k_num,
1641 config,
1642 compute_type=compute_type,
1643 use_fp8_w8a8=use_fp8_w8a8,
1644 use_int8_w8a8=use_int8_w8a8,
1645 use_int8_w8a16=use_int8_w8a16,
1646 use_int4_w4a16=use_int4_w4a16,
1647 per_channel_quant=per_channel_quant,
1648 block_shape=block_shape,
1649 B_bias=w1_bias,
1650 )
1652 apply_moe_activation(
1653 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
1654 )
1656 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
1657 A=intermediate_cache2,
1658 A_scale=a2_scale,
1659 quant_dtype=quant_dtype,
1660 per_act_token_quant=per_channel_quant,
1661 block_shape=block_shape,
1662 ocp_mx_scheme=ocp_mx_scheme,
1663 )
1665 if expert_map is not None:
1666 intermediate_cache3.zero_()
1668 dispatch_fused_moe_kernel(
1669 qintermediate_cache2,
1670 w2,
1671 intermediate_cache3,
1672 a2q_scale,
1673 w2_scale,
1674 w2_zp,
1675 curr_topk_weights,
1676 sorted_token_ids,
1677 expert_ids,
1678 num_tokens_post_padded,
1679 not apply_router_weight_on_input,
1680 1,
1681 config,
1682 compute_type=compute_type,
1683 use_fp8_w8a8=use_fp8_w8a8,
1684 use_int8_w8a8=use_int8_w8a8,
1685 use_int8_w8a16=use_int8_w8a16,
1686 use_int4_w4a16=use_int4_w4a16,
1687 per_channel_quant=per_channel_quant,
1688 block_shape=block_shape,
1689 B_bias=w2_bias,
1690 )
1692 moe_sum(
1693 intermediate_cache3.view(*intermediate_cache3.size()),
1694 out_hidden_states[begin_chunk_idx:end_chunk_idx],
1695 )
1697 return out_hidden_states
1700def inplace_fused_experts(
1701 hidden_states: torch.Tensor,
1702 w1: torch.Tensor,
1703 w2: torch.Tensor,
1704 topk_weights: torch.Tensor,
1705 topk_ids: torch.Tensor,
1706 activation: str = "silu",
1707 apply_router_weight_on_input: bool = False,
1708 use_fp8_w8a8: bool = False,
1709 use_int8_w8a8: bool = False,
1710 use_int8_w8a16: bool = False,
1711 use_int4_w4a16: bool = False,
1712 per_channel_quant: bool = False,
1713 global_num_experts: int = -1,
1714 w1_scale: Optional[torch.Tensor] = None,
1715 w2_scale: Optional[torch.Tensor] = None,
1716 a1_scale: Optional[torch.Tensor] = None,
1717 a2_scale: Optional[torch.Tensor] = None,
1718 block_shape: Optional[list[int]] = None,
1719 w1_bias: Optional[torch.Tensor] = None,
1720 w2_bias: Optional[torch.Tensor] = None,
1721) -> None:
1722 """
1723 In-place fused MoE: writes output directly into ``hidden_states``.
1725 Same semantics as ``fused_experts_impl(..., inplace=True)``.
1726 Returns None (the result is stored in ``hidden_states``).
1727 """
1728 fused_experts_impl(
1729 hidden_states,
1730 w1,
1731 w2,
1732 topk_weights,
1733 topk_ids,
1734 inplace=True,
1735 activation=activation,
1736 apply_router_weight_on_input=apply_router_weight_on_input,
1737 use_fp8_w8a8=use_fp8_w8a8,
1738 use_int8_w8a8=use_int8_w8a8,
1739 use_int8_w8a16=use_int8_w8a16,
1740 use_int4_w4a16=use_int4_w4a16,
1741 per_channel_quant=per_channel_quant,
1742 global_num_experts=global_num_experts,
1743 w1_scale=w1_scale,
1744 w2_scale=w2_scale,
1745 a1_scale=a1_scale,
1746 a2_scale=a2_scale,
1747 block_shape=block_shape,
1748 w1_bias=w1_bias,
1749 w2_bias=w2_bias,
1750 )
1753def outplace_fused_experts(
1754 hidden_states: torch.Tensor,
1755 w1: torch.Tensor,
1756 w2: torch.Tensor,
1757 topk_weights: torch.Tensor,
1758 topk_ids: torch.Tensor,
1759 activation: str = "silu",
1760 apply_router_weight_on_input: bool = False,
1761 use_fp8_w8a8: bool = False,
1762 use_int8_w8a8: bool = False,
1763 use_int8_w8a16: bool = False,
1764 use_int4_w4a16: bool = False,
1765 per_channel_quant: bool = False,
1766 global_num_experts: int = -1,
1767 w1_scale: Optional[torch.Tensor] = None,
1768 w2_scale: Optional[torch.Tensor] = None,
1769 a1_scale: Optional[torch.Tensor] = None,
1770 a2_scale: Optional[torch.Tensor] = None,
1771 block_shape: Optional[list[int]] = None,
1772 w1_bias: Optional[torch.Tensor] = None,
1773 w2_bias: Optional[torch.Tensor] = None,
1774) -> torch.Tensor:
1775 """
1776 Out-of-place fused MoE: allocates and returns a new output tensor.
1778 Same semantics as ``fused_experts_impl(..., inplace=False)``.
1779 """
1780 return fused_experts_impl(
1781 hidden_states,
1782 w1,
1783 w2,
1784 topk_weights,
1785 topk_ids,
1786 inplace=False,
1787 activation=activation,
1788 apply_router_weight_on_input=apply_router_weight_on_input,
1789 use_fp8_w8a8=use_fp8_w8a8,
1790 use_int8_w8a8=use_int8_w8a8,
1791 use_int8_w8a16=use_int8_w8a16,
1792 use_int4_w4a16=use_int4_w4a16,
1793 per_channel_quant=per_channel_quant,
1794 global_num_experts=global_num_experts,
1795 w1_scale=w1_scale,
1796 w2_scale=w2_scale,
1797 a1_scale=a1_scale,
1798 a2_scale=a2_scale,
1799 block_shape=block_shape,
1800 w1_bias=w1_bias,
1801 w2_bias=w2_bias,
1802 )