Coverage for src/flag_gems/runtime/backend/_sunrise/fused/fused_moe.py: 0%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import contextlib
2import threading
3from typing import Any
5from flag_gems.fused import fused_moe as generic_fused_moe
7_PATCH_LOCK = threading.RLock()
8_GENERIC_GET_DEFAULT_CONFIG = generic_fused_moe.get_default_config
9_PLAIN_HALF_CONFIG_DTYPES = ("fp16", "bf16")
12def _sunrise_get_default_config(
13 M: int,
14 E: int,
15 N: int,
16 K: int,
17 topk: int,
18 dtype: str | None,
19 block_shape: list[int] | None = None,
20 gemm_stage: str = "gemm1",
21 enable_gemm_fast_path: bool = False,
22) -> dict[str, Any]:
23 config = _GENERIC_GET_DEFAULT_CONFIG(
24 M,
25 E,
26 N,
27 K,
28 topk,
29 dtype,
30 block_shape,
31 gemm_stage,
32 enable_gemm_fast_path,
33 )
35 # Sunrise/PTPU can exhaust registers in the generic fused MoE kernel when
36 # large-N half-precision tiles keep BLOCK_SIZE_N at 128. Narrowing the N
37 # tile to 64 avoids the inline-asm register overflow seen on PT200.
38 if dtype in _PLAIN_HALF_CONFIG_DTYPES and N >= 4096:
39 config = config.copy()
40 config["BLOCK_SIZE_N"] = min(config["BLOCK_SIZE_N"], 64)
42 return config
45@contextlib.contextmanager
46def _sunrise_moe_config_patch():
47 with _PATCH_LOCK:
48 original = generic_fused_moe.get_default_config
49 generic_fused_moe.get_default_config = _sunrise_get_default_config
50 try:
51 yield
52 finally:
53 generic_fused_moe.get_default_config = original
56def fused_experts_impl(*args, **kwargs):
57 with _sunrise_moe_config_patch():
58 return generic_fused_moe.fused_experts_impl(*args, **kwargs)
61def inplace_fused_experts(*args, **kwargs):
62 with _sunrise_moe_config_patch():
63 return generic_fused_moe.inplace_fused_experts(*args, **kwargs)
66def outplace_fused_experts(*args, **kwargs):
67 with _sunrise_moe_config_patch():
68 return generic_fused_moe.outplace_fused_experts(*args, **kwargs)