Coverage for src/flag_gems/runtime/backend/_sunrise/fused/fused_moe.py: 0%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import contextlib 

2import threading 

3from typing import Any 

4 

5from flag_gems.fused import fused_moe as generic_fused_moe 

6 

7_PATCH_LOCK = threading.RLock() 

8_GENERIC_GET_DEFAULT_CONFIG = generic_fused_moe.get_default_config 

9_PLAIN_HALF_CONFIG_DTYPES = ("fp16", "bf16") 

10 

11 

12def _sunrise_get_default_config( 

13 M: int, 

14 E: int, 

15 N: int, 

16 K: int, 

17 topk: int, 

18 dtype: str | None, 

19 block_shape: list[int] | None = None, 

20 gemm_stage: str = "gemm1", 

21 enable_gemm_fast_path: bool = False, 

22) -> dict[str, Any]: 

23 config = _GENERIC_GET_DEFAULT_CONFIG( 

24 M, 

25 E, 

26 N, 

27 K, 

28 topk, 

29 dtype, 

30 block_shape, 

31 gemm_stage, 

32 enable_gemm_fast_path, 

33 ) 

34 

35 # Sunrise/PTPU can exhaust registers in the generic fused MoE kernel when 

36 # large-N half-precision tiles keep BLOCK_SIZE_N at 128. Narrowing the N 

37 # tile to 64 avoids the inline-asm register overflow seen on PT200. 

38 if dtype in _PLAIN_HALF_CONFIG_DTYPES and N >= 4096: 

39 config = config.copy() 

40 config["BLOCK_SIZE_N"] = min(config["BLOCK_SIZE_N"], 64) 

41 

42 return config 

43 

44 

45@contextlib.contextmanager 

46def _sunrise_moe_config_patch(): 

47 with _PATCH_LOCK: 

48 original = generic_fused_moe.get_default_config 

49 generic_fused_moe.get_default_config = _sunrise_get_default_config 

50 try: 

51 yield 

52 finally: 

53 generic_fused_moe.get_default_config = original 

54 

55 

56def fused_experts_impl(*args, **kwargs): 

57 with _sunrise_moe_config_patch(): 

58 return generic_fused_moe.fused_experts_impl(*args, **kwargs) 

59 

60 

61def inplace_fused_experts(*args, **kwargs): 

62 with _sunrise_moe_config_patch(): 

63 return generic_fused_moe.inplace_fused_experts(*args, **kwargs) 

64 

65 

66def outplace_fused_experts(*args, **kwargs): 

67 with _sunrise_moe_config_patch(): 

68 return generic_fused_moe.outplace_fused_experts(*args, **kwargs)