Coverage for src/flag_gems/runtime/backend/_ascend/fused/fused_moe.py: 0%
687 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3#
4# Adapted from the vLLM project (https://github.com/vllm-project/vllm).
5# Source files under vllm/model_executor/layers/:
6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl
7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation
8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input
9# fused_moe/config.py – _get_config_dtype_str
10# quantization/utils/mxfp4_utils.py – dequant_mxfp4
11# quantization/utils/mxfp6_utils.py – dequant_mxfp6
12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE
15import functools
16import logging
17import os
18from enum import Enum
19from typing import Any, Optional
21import torch
22import torch.nn.functional as F
23import triton
24import triton.language as tl
25import yaml
27# Using relative imports will cause the module to be not found.
28from flag_gems.runtime.backend._ascend.fused.moe_align_block_size import (
29 moe_align_block_size,
30)
31from flag_gems.runtime.backend._ascend.fused.moe_sum import moe_sum
32from flag_gems.utils import pointwise_dynamic
34logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
36# OCP MX quantization helpers (requires amd-quark)
38OCP_MX_BLOCK_SIZE = 32
41@functools.lru_cache(maxsize=1)
42def get_embedded_moe_configs():
43 config_path = os.path.join(
44 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.yaml"
45 )
46 if not os.path.exists(config_path):
47 return {}, {}
48 with open(config_path, "r") as f:
49 # JSON keys are strings, values are dicts where keys are M and values are configs
50 data = yaml.safe_load(f)
52 fallback = data.get("_FALLBACK", {})
54 # We need to convert the innermost keys (which are stringified integers for M) back to integers.
55 # Ensure we map the lists back to config dicts.
56 keys_order = [
57 "BLOCK_SIZE_M",
58 "BLOCK_SIZE_N",
59 "BLOCK_SIZE_K",
60 "GROUP_SIZE_M",
61 "num_warps",
62 "num_stages",
63 ]
64 parsed_data = {}
65 for dev, configs in data.items():
66 if dev == "_FALLBACK":
67 continue
68 parsed_data[dev] = {}
69 for k, m_dict in configs.items():
70 parsed_dict = {}
71 for m, v in m_dict.items():
72 if isinstance(v, list):
73 parsed_dict[int(m)] = dict(zip(keys_order, v))
74 else:
75 parsed_dict[int(m)] = v
76 parsed_data[dev][k] = parsed_dict
78 return parsed_data, fallback
81def dequant_mxfp4(
82 x: torch.Tensor,
83 scale: torch.Tensor,
84 float_dtype: torch.dtype,
85) -> torch.Tensor:
86 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4."""
87 try:
88 from quark.torch.kernel import mx
89 except ImportError as err:
90 raise ImportError("amd-quark is required for MX-FP4") from err
92 return mx.dq_mxfp4(x, scale, float_dtype)
95def dequant_mxfp6(
96 x: torch.Tensor,
97 scale: torch.Tensor,
98 float_dtype: torch.dtype,
99 quant_dtype: str,
100) -> torch.Tensor:
101 """Dequantize MXFP6 tensor via quark hw_emulation."""
102 try:
103 from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
104 dequantize_fp4_fp6_per_group,
105 )
106 from quark.torch.utils.pack import create_pack_method
107 except ImportError as err:
108 raise ImportError("amd-quark is required for MX-FP6") from err
110 pack_method = create_pack_method(None, dtype=quant_dtype)
111 unpacked_x = pack_method.unpack(x, reorder=False)
113 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype)
115 return dequantize_fp4_fp6_per_group(
116 unpacked_x,
117 scale,
118 axis=-1,
119 group_size=OCP_MX_BLOCK_SIZE,
120 quant_dtype=quant_dtype,
121 ).to(float_dtype)
124# Activation quantization helpers
127@functools.lru_cache(maxsize=1)
128def _get_device_name() -> str:
129 """Return the normalised CUDA device name (spaces replaced by underscores).
131 Matches the naming convention used by vLLM for its per-device config files.
132 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture).
133 """
134 name = torch.npu.get_device_name().replace(" ", "_")
135 # Normalise the H200 product family to a single key, following vLLM.
136 if "H200" in name.split("_"):
137 name = "NVIDIA_H200"
138 # H800 has the same SM 9.0 as H100; use H100 configs as fallback.
139 embedded_configs, fallback_mapping = get_embedded_moe_configs()
140 if name in embedded_configs:
141 return name
142 # Fallback mapping for devices whose tuning profiles are equivalent.
143 fallback = fallback_mapping.get(name)
144 if fallback and fallback in embedded_configs:
145 logger.info("Device %s not in config table, falling back to %s", name, fallback)
146 return fallback
147 return name
150def get_moe_configs(
151 E: int,
152 N: int,
153 dtype: str | None,
154 block_n: int | None = None,
155 block_k: int | None = None,
156) -> dict[int, Any] | None:
157 """
158 Return optimized configurations for the fused MoE kernel.
160 Looks up pre-tuned configs from the embedded table (ported from vLLM)
161 for the current GPU device. Returns None if no matching config is found.
162 """
163 device_name = _get_device_name()
164 embedded_configs, _ = get_embedded_moe_configs()
165 device_table = embedded_configs.get(device_name)
166 if device_table is None:
167 logger.warning(
168 "No embedded MoE configs for device %s. Will use default config.",
169 device_name,
170 )
171 return None
173 _block_n = block_n if block_n else 0
174 _block_k = block_k if block_k else 0
175 key = f"{E},{N},{dtype},{_block_n},{_block_k}"
176 configs = device_table.get(key)
177 if configs is not None:
178 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key)
179 return configs
180 logger.warning(
181 "No embedded MoE config for device=%s, key=%s. Will use default config.",
182 device_name,
183 key,
184 )
185 return None
188def try_get_optimal_moe_config(
189 w1_shape: tuple[int, ...],
190 w2_shape: tuple[int, ...],
191 top_k: int,
192 dtype: str | None,
193 M: int,
194 block_shape: list[int] | None = None,
195) -> dict[str, int]:
196 override_config: Optional[dict[str, Any]] = None
197 if override_config:
198 config = override_config
199 else:
200 # First try to load optimal config from the file
201 E, _, N = w2_shape
202 if dtype == "int4_w4a16":
203 N = N * 2
204 block_n = block_shape[0] if block_shape else 0
205 block_k = block_shape[1] if block_shape else 0
206 configs = get_moe_configs(E, N, dtype, block_n, block_k)
208 if configs:
209 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
210 else:
211 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
212 return config
215def _get_config_quant_dtype(
216 use_fp8_w8a8: bool,
217 use_int8_w8a8: bool,
218 ocp_mx_scheme: str | None,
219) -> None | torch.dtype | str:
220 """Map quantization flags to the corresponding dtype."""
221 if use_fp8_w8a8:
222 return torch.float8_e4m3fn
223 elif use_int8_w8a8:
224 return torch.int8
225 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
226 return "mxfp4"
227 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}:
228 return "mxfp6_e3m2"
229 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
230 return "mxfp6_e2m3"
231 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
232 return torch.bfloat16
233 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
234 return torch.float8_e4m3fn
236 return None
239def get_moe_wna16_block_config(
240 config: dict[str, int],
241 use_moe_wna16_cuda: bool,
242 num_valid_tokens: int,
243 size_k: int,
244 size_n: int,
245 num_experts: int,
246 group_size: int,
247 real_top_k: int,
248 block_size_m: int,
249):
250 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config:
251 return {}
252 if not use_moe_wna16_cuda:
253 if num_valid_tokens // real_top_k == 1:
254 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64}
255 else:
256 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}
257 else:
258 block_size_n = 128
259 block_size_k = 128
260 if block_size_k <= group_size:
261 block_size_k = group_size
263 num_n_blocks = size_k // block_size_k
264 num_k_blocks = size_n // block_size_k
265 num_m_blocks = (
266 num_valid_tokens + block_size_m - 1
267 ) / block_size_m + num_experts
268 if num_valid_tokens // real_top_k <= block_size_m:
269 num_m_blocks = min(num_m_blocks, num_valid_tokens)
270 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks
272 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256:
273 block_size_k = 256
274 num_blocks = num_blocks // (256 // block_size_k)
276 if (
277 num_m_blocks <= 16
278 and size_k % (block_size_k * 2) == 0
279 and size_k % (block_size_k * 2) == 0
280 and block_size_k <= 512
281 and num_blocks >= 512
282 ):
283 block_size_k = block_size_k * 2
284 num_blocks = num_blocks // 2
286 if num_blocks > 1024:
287 block_size_n = 256
288 num_n_blocks = num_n_blocks // 2
289 num_blocks = num_blocks // 2
291 if size_n <= 1024 and num_blocks >= 1024:
292 block_size_n = 1024
294 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
296 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
299def get_default_config(
300 M: int,
301 E: int,
302 N: int,
303 K: int,
304 topk: int,
305 dtype: str | None,
306 block_shape: list[int] | None = None,
307) -> dict[str, int]:
308 """Default Triton config for fused MoE kernel.
310 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100.
311 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each
312 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical.
313 """
314 if dtype == "fp8_w8a8" and block_shape is not None:
315 config = {
316 "BLOCK_SIZE_M": 16 if M <= 64 else 64,
317 "BLOCK_SIZE_N": block_shape[0],
318 "BLOCK_SIZE_K": block_shape[1],
319 "GROUP_SIZE_M": 1 if M <= 16 else 32,
320 "num_warps": 4,
321 "num_stages": 3,
322 }
323 else:
324 # tokens_per_expert drives block_m: use M//E (not M*topk//E) to
325 # estimate the actual per-expert token count after routing.
326 tokens_per_expert = M // max(E, 1)
328 if tokens_per_expert <= 2:
329 block_m = 16
330 elif tokens_per_expert <= 4:
331 block_m = 32
332 elif tokens_per_expert <= 16:
333 block_m = 64
334 else:
335 block_m = 128
337 # Tile sizing
338 if N >= 4096:
339 block_n = 128 if M <= 128 else 256
340 elif N >= 1024:
341 block_n = 64 if M <= 64 else 128
342 else:
343 block_n = 64 if M <= 64 else 128
345 if dtype == "fp8_w8a8":
346 block_k = 128
347 else:
348 # Cap BLOCK_SIZE_K at 32: BK=64 with BM≥64 triggers
349 # triton-ascend compiler errors ('vsel' unsupported) on
350 # large shapes (e.g. Mixtral N=28672).
351 block_k = 32
353 if tokens_per_expert > 128:
354 group_m = 16
355 elif tokens_per_expert > 32:
356 group_m = 8
357 else:
358 group_m = 1
360 # Adaptive stages: optimize for different M sizes
361 # Small M: use more stages for better data reuse
362 # Large M: use fewer stages to reduce memory overhead
363 if M <= 64:
364 num_stages = 2
365 num_warps = 2
366 elif M <= 256:
367 num_stages = 2
368 num_warps = 4
369 else:
370 num_stages = 2
371 num_warps = 4
373 # UB budget check for Ascend NPU (192KB = 196608 bytes)
374 # Account for: A tile (BM*BK*2) + B tile (BK*BN*2) + Accumulator (BM*BN*4)
375 # Use tighter safety factor for small M
376 UB_LIMIT = 196608
377 SAFETY_FACTOR = 0.65 if M <= 64 else 0.70
379 ub_per_stage = (
380 block_m * block_k * 2 # A tile (bf16)
381 + block_k * block_n * 2 # B tile (bf16)
382 + block_m * block_n * 4 # Accumulator (fp32)
383 )
385 # Reduce stages if needed
386 while num_stages > 1 and ub_per_stage * num_stages > UB_LIMIT * SAFETY_FACTOR:
387 num_stages -= 1
389 # Reduce block sizes if still over budget
390 while num_stages >= 1 and ub_per_stage * num_stages > UB_LIMIT * SAFETY_FACTOR:
391 # Reduce in order: BN (biggest impact), BM, BK
392 if block_n > 64:
393 block_n = max(64, block_n // 2)
394 elif block_m > 16:
395 block_m = max(16, block_m // 2)
396 elif block_k > 16:
397 block_k = max(16, block_k // 2)
398 else:
399 break
400 ub_per_stage = (
401 block_m * block_k * 2 + block_k * block_n * 2 + block_m * block_n * 4
402 )
404 config = {
405 "BLOCK_SIZE_M": block_m,
406 "BLOCK_SIZE_N": block_n,
407 "BLOCK_SIZE_K": block_k,
408 "GROUP_SIZE_M": group_m,
409 "num_warps": num_warps,
410 "num_stages": num_stages,
411 }
412 return config
415def _get_config_dtype_str(
416 dtype: Optional[torch.dtype] = None,
417 use_fp8_w8a8: bool = False,
418 use_fp8_w8a16: bool = False,
419 use_int8_w8a16: bool = False,
420 use_int4_w4a16: bool = False,
421 ocp_mx_scheme: str | None = None,
422) -> str | None:
423 """Return dtype string for kernel config lookup."""
424 if use_fp8_w8a8:
425 return "fp8_w8a8"
426 elif use_fp8_w8a16:
427 return "fp8_w8a16"
428 elif use_int8_w8a16:
429 return "int8_w8a16"
430 elif use_int4_w4a16:
431 return "int4_w4a16"
432 elif ocp_mx_scheme is not None:
433 return None
434 elif dtype == torch.float:
435 return "float32"
436 return None
439# MoE activation enum
442class MoEActivation(Enum):
443 """Activation functions for MoE layers."""
445 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d]
446 SILU = "silu"
447 GELU = "gelu"
448 RELU2 = "relu2"
449 SWIGLUOAI = "swigluoai"
450 SWIGLUSTEP = "swiglustep"
452 # Non-gated: input [..., d] -> output [..., d]
453 SILU_NO_MUL = "silu_no_mul"
454 GELU_NO_MUL = "gelu_no_mul"
455 RELU2_NO_MUL = "relu2_no_mul"
457 @property
458 def is_gated(self) -> bool:
459 return not self.value.endswith("_no_mul")
461 def without_mul(self) -> "MoEActivation":
462 """Return the non-gated variant."""
463 _without_mul: dict[MoEActivation, MoEActivation] = {
464 MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
465 MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
466 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
467 }
468 return _without_mul.get(self, self)
470 @classmethod
471 def from_str(cls, s: str) -> "MoEActivation":
472 for member in cls:
473 if member.value == s:
474 return member
475 valid = [m.value for m in cls]
476 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
478 @staticmethod
479 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int:
480 """Return N for non-gated, N // 2 for gated activations."""
481 return N if not activation.is_gated else N // 2
484def apply_moe_activation(
485 activation: MoEActivation,
486 output: torch.Tensor,
487 input: torch.Tensor,
488) -> torch.Tensor:
489 """Apply MoE activation (pure PyTorch / FlagGems Triton)."""
490 assert input.dim() == 2, "Input must be 2D"
491 assert output.dim() == 2, "Output must be 2D"
492 if activation.is_gated:
493 assert output.size(-1) * 2 == input.size(
494 -1
495 ), f"{activation.value} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
496 else:
497 assert output.size(-1) == input.size(
498 -1
499 ), f"{activation.value} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
501 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI):
502 N = output.size(-1)
503 x, y = input[:, :N], input[:, N:]
504 _silu_and_mul_kernel(x, y, out0=output)
505 elif activation == MoEActivation.GELU:
506 N = output.size(-1)
507 gate, up = input[:, :N], input[:, N:]
508 output.copy_(F.gelu(gate) * up)
509 elif activation == MoEActivation.SWIGLUSTEP:
510 N = output.size(-1)
511 gate, up = input[:, :N], input[:, N:]
512 output.copy_(torch.sigmoid(gate) * up)
513 elif activation == MoEActivation.RELU2:
514 N = output.size(-1)
515 gate, up = input[:, :N], input[:, N:]
516 output.copy_(F.relu(gate).square() * up)
518 elif activation == MoEActivation.SILU_NO_MUL:
519 output.copy_(F.silu(input))
520 elif activation == MoEActivation.GELU_NO_MUL:
521 output.copy_(F.gelu(input))
522 elif activation == MoEActivation.RELU2_NO_MUL:
523 F.relu(input, inplace=True)
524 torch.square(input, out=output)
525 else:
526 raise ValueError(f"Unsupported FusedMoe activation: {activation}")
528 return output
531def _fp8_quantize(
532 A: torch.Tensor,
533 A_scale: Optional[torch.Tensor],
534 per_act_token: bool,
535 block_shape: Optional[list[int]] = None,
536) -> tuple[torch.Tensor, torch.Tensor]:
537 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise."""
538 fp8_dtype = torch.float8_e4m3fn
539 finfo = torch.finfo(fp8_dtype)
540 fp8_max = finfo.max
541 fp8_min = finfo.min
542 eps = 1e-10
544 if block_shape is not None:
545 assert not per_act_token
546 assert len(block_shape) == 2
547 block_k = block_shape[1]
548 assert A.size(-1) % block_k == 0
549 orig_shape = A.shape
550 A_flat = A.reshape(-1, A.size(-1))
551 M, K = A_flat.shape
552 A_groups = A_flat.reshape(M * (K // block_k), block_k)
553 amax = (
554 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
555 )
556 scale = amax / fp8_max
557 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
558 A_q = A_q.reshape(orig_shape)
559 scale = scale.reshape(M, K // block_k)
560 return A_q, scale
562 elif per_act_token:
563 A_flat = A.reshape(-1, A.size(-1))
564 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
565 scale = amax / fp8_max
566 min_scale = torch.tensor(
567 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device
568 )
569 scale = scale.clamp(min=min_scale)
570 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
571 A_q = A_q.reshape(A.shape)
572 scale = scale.reshape(A.shape[:-1] + (1,))
573 return A_q, scale
575 else:
576 if A_scale is not None:
577 scale = (
578 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
579 )
580 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype)
581 return A_q, A_scale
582 else:
583 amax = A.abs().amax().clamp(min=eps).to(torch.float32)
584 scale = amax / fp8_max
585 iscale = 1.0 / scale
586 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype)
587 return A_q, scale.view(1)
590def _int8_quantize(
591 A: torch.Tensor,
592 A_scale: Optional[torch.Tensor],
593 per_act_token: bool,
594 block_shape: Optional[list[int]] = None,
595) -> tuple[torch.Tensor, torch.Tensor]:
596 """INT8 quantization: per-tensor, per-token, or block-wise."""
597 iinfo = torch.iinfo(torch.int8)
598 int8_max = iinfo.max
599 int8_min = iinfo.min
600 eps = 1e-10
602 if block_shape is not None:
603 assert not per_act_token
604 assert len(block_shape) == 2
605 block_k = block_shape[1]
606 assert A.size(-1) % block_k == 0
607 orig_shape = A.shape
608 A_flat = A.reshape(-1, A.size(-1))
609 M, K = A_flat.shape
610 A_groups = A_flat.reshape(M * (K // block_k), block_k)
611 amax = (
612 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
613 )
614 scale = amax / int8_max
615 A_q = (
616 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
617 )
618 A_q = A_q.reshape(orig_shape)
619 scale = scale.reshape(M, K // block_k)
620 return A_q, scale
622 elif per_act_token:
623 A_flat = A.reshape(-1, A.size(-1))
624 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32)
625 scale = amax / int8_max
626 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
627 A_q = A_q.reshape(A.shape)
628 scale = scale.reshape(A.shape[:-1] + (1,))
629 return A_q, scale
631 else:
632 assert A_scale is not None, "int8 per-tensor requires A_scale"
633 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float()
634 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8)
635 return A_q, A_scale
638def moe_kernel_quantize_input(
639 A: torch.Tensor,
640 A_scale: Optional[torch.Tensor],
641 quant_dtype: None | torch.dtype | str,
642 per_act_token_quant: bool,
643 block_shape: Optional[list[int]] = None,
644 ocp_mx_scheme: str | None = None,
645) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
646 """Quantize MoE input activations before GEMM."""
647 if ocp_mx_scheme is not None:
648 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
649 pass
650 elif ocp_mx_scheme.endswith("a_fp8"):
651 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False)
652 A = (qA.float() * qA_scale.float()).to(A.dtype)
653 return A, None
655 if quant_dtype is None:
656 return A, A_scale
657 elif quant_dtype == torch.float8_e4m3fn:
658 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
659 elif quant_dtype == torch.int8:
660 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
661 else:
662 return A, A_scale
665def _ensure_block_size_k_divisible(
666 size_k: int, block_size_k: int, group_size: int
667) -> int:
668 """Find largest block_size_k that divides size_k and is divisible by group_size."""
669 if size_k % block_size_k == 0 and block_size_k % group_size == 0:
670 return block_size_k
672 max_search = min(block_size_k, size_k)
673 start = (max_search // group_size) * group_size
674 for candidate in range(start, group_size - 1, -group_size):
675 if size_k % candidate == 0:
676 return candidate
678 if size_k % group_size == 0:
679 return group_size
681 return size_k
684@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
685@triton.jit
686def _silu_and_mul_kernel(x, y):
687 x_fp32 = x.to(tl.float32)
688 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32)))
689 return x_silu * y
692@triton.jit
693def write_zeros_to_output(
694 c_ptr,
695 stride_cm,
696 stride_cn,
697 pid_n,
698 N,
699 offs_token,
700 token_mask,
701 BLOCK_SIZE_M,
702 BLOCK_SIZE_N,
703 compute_type,
704):
705 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
706 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
707 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
708 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
709 tl.store(c_ptrs, accumulator, mask=c_mask)
712@triton.jit
713def fused_moe_kernel_gptq_awq(
714 # Pointers to matrices
715 a_ptr,
716 b_ptr,
717 c_ptr,
718 b_scale_ptr,
719 b_zp_ptr,
720 topk_weights_ptr,
721 sorted_token_ids_ptr,
722 expert_ids_ptr,
723 num_tokens_post_padded_ptr,
724 # Matrix dimensions
725 N: tl.constexpr,
726 K: tl.constexpr,
727 EM,
728 num_valid_tokens,
729 # The stride variables represent how much to increase the ptr by when
730 # moving by 1 element in a particular dimension. E.g. `stride_am` is
731 # how much to increase `a_ptr` by to get the element one row down
732 # (A has M rows).
733 stride_am,
734 stride_ak,
735 stride_be,
736 stride_bk,
737 stride_bn,
738 stride_cm,
739 stride_cn,
740 stride_bse,
741 stride_bsk,
742 stride_bsn,
743 stride_bze,
744 stride_bzk,
745 stride_bzn,
746 block_k_diviable: tl.constexpr,
747 group_size: tl.constexpr,
748 # Meta-parameters
749 BLOCK_SIZE_M: tl.constexpr,
750 BLOCK_SIZE_N: tl.constexpr,
751 BLOCK_SIZE_K: tl.constexpr,
752 GROUP_SIZE_M: tl.constexpr,
753 SPLIT_K: tl.constexpr,
754 MUL_ROUTED_WEIGHT: tl.constexpr,
755 top_k: tl.constexpr,
756 compute_type: tl.constexpr,
757 has_zp: tl.constexpr,
758 use_int4_w4a16: tl.constexpr,
759 use_int8_w8a16: tl.constexpr,
760):
761 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights."""
762 # Map pid to C block (grouped ordering for L2 reuse)
763 pid = tl.program_id(axis=0)
764 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
765 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
766 num_pid_in_group = GROUP_SIZE_M * num_pid_n
767 group_id = pid // num_pid_in_group
768 first_pid_m = group_id * GROUP_SIZE_M
769 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
770 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
771 pid_n = (pid % num_pid_in_group) // group_size_m
773 # Create pointers for first blocks of A and B
774 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
775 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
776 return
777 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
778 # Cast to int64 to prevent overflow in stride*offset products
779 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
780 token_mask = offs_token < num_valid_tokens
782 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
783 if off_experts == -1:
784 # -----------------------------------------------------------
785 # Write back zeros to the output when the expert is not
786 # in the current expert parallel rank.
787 write_zeros_to_output(
788 c_ptr,
789 stride_cm,
790 stride_cn,
791 pid_n,
792 N,
793 offs_token,
794 token_mask,
795 BLOCK_SIZE_M,
796 BLOCK_SIZE_N,
797 compute_type,
798 )
799 return
801 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
802 offs_k = tl.arange(0, BLOCK_SIZE_K)
803 a_ptrs = a_ptr + (
804 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
805 )
807 if use_int4_w4a16:
808 b_ptrs = (
809 b_ptr
810 + off_experts * stride_be
811 + (offs_k[:, None] // 2) * stride_bk
812 + offs_bn[None, :] * stride_bn
813 )
814 b_shifter = (offs_k[:, None] % 2) * 4
815 elif use_int8_w8a16:
816 b_ptrs = (
817 b_ptr
818 + off_experts * stride_be
819 + offs_k[:, None] * stride_bk
820 + offs_bn[None, :] * stride_bn
821 )
823 if not has_zp and use_int4_w4a16:
824 b_zp_num = 8
825 if not has_zp and use_int8_w8a16:
826 b_zp_num = 128
827 elif has_zp and use_int4_w4a16:
828 b_zp_shifter = (offs_bn[None, :] % 2) * 4
830 # Accumulate C block in fp32
831 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
832 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
833 if not block_k_diviable:
834 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
835 k_other = 0.0
836 else:
837 k_mask = None
838 k_other = None
840 a = tl.load(
841 a_ptrs,
842 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
843 other=0.0,
844 )
845 b = tl.load(b_ptrs)
846 if use_int4_w4a16:
847 b = (b >> b_shifter) & 0xF
849 b_scale_ptrs = (
850 b_scale_ptr
851 + off_experts * stride_bse
852 + offs_bn[None, :] * stride_bsn
853 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
854 )
855 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
856 b_scale = b_scale.to(tl.float32)
858 if has_zp and use_int4_w4a16:
859 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
860 b_zp_ptrs = (
861 b_zp_ptr
862 + off_experts * stride_bze
863 + (offs_bn[None, :] // 2) * stride_bzn
864 + offs_k_true * stride_bzk
865 )
866 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
867 b_zp = (b_zp >> b_zp_shifter) & 0xF
868 b_zp = b_zp.to(tl.float32)
869 elif has_zp and use_int8_w8a16:
870 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
871 b_zp_ptrs = (
872 b_zp_ptr
873 + off_experts * stride_bze
874 + offs_bn[None, :] * stride_bzn
875 + offs_k_true * stride_bzk
876 )
877 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
878 b_zp = b_zp.to(tl.float32)
880 if has_zp:
881 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
882 else:
883 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
884 accumulator = tl.dot(a, b, acc=accumulator)
886 a_ptrs += BLOCK_SIZE_K * stride_ak
887 if use_int4_w4a16:
888 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
889 else:
890 b_ptrs += BLOCK_SIZE_K * stride_bk
892 if MUL_ROUTED_WEIGHT:
893 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
894 accumulator = accumulator * moe_weight[:, None]
896 accumulator = accumulator.to(compute_type)
897 # Write back output
898 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
899 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
900 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
901 tl.store(c_ptrs, accumulator, mask=c_mask)
904@triton.jit
905def fused_moe_kernel(
906 # Pointers to matrices
907 a_ptr,
908 b_ptr,
909 c_ptr,
910 b_bias_ptr,
911 a_scale_ptr,
912 b_scale_ptr,
913 topk_weights_ptr,
914 sorted_token_ids_ptr,
915 expert_ids_ptr,
916 num_tokens_post_padded_ptr,
917 # Matrix dimensions
918 N,
919 K,
920 EM,
921 num_valid_tokens,
922 stride_am,
923 stride_ak,
924 stride_be,
925 stride_bk,
926 stride_bn,
927 stride_cm,
928 stride_cn,
929 stride_asm,
930 stride_ask,
931 stride_bse,
932 stride_bsk,
933 stride_bsn,
934 stride_bbe, # bias expert stride
935 stride_bbn, # bias N stride
936 # Block size for block-wise quantization
937 group_n: tl.constexpr,
938 group_k: tl.constexpr,
939 naive_block_assignment: tl.constexpr,
940 # Meta-parameters
941 BLOCK_SIZE_M: tl.constexpr,
942 BLOCK_SIZE_N: tl.constexpr,
943 BLOCK_SIZE_K: tl.constexpr,
944 GROUP_SIZE_M: tl.constexpr,
945 SPLIT_K: tl.constexpr,
946 MUL_ROUTED_WEIGHT: tl.constexpr,
947 top_k: tl.constexpr,
948 compute_type: tl.constexpr,
949 use_fp8_w8a8: tl.constexpr,
950 use_int8_w8a8: tl.constexpr,
951 use_int8_w8a16: tl.constexpr,
952 per_channel_quant: tl.constexpr,
953 HAS_BIAS: tl.constexpr,
954):
955 """Fused MoE kernel: token × expert GEMM with quantization support."""
956 # Map pid to C block (grouped ordering for L2 reuse)
957 pid = tl.program_id(axis=0)
958 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
959 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
960 num_pid_in_group = GROUP_SIZE_M * num_pid_n
961 group_id = pid // num_pid_in_group
962 first_pid_m = group_id * GROUP_SIZE_M
963 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
964 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
965 pid_n = (pid % num_pid_in_group) // group_size_m
967 # Create pointers for first blocks of A and B
968 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
969 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
970 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
971 return
972 if not naive_block_assignment:
973 offs_token_id = pid_m * BLOCK_SIZE_M + offs
974 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
975 else:
976 offs_token = tl.where(
977 offs == 0,
978 pid_m, # first element = pid_m
979 num_valid_tokens, # remaining elements = constant
980 )
981 offs_token = offs_token.to(tl.int64) # prevent int32 overflow
983 token_mask = offs_token < num_valid_tokens
985 offs_token = tl.where(token_mask, offs_token, 0)
987 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
988 if off_experts == -1:
989 # Expert not in current EP rank, write zeros
990 write_zeros_to_output(
991 c_ptr,
992 stride_cm,
993 stride_cn,
994 pid_n,
995 N,
996 offs_token,
997 token_mask,
998 BLOCK_SIZE_M,
999 BLOCK_SIZE_N,
1000 compute_type,
1001 )
1002 return
1004 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
1005 offs_k = tl.arange(0, BLOCK_SIZE_K)
1006 a_ptrs = a_ptr + (
1007 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
1008 )
1010 b_ptrs = (
1011 b_ptr
1012 + off_experts * stride_be
1013 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
1014 )
1015 if use_int8_w8a16:
1016 b_scale_ptrs = (
1017 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
1018 )
1019 b_scale = tl.load(b_scale_ptrs)
1021 if use_fp8_w8a8 or use_int8_w8a8:
1022 if group_k > 0 and group_n > 0: # block-wise
1023 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1024 offs_bsn = offs_bn // group_n
1025 b_scale_ptrs = (
1026 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
1027 )
1028 elif per_channel_quant: # channel-wise
1029 b_scale_ptrs = (
1030 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
1031 )
1032 b_scale = tl.load(b_scale_ptrs)
1033 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
1034 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
1035 else: # tensor-wise
1036 a_scale = tl.load(a_scale_ptr)
1037 b_scale = tl.load(b_scale_ptr + off_experts)
1038 if HAS_BIAS:
1039 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn
1040 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0)
1041 # Accumulate C block in fp32
1042 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1043 k_total = K
1044 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
1045 # Pre-compute remaining K for this iteration
1046 k_offset = k * BLOCK_SIZE_K
1047 k_remaining = k_total - k_offset
1048 # Use other=0.0 for proper tail block handling - compatible with all batch sizes
1049 a = tl.load(
1050 a_ptrs,
1051 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
1052 other=0.0,
1053 )
1054 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
1055 if use_int8_w8a16:
1056 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
1057 elif use_fp8_w8a8 or use_int8_w8a8:
1058 if group_k > 0 and group_n > 0:
1059 k_start = k * BLOCK_SIZE_K
1060 offs_ks = k_start // group_k
1061 a_scale = tl.load(
1062 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
1063 )
1064 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
1066 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
1067 else:
1068 if use_fp8_w8a8:
1069 accumulator = tl.dot(a, b, acc=accumulator)
1070 else:
1071 accumulator += tl.dot(a, b)
1072 else:
1073 accumulator += tl.dot(a, b)
1074 # Update pointers for next iteration
1075 a_ptrs += BLOCK_SIZE_K * stride_ak
1076 b_ptrs += BLOCK_SIZE_K * stride_bk
1078 # Dequantization
1079 if use_int8_w8a16:
1080 accumulator = accumulator * b_scale
1081 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
1082 accumulator = accumulator * a_scale * b_scale
1084 if HAS_BIAS:
1085 accumulator += bias[None, :]
1087 # Router weight multiplication (must be in fp32)
1088 if MUL_ROUTED_WEIGHT:
1089 moe_weight = tl.load(
1090 topk_weights_ptr + offs_token,
1091 mask=token_mask,
1092 other=0,
1093 )
1094 accumulator *= moe_weight[:, None]
1096 accumulator = accumulator.to(compute_type)
1098 # Write back output
1099 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1100 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
1101 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
1102 tl.store(c_ptrs, accumulator, mask=c_mask)
1105def invoke_fused_moe_wna16_triton_kernel(
1106 A: torch.Tensor,
1107 B: torch.Tensor,
1108 C: torch.Tensor,
1109 B_scale: torch.Tensor | None,
1110 B_zp: torch.Tensor | None,
1111 topk_weights: torch.Tensor | None,
1112 sorted_token_ids: torch.Tensor,
1113 expert_ids: torch.Tensor,
1114 num_tokens_post_padded: torch.Tensor,
1115 mul_routed_weight: bool,
1116 top_k: int,
1117 config: dict[str, Any],
1118 compute_type: tl.dtype,
1119 use_int8_w8a16: bool,
1120 use_int4_w4a16: bool,
1121 block_shape: list[int] | None,
1122):
1123 assert B_scale is not None and B_scale.ndim == 3
1124 assert B_zp is None or B_zp.ndim == 3
1125 assert block_shape is not None and block_shape[0] == 0
1127 M = A.size(0)
1128 num_tokens = M * top_k
1130 EM = sorted_token_ids.size(0)
1131 if A.size(0) < config["BLOCK_SIZE_M"]:
1132 # optimize for small batch_size.
1133 # We assume that top_ids of each token is unique,
1134 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
1135 # and we can skip some invalid blocks.
1136 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
1137 grid = lambda META: (
1138 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1139 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1140 )
1141 config = config.copy()
1142 config.update(
1143 get_moe_wna16_block_config(
1144 config=config,
1145 use_moe_wna16_cuda=False,
1146 num_valid_tokens=num_tokens,
1147 size_k=A.size(1),
1148 size_n=B.size(1),
1149 num_experts=B.size(1),
1150 group_size=block_shape[1],
1151 real_top_k=top_k,
1152 block_size_m=config["BLOCK_SIZE_M"],
1153 )
1154 )
1156 fused_moe_kernel_gptq_awq[grid](
1157 A,
1158 B,
1159 C,
1160 B_scale,
1161 B_zp,
1162 topk_weights,
1163 sorted_token_ids,
1164 expert_ids,
1165 num_tokens_post_padded,
1166 B.size(1),
1167 A.size(1),
1168 EM,
1169 num_tokens,
1170 A.stride(0),
1171 A.stride(1),
1172 B.stride(0),
1173 B.stride(2),
1174 B.stride(1),
1175 C.stride(1),
1176 C.stride(2),
1177 B_scale.stride(0),
1178 B_scale.stride(2),
1179 B_scale.stride(1),
1180 B_zp.stride(0) if B_zp is not None else 0,
1181 B_zp.stride(2) if B_zp is not None else 0,
1182 B_zp.stride(1) if B_zp is not None else 0,
1183 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
1184 group_size=block_shape[1],
1185 MUL_ROUTED_WEIGHT=mul_routed_weight,
1186 top_k=top_k,
1187 compute_type=compute_type,
1188 has_zp=B_zp is not None,
1189 use_int4_w4a16=use_int4_w4a16,
1190 use_int8_w8a16=use_int8_w8a16,
1191 **config,
1192 )
1195def invoke_fused_moe_triton_kernel(
1196 A: torch.Tensor,
1197 B: torch.Tensor,
1198 C: torch.Tensor,
1199 A_scale: Optional[torch.Tensor],
1200 B_scale: Optional[torch.Tensor],
1201 topk_weights: Optional[torch.Tensor],
1202 sorted_token_ids: torch.Tensor,
1203 expert_ids: torch.Tensor,
1204 num_tokens_post_padded: torch.Tensor,
1205 mul_routed_weight: bool,
1206 top_k: int,
1207 config: dict[str, Any],
1208 compute_type: tl.dtype,
1209 use_fp8_w8a8: bool = False,
1210 use_int8_w8a8: bool = False,
1211 use_int8_w8a16: bool = False,
1212 use_int4_w4a16: bool = False,
1213 per_channel_quant: bool = False,
1214 block_shape: Optional[list[int]] = None,
1215 B_bias: torch.Tensor | None = None,
1216) -> None:
1217 """Launch the fused_moe_kernel Triton kernel."""
1218 assert topk_weights is not None or not mul_routed_weight
1219 assert topk_weights is None or topk_weights.stride(1) == 1
1220 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1222 if use_fp8_w8a8 or use_int8_w8a8:
1223 assert B_scale is not None
1224 assert block_shape is None or triton.cdiv(
1225 B.size(-2), block_shape[0]
1226 ) == B_scale.size(-2)
1227 assert block_shape is None or triton.cdiv(
1228 B.size(-1), block_shape[1]
1229 ) == B_scale.size(-1)
1230 elif use_int8_w8a16 or use_int4_w4a16:
1231 assert B_scale is not None
1232 assert block_shape is None or block_shape[0] == 0
1233 else:
1234 assert A_scale is None
1235 assert B_scale is None
1237 M = A.size(0)
1238 num_tokens = M * top_k
1239 if sorted_token_ids is not None:
1240 EM = sorted_token_ids.size(0)
1241 if A.size(0) < config["BLOCK_SIZE_M"]:
1242 EM = min(
1243 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]
1244 )
1245 else:
1246 EM = num_tokens * config["BLOCK_SIZE_M"]
1247 grid = lambda META: (
1248 triton.cdiv(EM, META["BLOCK_SIZE_M"])
1249 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
1250 )
1251 HAS_BIAS = B_bias is not None
1253 config = config.copy()
1254 config["SPLIT_K"] = 1
1255 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
1256 if block_shape is not None:
1257 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
1259 fused_moe_kernel[grid](
1260 A,
1261 B,
1262 C,
1263 B_bias,
1264 A_scale,
1265 B_scale,
1266 topk_weights,
1267 sorted_token_ids,
1268 expert_ids,
1269 num_tokens_post_padded,
1270 B.size(1), # N
1271 B.size(2), # K
1272 EM,
1273 num_tokens,
1274 A.stride(0),
1275 A.stride(1),
1276 B.stride(0),
1277 B.stride(2),
1278 B.stride(1),
1279 C.stride(1),
1280 C.stride(2),
1281 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
1282 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
1283 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
1284 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
1285 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
1286 B_bias.stride(0) if B_bias is not None else 0,
1287 B_bias.stride(1) if B_bias is not None else 0,
1288 0 if block_shape is None else block_shape[0],
1289 0 if block_shape is None else block_shape[1],
1290 MUL_ROUTED_WEIGHT=mul_routed_weight,
1291 top_k=top_k,
1292 compute_type=compute_type,
1293 use_fp8_w8a8=use_fp8_w8a8,
1294 use_int8_w8a8=use_int8_w8a8,
1295 use_int8_w8a16=use_int8_w8a16,
1296 per_channel_quant=per_channel_quant,
1297 naive_block_assignment=(sorted_token_ids is None),
1298 HAS_BIAS=HAS_BIAS,
1299 BLOCK_SIZE_K=BLOCK_SIZE_K,
1300 **config,
1301 )
1304def dispatch_fused_moe_kernel(
1305 A: torch.Tensor,
1306 B: torch.Tensor,
1307 C: torch.Tensor,
1308 A_scale: Optional[torch.Tensor],
1309 B_scale: Optional[torch.Tensor],
1310 B_zp: Optional[torch.Tensor],
1311 topk_weights: Optional[torch.Tensor],
1312 sorted_token_ids: torch.Tensor,
1313 expert_ids: torch.Tensor,
1314 num_tokens_post_padded: torch.Tensor,
1315 mul_routed_weight: bool,
1316 top_k: int,
1317 config: dict[str, Any],
1318 compute_type: tl.dtype,
1319 use_fp8_w8a8: bool,
1320 use_int8_w8a8: bool,
1321 use_int8_w8a16: bool,
1322 use_int4_w4a16: bool,
1323 per_channel_quant: bool,
1324 block_shape: Optional[list[int]] = None,
1325 B_bias: Optional[torch.Tensor] = None,
1326) -> None:
1327 """Dispatch to the appropriate fused MoE kernel based on quantization flags."""
1328 assert topk_weights is not None or not mul_routed_weight
1329 assert topk_weights is None or topk_weights.stride(1) == 1
1330 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
1332 # M = A.size(0)
1333 # num_tokens = M * top_k
1335 if False:
1336 # TODO: Other precision-specific implementations
1337 # use_fp8_w8a8,
1338 # use_int8_w8a8,
1339 # use_int8_w8a16,
1340 # use_int4_w4a16,
1341 pass
1342 if (use_int8_w8a16 or use_int4_w4a16) and (
1343 block_shape is not None and block_shape[1] > 0
1344 ):
1345 assert B_bias is None
1346 invoke_fused_moe_wna16_triton_kernel(
1347 A,
1348 B,
1349 C,
1350 B_scale,
1351 B_zp,
1352 topk_weights,
1353 sorted_token_ids,
1354 expert_ids,
1355 num_tokens_post_padded,
1356 mul_routed_weight,
1357 top_k,
1358 config,
1359 compute_type,
1360 use_int8_w8a16,
1361 use_int4_w4a16,
1362 block_shape,
1363 )
1364 else:
1365 invoke_fused_moe_triton_kernel(
1366 A,
1367 B,
1368 C,
1369 A_scale,
1370 B_scale,
1371 topk_weights,
1372 sorted_token_ids,
1373 expert_ids,
1374 num_tokens_post_padded,
1375 mul_routed_weight,
1376 top_k,
1377 config,
1378 compute_type,
1379 use_fp8_w8a8,
1380 use_int8_w8a8,
1381 use_int8_w8a16,
1382 use_int4_w4a16,
1383 per_channel_quant,
1384 block_shape,
1385 B_bias,
1386 )
1389def fused_experts_impl(
1390 hidden_states: torch.Tensor,
1391 w1: torch.Tensor,
1392 w2: torch.Tensor,
1393 topk_weights: torch.Tensor,
1394 topk_ids: torch.Tensor,
1395 inplace: bool = False,
1396 activation: str = "silu",
1397 apply_router_weight_on_input: bool = False,
1398 use_fp8_w8a8: bool = False,
1399 use_int8_w8a8: bool = False,
1400 use_int8_w8a16: bool = False,
1401 use_int4_w4a16: bool = False,
1402 ocp_mx_scheme: str | None = None,
1403 per_channel_quant: bool = False,
1404 global_num_experts: int = -1,
1405 expert_map: torch.Tensor | None = None,
1406 w1_scale: Optional[torch.Tensor] = None,
1407 w2_scale: Optional[torch.Tensor] = None,
1408 w1_zp: torch.Tensor | None = None,
1409 w2_zp: torch.Tensor | None = None,
1410 a1_scale: Optional[torch.Tensor] = None,
1411 a2_scale: Optional[torch.Tensor] = None,
1412 block_shape: Optional[list[int]] = None,
1413 w1_bias: Optional[torch.Tensor] = None,
1414 w2_bias: Optional[torch.Tensor] = None,
1415) -> torch.Tensor:
1416 logger.debug("GEMS_ASCEND FUSED MOE")
1417 if hasattr(activation, "value"):
1418 activation = activation.value
1419 assert (
1420 activation == "silu"
1421 ), f"Only 'silu' activation is supported, got {activation}"
1423 activation_enum = MoEActivation.from_str(activation)
1425 # Check constraints
1426 if use_int4_w4a16:
1427 # INT4 stored unpacked in INT8 containers (full K dim)
1428 assert hidden_states.size(1) == w1.size(
1429 2
1430 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1431 elif ocp_mx_scheme is not None:
1432 if ocp_mx_scheme.startswith("w_mxfp4"):
1433 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
1434 elif ocp_mx_scheme.startswith("w_mxfp6"):
1435 assert (
1436 hidden_states.size(1) == (w1.size(2) * 4) // 3
1437 ), "hidden size mismatch"
1438 else:
1439 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1440 else:
1441 assert hidden_states.size(1) == w1.size(
1442 2
1443 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
1445 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
1446 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1447 assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
1448 assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
1449 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
1451 num_tokens = hidden_states.size(0)
1452 E, N, _ = w1.size()
1453 K = w2.size(1)
1454 if global_num_experts == -1:
1455 global_num_experts = E
1456 top_k_num = topk_ids.size(1)
1458 CHUNK_SIZE: int = 16 * 1024
1459 M = min(num_tokens, CHUNK_SIZE)
1461 config_dtype = _get_config_dtype_str(
1462 use_fp8_w8a8=use_fp8_w8a8,
1463 use_int8_w8a16=use_int8_w8a16,
1464 use_int4_w4a16=use_int4_w4a16,
1465 ocp_mx_scheme=ocp_mx_scheme,
1466 dtype=hidden_states.dtype,
1467 )
1469 quant_dtype = _get_config_quant_dtype(
1470 use_fp8_w8a8=use_fp8_w8a8,
1471 use_int8_w8a8=use_int8_w8a8,
1472 ocp_mx_scheme=ocp_mx_scheme,
1473 )
1475 get_config_func = functools.partial(
1476 try_get_optimal_moe_config,
1477 w1.size(),
1478 w2.size(),
1479 top_k_num,
1480 config_dtype,
1481 block_shape=block_shape,
1482 )
1484 config = get_config_func(M)
1486 # cache1 and cache3 share memory (non-overlapping lifetime)
1487 cache13 = torch.empty(
1488 M * top_k_num * max(N, K),
1489 device=hidden_states.device,
1490 dtype=hidden_states.dtype,
1491 )
1492 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
1493 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
1495 # cache2 needs separate memory (concurrent with cache1)
1496 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum)
1497 intermediate_cache2 = torch.empty(
1498 (M * top_k_num, activation_out_dim),
1499 device=hidden_states.device,
1500 dtype=hidden_states.dtype,
1501 )
1503 if hidden_states.dtype == torch.bfloat16:
1504 compute_type = tl.bfloat16
1505 elif hidden_states.dtype == torch.float16:
1506 compute_type = tl.float16
1507 elif hidden_states.dtype == torch.float32:
1508 compute_type = tl.float32
1509 else:
1510 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1512 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
1514 if ocp_mx_scheme is not None:
1515 # Dequantize OCP MX weights (TODO: skip on platforms with native MX)
1516 if ocp_mx_scheme.startswith("w_mxfp4"):
1517 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
1518 w1_scale = None
1519 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
1520 w2_scale = None
1521 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
1522 w1 = dequant_mxfp6(
1523 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1524 )
1525 w1_scale = None
1526 w2 = dequant_mxfp6(
1527 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
1528 )
1529 w2_scale = None
1530 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
1531 w1 = dequant_mxfp6(
1532 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1533 )
1534 w1_scale = None
1535 w2 = dequant_mxfp6(
1536 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
1537 )
1538 w2_scale = None
1539 else:
1540 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
1542 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot)
1543 if use_int8_w8a16 or use_int4_w4a16:
1544 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype)
1545 w1_scale = None
1546 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype)
1547 w2_scale = None
1548 use_int8_w8a16 = False
1549 use_int4_w4a16 = False
1551 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
1552 begin_chunk_idx, end_chunk_idx = (
1553 chunk * CHUNK_SIZE,
1554 min((chunk + 1) * CHUNK_SIZE, num_tokens),
1555 )
1556 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
1557 tokens_in_chunk, _ = curr_hidden_states.size()
1559 if tokens_in_chunk == 0:
1560 break
1562 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
1563 # Adjust cache size for last chunk
1564 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1565 intermediate_cache2 = intermediate_cache2[
1566 : tokens_in_chunk * topk_ids.size(1)
1567 ]
1568 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
1569 config = get_config_func(tokens_in_chunk)
1571 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
1572 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
1573 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
1574 A=curr_hidden_states,
1575 A_scale=a1_scale,
1576 quant_dtype=quant_dtype,
1577 per_act_token_quant=per_channel_quant,
1578 block_shape=block_shape,
1579 ocp_mx_scheme=ocp_mx_scheme,
1580 )
1582 SPARSITY_FACTOR = 4
1583 # For small tokens (< 32), always use naive assignment to skip alignment overhead
1584 use_naive_small = tokens_in_chunk < 32
1585 naive_block_assignment = use_naive_small or (
1586 expert_map is None
1587 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
1588 and not (
1589 (use_int8_w8a16 or use_int4_w4a16)
1590 and block_shape is not None
1591 and block_shape[1] > 0
1592 )
1593 )
1595 if not naive_block_assignment:
1596 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
1597 curr_topk_ids,
1598 config["BLOCK_SIZE_M"],
1599 global_num_experts,
1600 expert_map,
1601 # ignore_invalid_experts=True,
1602 )
1603 else:
1604 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
1605 expert_ids = curr_topk_ids.view(-1)
1606 num_tokens_post_padded = torch.empty(
1607 (1), dtype=torch.int32, device=topk_ids.device
1608 )
1609 num_tokens_post_padded.fill_(max_num_tokens_padded)
1610 sorted_token_ids = None
1612 dispatch_fused_moe_kernel(
1613 qcurr_hidden_states,
1614 w1,
1615 intermediate_cache1,
1616 a1q_scale,
1617 w1_scale,
1618 w1_zp,
1619 curr_topk_weights,
1620 sorted_token_ids,
1621 expert_ids,
1622 num_tokens_post_padded,
1623 apply_router_weight_on_input,
1624 top_k_num,
1625 config,
1626 compute_type=compute_type,
1627 use_fp8_w8a8=use_fp8_w8a8,
1628 use_int8_w8a8=use_int8_w8a8,
1629 use_int8_w8a16=use_int8_w8a16,
1630 use_int4_w4a16=use_int4_w4a16,
1631 per_channel_quant=per_channel_quant,
1632 block_shape=block_shape,
1633 B_bias=w1_bias,
1634 )
1636 apply_moe_activation(
1637 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N)
1638 )
1640 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
1641 A=intermediate_cache2,
1642 A_scale=a2_scale,
1643 quant_dtype=quant_dtype,
1644 per_act_token_quant=per_channel_quant,
1645 block_shape=block_shape,
1646 ocp_mx_scheme=ocp_mx_scheme,
1647 )
1649 if expert_map is not None:
1650 intermediate_cache3.zero_()
1652 dispatch_fused_moe_kernel(
1653 qintermediate_cache2,
1654 w2,
1655 intermediate_cache3,
1656 a2q_scale,
1657 w2_scale,
1658 w2_zp,
1659 curr_topk_weights,
1660 sorted_token_ids,
1661 expert_ids,
1662 num_tokens_post_padded,
1663 not apply_router_weight_on_input,
1664 1,
1665 config,
1666 compute_type=compute_type,
1667 use_fp8_w8a8=use_fp8_w8a8,
1668 use_int8_w8a8=use_int8_w8a8,
1669 use_int8_w8a16=use_int8_w8a16,
1670 use_int4_w4a16=use_int4_w4a16,
1671 per_channel_quant=per_channel_quant,
1672 block_shape=block_shape,
1673 B_bias=w2_bias,
1674 )
1676 moe_sum(
1677 intermediate_cache3.view(*intermediate_cache3.size()),
1678 out_hidden_states[begin_chunk_idx:end_chunk_idx],
1679 )
1681 return out_hidden_states
1684def inplace_fused_experts(
1685 hidden_states: torch.Tensor,
1686 w1: torch.Tensor,
1687 w2: torch.Tensor,
1688 topk_weights: torch.Tensor,
1689 topk_ids: torch.Tensor,
1690 activation: str = "silu",
1691 apply_router_weight_on_input: bool = False,
1692 use_fp8_w8a8: bool = False,
1693 use_int8_w8a8: bool = False,
1694 use_int8_w8a16: bool = False,
1695 use_int4_w4a16: bool = False,
1696 per_channel_quant: bool = False,
1697 global_num_experts: int = -1,
1698 w1_scale: Optional[torch.Tensor] = None,
1699 w2_scale: Optional[torch.Tensor] = None,
1700 a1_scale: Optional[torch.Tensor] = None,
1701 a2_scale: Optional[torch.Tensor] = None,
1702 block_shape: Optional[list[int]] = None,
1703 w1_bias: Optional[torch.Tensor] = None,
1704 w2_bias: Optional[torch.Tensor] = None,
1705) -> None:
1706 """
1707 In-place fused MoE: writes output directly into ``hidden_states``.
1709 Same semantics as ``fused_experts_impl(..., inplace=True)``.
1710 Returns None (the result is stored in ``hidden_states``).
1711 """
1712 fused_experts_impl(
1713 hidden_states,
1714 w1,
1715 w2,
1716 topk_weights,
1717 topk_ids,
1718 inplace=True,
1719 activation=activation,
1720 apply_router_weight_on_input=apply_router_weight_on_input,
1721 use_fp8_w8a8=use_fp8_w8a8,
1722 use_int8_w8a8=use_int8_w8a8,
1723 use_int8_w8a16=use_int8_w8a16,
1724 use_int4_w4a16=use_int4_w4a16,
1725 per_channel_quant=per_channel_quant,
1726 global_num_experts=global_num_experts,
1727 w1_scale=w1_scale,
1728 w2_scale=w2_scale,
1729 a1_scale=a1_scale,
1730 a2_scale=a2_scale,
1731 block_shape=block_shape,
1732 w1_bias=w1_bias,
1733 w2_bias=w2_bias,
1734 )
1737def outplace_fused_experts(
1738 hidden_states: torch.Tensor,
1739 w1: torch.Tensor,
1740 w2: torch.Tensor,
1741 topk_weights: torch.Tensor,
1742 topk_ids: torch.Tensor,
1743 activation: str = "silu",
1744 apply_router_weight_on_input: bool = False,
1745 use_fp8_w8a8: bool = False,
1746 use_int8_w8a8: bool = False,
1747 use_int8_w8a16: bool = False,
1748 use_int4_w4a16: bool = False,
1749 per_channel_quant: bool = False,
1750 global_num_experts: int = -1,
1751 w1_scale: Optional[torch.Tensor] = None,
1752 w2_scale: Optional[torch.Tensor] = None,
1753 a1_scale: Optional[torch.Tensor] = None,
1754 a2_scale: Optional[torch.Tensor] = None,
1755 block_shape: Optional[list[int]] = None,
1756 w1_bias: Optional[torch.Tensor] = None,
1757 w2_bias: Optional[torch.Tensor] = None,
1758) -> torch.Tensor:
1759 """
1760 Out-of-place fused MoE: allocates and returns a new output tensor.
1762 Same semantics as ``fused_experts_impl(..., inplace=False)``.
1763 """
1764 return fused_experts_impl(
1765 hidden_states,
1766 w1,
1767 w2,
1768 topk_weights,
1769 topk_ids,
1770 inplace=False,
1771 activation=activation,
1772 apply_router_weight_on_input=apply_router_weight_on_input,
1773 use_fp8_w8a8=use_fp8_w8a8,
1774 use_int8_w8a8=use_int8_w8a8,
1775 use_int8_w8a16=use_int8_w8a16,
1776 use_int4_w4a16=use_int4_w4a16,
1777 per_channel_quant=per_channel_quant,
1778 global_num_experts=global_num_experts,
1779 w1_scale=w1_scale,
1780 w2_scale=w2_scale,
1781 a1_scale=a1_scale,
1782 a2_scale=a2_scale,
1783 block_shape=block_shape,
1784 w1_bias=w1_bias,
1785 w2_bias=w2_bias,
1786 )