Coverage for src/flag_gems/fused/fused_moe.py: 39%

921 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# 

4# Adapted from the vLLM project (https://github.com/vllm-project/vllm). 

5# Source files under vllm/model_executor/layers/: 

6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl 

7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation 

8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input 

9# fused_moe/config.py – _get_config_dtype_str 

10# quantization/utils/mxfp4_utils.py – dequant_mxfp4 

11# quantization/utils/mxfp6_utils.py – dequant_mxfp6 

12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE 

13 

14 

15import functools 

16import logging 

17import os 

18from enum import Enum 

19from typing import Any, Optional 

20 

21import torch 

22import torch.nn.functional as F 

23import triton 

24import triton.language as tl 

25import yaml 

26 

27from flag_gems.fused.moe_align_block_size import moe_align_block_size 

28from flag_gems.fused.moe_sum import moe_sum 

29from flag_gems.runtime import device, torch_device_fn 

30from flag_gems.utils import pointwise_dynamic 

31 

32logger = logging.getLogger(__name__) 

33 

34# OCP MX quantization helpers (requires amd-quark) 

35 

36OCP_MX_BLOCK_SIZE = 32 

37# H100/Qwen-style MoE tuning thresholds. GEMM tile changes become reliably 

38# positive from 4096 tokens; direct_sum is kept separate because it is a 

39# reduction-layout decision even though it currently shares the same cutoff. 

40MOE_GEMM_TUNING_MIN_TOKENS = 4096 

41MOE_DIRECT_SUM_MIN_TOKENS = 4096 

42_HALF_GEMM_TILE_M = 128 

43_HALF_GEMM_TILE_K = 64 

44_HALF_GEMM2_TILE_N = 256 

45_PLAIN_HALF_CONFIG_DTYPES = ("fp16", "bf16") 

46 

47 

48@functools.lru_cache(maxsize=1) 

49def get_embedded_moe_configs(): 

50 config_path = os.path.join( 

51 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.yaml" 

52 ) 

53 if not os.path.exists(config_path): 

54 return {}, {} 

55 with open(config_path, "r") as f: 

56 # JSON keys are strings, values are dicts where keys are M and values are configs 

57 data = yaml.safe_load(f) 

58 

59 fallback = data.get("_FALLBACK", {}) 

60 

61 # We need to convert the innermost keys (which are stringified integers for M) back to integers. 

62 # Ensure we map the lists back to config dicts. 

63 keys_order = [ 

64 "BLOCK_SIZE_M", 

65 "BLOCK_SIZE_N", 

66 "BLOCK_SIZE_K", 

67 "GROUP_SIZE_M", 

68 "num_warps", 

69 "num_stages", 

70 ] 

71 parsed_data = {} 

72 for dev, configs in data.items(): 

73 if dev == "_FALLBACK": 

74 continue 

75 parsed_data[dev] = {} 

76 for k, m_dict in configs.items(): 

77 parsed_dict = {} 

78 for m, v in m_dict.items(): 

79 if isinstance(v, list): 

80 parsed_dict[int(m)] = dict(zip(keys_order, v)) 

81 else: 

82 parsed_dict[int(m)] = v 

83 parsed_data[dev][k] = parsed_dict 

84 

85 return parsed_data, fallback 

86 

87 

88def dequant_mxfp4( 

89 x: torch.Tensor, 

90 scale: torch.Tensor, 

91 float_dtype: torch.dtype, 

92) -> torch.Tensor: 

93 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4.""" 

94 try: 

95 from quark.torch.kernel import mx 

96 except ImportError as err: 

97 raise ImportError("amd-quark is required for MX-FP4") from err 

98 

99 return mx.dq_mxfp4(x, scale, float_dtype) 

100 

101 

102def dequant_mxfp6( 

103 x: torch.Tensor, 

104 scale: torch.Tensor, 

105 float_dtype: torch.dtype, 

106 quant_dtype: str, 

107) -> torch.Tensor: 

108 """Dequantize MXFP6 tensor via quark hw_emulation.""" 

109 try: 

110 from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( 

111 dequantize_fp4_fp6_per_group, 

112 ) 

113 from quark.torch.utils.pack import create_pack_method 

114 except ImportError as err: 

115 raise ImportError("amd-quark is required for MX-FP6") from err 

116 

117 pack_method = create_pack_method(None, dtype=quant_dtype) 

118 unpacked_x = pack_method.unpack(x, reorder=False) 

119 

120 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype) 

121 

122 return dequantize_fp4_fp6_per_group( 

123 unpacked_x, 

124 scale, 

125 axis=-1, 

126 group_size=OCP_MX_BLOCK_SIZE, 

127 quant_dtype=quant_dtype, 

128 ).to(float_dtype) 

129 

130 

131# Activation quantization helpers 

132 

133 

134@functools.lru_cache(maxsize=1) 

135def _get_device_name() -> str: 

136 """Return the normalised device name (spaces replaced by underscores). 

137 

138 Matches the naming convention used by vLLM for its per-device config files. 

139 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture). 

140 """ 

141 try: 

142 name = torch_device_fn.get_device_name().replace(" ", "_") 

143 except AttributeError: 

144 name = device.name 

145 # Normalise the H200 product family to a single key, following vLLM. 

146 if "H200" in name.split("_"): 

147 name = "NVIDIA_H200" 

148 # H800 has the same SM 9.0 as H100; use H100 configs as fallback. 

149 embedded_configs, fallback_mapping = get_embedded_moe_configs() 

150 if name in embedded_configs: 

151 return name 

152 # Fallback mapping for devices whose tuning profiles are equivalent. 

153 fallback = fallback_mapping.get(name) 

154 if fallback and fallback in embedded_configs: 

155 logger.info("Device %s not in config table, falling back to %s", name, fallback) 

156 return fallback 

157 return name 

158 

159 

160def get_moe_configs( 

161 E: int, 

162 N: int, 

163 dtype: str | None, 

164 block_n: int | None = None, 

165 block_k: int | None = None, 

166) -> dict[int, Any] | None: 

167 """ 

168 Return optimized configurations for the fused MoE kernel. 

169 

170 Looks up pre-tuned configs from the embedded table (ported from vLLM) 

171 for the current GPU device. Returns None if no matching config is found. 

172 """ 

173 device_name = _get_device_name() 

174 embedded_configs, _ = get_embedded_moe_configs() 

175 device_table = embedded_configs.get(device_name) 

176 if device_table is None: 

177 logger.warning( 

178 "No embedded MoE configs for device %s. Will use default config.", 

179 device_name, 

180 ) 

181 return None 

182 

183 _block_n = block_n if block_n else 0 

184 _block_k = block_k if block_k else 0 

185 key = f"{E},{N},{dtype},{_block_n},{_block_k}" 

186 configs = device_table.get(key) 

187 if configs is not None: 

188 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key) 

189 return configs 

190 logger.warning( 

191 "No embedded MoE config for device=%s, key=%s. Will use default config.", 

192 device_name, 

193 key, 

194 ) 

195 return None 

196 

197 

198def try_get_optimal_moe_config( 

199 w1_shape: tuple[int, ...], 

200 w2_shape: tuple[int, ...], 

201 top_k: int, 

202 dtype: str | None, 

203 M: int, 

204 E: int, 

205 block_shape: list[int] | None = None, 

206 gemm_stage: str = "gemm1", 

207 enable_gemm_fast_path: bool = False, 

208 return_is_embedded: bool = False, 

209) -> dict[str, Any] | tuple[dict[str, Any], bool]: 

210 if gemm_stage not in ("gemm1", "gemm2"): 

211 raise ValueError(f"Unsupported MoE GEMM stage: {gemm_stage}") 

212 _, _, config_n = w2_shape 

213 if dtype == "int4_w4a16": 

214 config_n = config_n * 2 

215 block_n = block_shape[0] if block_shape else 0 

216 block_k = block_shape[1] if block_shape else 0 

217 configs = get_moe_configs(E, config_n, dtype, block_n, block_k) 

218 if configs: 

219 config = configs[min(configs.keys(), key=lambda x: abs(x - M))].copy() 

220 is_embedded = True 

221 else: 

222 if gemm_stage == "gemm1": 

223 _, N, K = w1_shape 

224 else: 

225 _, N, K = w2_shape 

226 config = get_default_config( 

227 M, 

228 E, 

229 N, 

230 K, 

231 top_k, 

232 dtype, 

233 block_shape, 

234 gemm_stage=gemm_stage, 

235 enable_gemm_fast_path=enable_gemm_fast_path, 

236 ) 

237 is_embedded = False 

238 if return_is_embedded: 

239 return config, is_embedded 

240 return config 

241 

242 

243def _get_config_quant_dtype( 

244 use_fp8_w8a8: bool, 

245 use_int8_w8a8: bool, 

246 ocp_mx_scheme: str | None, 

247) -> None | torch.dtype | str: 

248 """Map quantization flags to the corresponding dtype.""" 

249 if use_fp8_w8a8: 

250 return torch.float8_e4m3fn 

251 elif use_int8_w8a8: 

252 return torch.int8 

253 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": 

254 return "mxfp4" 

255 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: 

256 return "mxfp6_e3m2" 

257 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: 

258 return "mxfp6_e2m3" 

259 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}: 

260 return torch.bfloat16 

261 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}: 

262 return torch.float8_e4m3fn 

263 

264 return None 

265 

266 

267def get_moe_wna16_block_config( 

268 config: dict[str, int], 

269 use_moe_wna16_cuda: bool, 

270 num_valid_tokens: int, 

271 size_k: int, 

272 size_n: int, 

273 num_experts: int, 

274 group_size: int, 

275 real_top_k: int, 

276 block_size_m: int, 

277): 

278 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: 

279 return {} 

280 if not use_moe_wna16_cuda: 

281 if num_valid_tokens // real_top_k == 1: 

282 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} 

283 else: 

284 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} 

285 else: 

286 block_size_n = 128 

287 block_size_k = 128 

288 if block_size_k <= group_size: 

289 block_size_k = group_size 

290 

291 num_n_blocks = size_k // block_size_k 

292 num_k_blocks = size_n // block_size_k 

293 num_m_blocks = ( 

294 num_valid_tokens + block_size_m - 1 

295 ) / block_size_m + num_experts 

296 if num_valid_tokens // real_top_k <= block_size_m: 

297 num_m_blocks = min(num_m_blocks, num_valid_tokens) 

298 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks 

299 

300 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: 

301 block_size_k = 256 

302 num_blocks = num_blocks // (256 // block_size_k) 

303 

304 if ( 

305 num_m_blocks <= 16 

306 and size_k % (block_size_k * 2) == 0 

307 and size_k % (block_size_k * 2) == 0 

308 and block_size_k <= 512 

309 and num_blocks >= 512 

310 ): 

311 block_size_k = block_size_k * 2 

312 num_blocks = num_blocks // 2 

313 

314 if num_blocks > 1024: 

315 block_size_n = 256 

316 num_n_blocks = num_n_blocks // 2 

317 num_blocks = num_blocks // 2 

318 

319 if size_n <= 1024 and num_blocks >= 1024: 

320 block_size_n = 1024 

321 

322 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) 

323 

324 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} 

325 

326 

327def get_default_config( 

328 M: int, 

329 E: int, 

330 N: int, 

331 K: int, 

332 topk: int, 

333 dtype: str | None, 

334 block_shape: list[int] | None = None, 

335 gemm_stage: str = "gemm1", 

336 enable_gemm_fast_path: bool = False, 

337) -> dict[str, Any]: 

338 """Default Triton config for fused MoE kernel. 

339 

340 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100. 

341 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each 

342 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical. 

343 """ 

344 is_fp8_blockwise = dtype == "fp8_w8a8" and block_shape is not None 

345 if gemm_stage not in ("gemm1", "gemm2"): 

346 raise ValueError(f"Unsupported MoE GEMM stage: {gemm_stage}") 

347 

348 if is_fp8_blockwise: 

349 avg_tokens_per_expert = M * max(topk, 1) // max(E, 1) 

350 is_large_m = M >= 16384 

351 if avg_tokens_per_expert <= 16: 

352 block_m = 16 

353 elif avg_tokens_per_expert <= 32: 

354 block_m = 32 

355 elif avg_tokens_per_expert <= 64 or not is_large_m: 

356 block_m = 64 

357 else: 

358 block_m = 128 

359 

360 config = { 

361 "BLOCK_SIZE_M": block_m, 

362 "BLOCK_SIZE_N": block_shape[0], 

363 "BLOCK_SIZE_K": block_shape[1], 

364 "GROUP_SIZE_M": 8 if (is_large_m and avg_tokens_per_expert > 16) else 1, 

365 "num_warps": 8 if (is_large_m and block_m > 32) else 4, 

366 "num_stages": 4 if M >= 1024 else 3, 

367 "SWAP_AB": False, 

368 } 

369 elif dtype in _PLAIN_HALF_CONFIG_DTYPES: 

370 # Routed rows per expert drives block_m. Each token contributes topk 

371 # rows to the expert-sorted GEMM input, so M * topk / E is the relevant 

372 # density for high-expert-count MoE routing. 

373 routed_tokens_per_expert = M * max(topk, 1) // max(E, 1) 

374 tokens_per_expert = M // max(E, 1) 

375 

376 if routed_tokens_per_expert <= 16: 

377 block_m = 16 

378 elif routed_tokens_per_expert <= 64: 

379 block_m = 64 

380 else: 

381 block_m = 128 

382 

383 if tokens_per_expert > 128: 

384 group_m = 16 

385 elif tokens_per_expert > 32: 

386 group_m = 8 

387 else: 

388 group_m = 1 

389 

390 block_k = 128 if M <= 64 else 64 

391 

392 if N >= 4096: 

393 block_n = 128 if M <= 128 else 256 

394 else: 

395 block_n = 64 if M <= 64 else 128 

396 

397 can_use_gemm_fast_path = ( 

398 enable_gemm_fast_path 

399 and M >= MOE_GEMM_TUNING_MIN_TOKENS 

400 and block_m == _HALF_GEMM_TILE_M 

401 and block_k == _HALF_GEMM_TILE_K 

402 ) 

403 

404 use_gemm2_fast_path = ( 

405 gemm_stage == "gemm2" 

406 and can_use_gemm_fast_path 

407 and N % _HALF_GEMM2_TILE_N == 0 

408 ) 

409 use_gemm1_fast_path = ( 

410 gemm_stage == "gemm1" and can_use_gemm_fast_path and N % block_n == 0 

411 ) 

412 

413 if gemm_stage == "gemm2" and enable_gemm_fast_path: 

414 block_n = ( 

415 _HALF_GEMM2_TILE_N if use_gemm2_fast_path else (64 if M <= 64 else 128) 

416 ) 

417 

418 # Prefer 4 warps for small tiles; only use 8 for large M 

419 num_warps = 4 if M <= 128 else 8 

420 num_stages = 3 

421 

422 if use_gemm1_fast_path: 

423 group_m = 1 

424 num_stages = 4 

425 elif use_gemm2_fast_path: 

426 group_m = 2 

427 num_stages = 4 

428 

429 smem_per_stage = (block_m * block_k + block_k * block_n) * 2 

430 while num_stages > 2 and smem_per_stage * num_stages > 200_000: 

431 num_stages -= 1 

432 

433 config = { 

434 "BLOCK_SIZE_M": block_m, 

435 "BLOCK_SIZE_N": block_n, 

436 "BLOCK_SIZE_K": block_k, 

437 "GROUP_SIZE_M": group_m, 

438 "num_warps": num_warps, 

439 "num_stages": num_stages, 

440 } 

441 if use_gemm1_fast_path: 

442 config["PAIR_GATE_UP_DOT"] = True 

443 else: 

444 tokens_per_expert = M // max(E, 1) 

445 

446 if tokens_per_expert <= 2: 

447 block_m = 16 

448 elif tokens_per_expert <= 4: 

449 block_m = 32 

450 elif tokens_per_expert <= 16: 

451 block_m = 64 

452 else: 

453 block_m = 128 

454 

455 # Tile sizing 

456 if N >= 4096: 

457 block_n = 128 if M <= 128 else 256 

458 elif N >= 1024: 

459 block_n = 64 if M <= 64 else 128 

460 else: 

461 block_n = 64 if M <= 64 else 128 

462 

463 if dtype == "fp8_w8a8": 

464 block_k = 128 

465 elif M <= 64: 

466 block_k = 128 

467 else: 

468 block_k = 64 

469 

470 if tokens_per_expert > 128: 

471 group_m = 16 

472 elif tokens_per_expert > 32: 

473 group_m = 8 

474 else: 

475 group_m = 1 

476 

477 # Prefer 4 warps for small tiles; only use 8 for large M 

478 num_warps = 4 if M <= 128 else 8 

479 num_stages = 3 

480 

481 smem_per_stage = (block_m * block_k + block_k * block_n) * 2 

482 while num_stages > 2 and smem_per_stage * num_stages > 200_000: 

483 num_stages -= 1 

484 

485 config = { 

486 "BLOCK_SIZE_M": block_m, 

487 "BLOCK_SIZE_N": block_n, 

488 "BLOCK_SIZE_K": block_k, 

489 "GROUP_SIZE_M": group_m, 

490 "num_warps": num_warps, 

491 "num_stages": num_stages, 

492 } 

493 return config 

494 

495 

496def _get_config_dtype_str( 

497 dtype: Optional[torch.dtype] = None, 

498 use_fp8_w8a8: bool = False, 

499 use_fp8_w8a16: bool = False, 

500 use_int8_w8a16: bool = False, 

501 use_int4_w4a16: bool = False, 

502 ocp_mx_scheme: str | None = None, 

503) -> str | None: 

504 """Return dtype string for kernel config lookup.""" 

505 if use_fp8_w8a8: 

506 return "fp8_w8a8" 

507 elif use_fp8_w8a16: 

508 return "fp8_w8a16" 

509 elif use_int8_w8a16: 

510 return "int8_w8a16" 

511 elif use_int4_w4a16: 

512 return "int4_w4a16" 

513 elif ocp_mx_scheme is not None: 

514 return None 

515 elif dtype == torch.float16: 

516 return "fp16" 

517 elif dtype == torch.bfloat16: 

518 return "bf16" 

519 elif dtype == torch.float: 

520 return "float32" 

521 return None 

522 

523 

524# MoE activation enum 

525 

526 

527class MoEActivation(Enum): 

528 """Activation functions for MoE layers.""" 

529 

530 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d] 

531 SILU = "silu" 

532 GELU = "gelu" 

533 RELU2 = "relu2" 

534 SWIGLUOAI = "swigluoai" 

535 SWIGLUSTEP = "swiglustep" 

536 

537 # Non-gated: input [..., d] -> output [..., d] 

538 SILU_NO_MUL = "silu_no_mul" 

539 GELU_NO_MUL = "gelu_no_mul" 

540 RELU2_NO_MUL = "relu2_no_mul" 

541 

542 @property 

543 def is_gated(self) -> bool: 

544 return not self.value.endswith("_no_mul") 

545 

546 def without_mul(self) -> "MoEActivation": 

547 """Return the non-gated variant.""" 

548 _without_mul: dict[MoEActivation, MoEActivation] = { 

549 MoEActivation.SILU: MoEActivation.SILU_NO_MUL, 

550 MoEActivation.GELU: MoEActivation.GELU_NO_MUL, 

551 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, 

552 } 

553 return _without_mul.get(self, self) 

554 

555 @classmethod 

556 def from_str(cls, s: str) -> "MoEActivation": 

557 for member in cls: 

558 if member.value == s: 

559 return member 

560 valid = [m.value for m in cls] 

561 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}") 

562 

563 @staticmethod 

564 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int: 

565 """Return N for non-gated, N // 2 for gated activations.""" 

566 return N if not activation.is_gated else N // 2 

567 

568 

569def apply_moe_activation( 

570 activation: MoEActivation, 

571 output: torch.Tensor, 

572 input: torch.Tensor, 

573) -> torch.Tensor: 

574 """Apply MoE activation (pure PyTorch / FlagGems Triton).""" 

575 assert input.dim() == 2, "Input must be 2D" 

576 assert output.dim() == 2, "Output must be 2D" 

577 if activation.is_gated: 

578 assert output.size(-1) * 2 == input.size(-1), ( 

579 f"{activation.value} expects 2x ratio: " 

580 f"{output.size(-1) * 2} vs {input.size(-1)}" 

581 ) 

582 else: 

583 assert output.size(-1) == input.size(-1), ( 

584 f"{activation.value} expects equal sizes: " 

585 f"{output.size(-1)} vs {input.size(-1)}" 

586 ) 

587 

588 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): 

589 N = output.size(-1) 

590 x, y = input[:, :N], input[:, N:] 

591 _silu_and_mul_kernel(x, y, out0=output) 

592 elif activation == MoEActivation.GELU: 

593 N = output.size(-1) 

594 gate, up = input[:, :N], input[:, N:] 

595 output.copy_(F.gelu(gate) * up) 

596 elif activation == MoEActivation.SWIGLUSTEP: 

597 N = output.size(-1) 

598 gate, up = input[:, :N], input[:, N:] 

599 output.copy_(torch.sigmoid(gate) * up) 

600 elif activation == MoEActivation.RELU2: 

601 N = output.size(-1) 

602 gate, up = input[:, :N], input[:, N:] 

603 output.copy_(F.relu(gate).square() * up) 

604 

605 elif activation == MoEActivation.SILU_NO_MUL: 

606 output.copy_(F.silu(input)) 

607 elif activation == MoEActivation.GELU_NO_MUL: 

608 output.copy_(F.gelu(input)) 

609 elif activation == MoEActivation.RELU2_NO_MUL: 

610 F.relu(input, inplace=True) 

611 torch.square(input, out=output) 

612 else: 

613 raise ValueError(f"Unsupported FusedMoe activation: {activation}") 

614 

615 return output 

616 

617 

618def _fp8_quantize( 

619 A: torch.Tensor, 

620 A_scale: Optional[torch.Tensor], 

621 per_act_token: bool, 

622 block_shape: Optional[list[int]] = None, 

623) -> tuple[torch.Tensor, torch.Tensor]: 

624 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise.""" 

625 fp8_dtype = torch.float8_e4m3fn 

626 finfo = torch.finfo(fp8_dtype) 

627 fp8_max = finfo.max 

628 fp8_min = finfo.min 

629 eps = 1e-10 

630 

631 if block_shape is not None: 

632 assert not per_act_token 

633 assert len(block_shape) == 2 

634 block_k = block_shape[1] 

635 assert A.size(-1) % block_k == 0 

636 if A.ndim == 2 and A.stride(-1) == 1: 

637 from flag_gems.ops.per_token_group_quant_fp8 import ( 

638 per_token_group_quant_fp8, 

639 ) 

640 

641 return per_token_group_quant_fp8( 

642 A, 

643 group_size=block_k, 

644 eps=eps, 

645 dtype=fp8_dtype, 

646 column_major_scales=False, 

647 scale_ue8m0=False, 

648 ) 

649 orig_shape = A.shape 

650 A_flat = A.reshape(-1, A.size(-1)) 

651 M, K = A_flat.shape 

652 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

653 amax = ( 

654 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

655 ) 

656 scale = amax / fp8_max 

657 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

658 A_q = A_q.reshape(orig_shape) 

659 scale = scale.reshape(M, K // block_k) 

660 return A_q, scale 

661 

662 elif per_act_token: 

663 A_flat = A.reshape(-1, A.size(-1)) 

664 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

665 scale = amax / fp8_max 

666 min_scale = torch.tensor( 

667 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device 

668 ) 

669 scale = scale.clamp(min=min_scale) 

670 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

671 A_q = A_q.reshape(A.shape) 

672 scale = scale.reshape(A.shape[:-1] + (1,)) 

673 return A_q, scale 

674 

675 else: 

676 if A_scale is not None: 

677 scale = ( 

678 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

679 ) 

680 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

681 return A_q, A_scale 

682 else: 

683 amax = A.abs().amax().clamp(min=eps).to(torch.float32) 

684 scale = amax / fp8_max 

685 iscale = 1.0 / scale 

686 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

687 return A_q, scale.view(1) 

688 

689 

690def _int8_quantize( 

691 A: torch.Tensor, 

692 A_scale: Optional[torch.Tensor], 

693 per_act_token: bool, 

694 block_shape: Optional[list[int]] = None, 

695) -> tuple[torch.Tensor, torch.Tensor]: 

696 """INT8 quantization: per-tensor, per-token, or block-wise.""" 

697 iinfo = torch.iinfo(torch.int8) 

698 int8_max = iinfo.max 

699 int8_min = iinfo.min 

700 eps = 1e-10 

701 

702 if block_shape is not None: 

703 assert not per_act_token 

704 assert len(block_shape) == 2 

705 block_k = block_shape[1] 

706 assert A.size(-1) % block_k == 0 

707 orig_shape = A.shape 

708 A_flat = A.reshape(-1, A.size(-1)) 

709 M, K = A_flat.shape 

710 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

711 amax = ( 

712 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

713 ) 

714 scale = amax / int8_max 

715 A_q = ( 

716 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

717 ) 

718 A_q = A_q.reshape(orig_shape) 

719 scale = scale.reshape(M, K // block_k) 

720 return A_q, scale 

721 

722 elif per_act_token: 

723 A_flat = A.reshape(-1, A.size(-1)) 

724 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

725 scale = amax / int8_max 

726 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

727 A_q = A_q.reshape(A.shape) 

728 scale = scale.reshape(A.shape[:-1] + (1,)) 

729 return A_q, scale 

730 

731 else: 

732 assert A_scale is not None, "int8 per-tensor requires A_scale" 

733 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

734 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

735 return A_q, A_scale 

736 

737 

738def moe_kernel_quantize_input( 

739 A: torch.Tensor, 

740 A_scale: Optional[torch.Tensor], 

741 quant_dtype: None | torch.dtype | str, 

742 per_act_token_quant: bool, 

743 block_shape: Optional[list[int]] = None, 

744 ocp_mx_scheme: str | None = None, 

745) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 

746 """Quantize MoE input activations before GEMM.""" 

747 if ocp_mx_scheme is not None: 

748 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}: 

749 pass 

750 elif ocp_mx_scheme.endswith("a_fp8"): 

751 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False) 

752 A = (qA.float() * qA_scale.float()).to(A.dtype) 

753 return A, None 

754 

755 if quant_dtype is None: 

756 return A, A_scale 

757 elif quant_dtype == torch.float8_e4m3fn: 

758 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) 

759 elif quant_dtype == torch.int8: 

760 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) 

761 else: 

762 return A, A_scale 

763 

764 

765def _ensure_block_size_k_divisible( 

766 size_k: int, block_size_k: int, group_size: int 

767) -> int: 

768 """Find largest block_size_k that divides size_k and is divisible by group_size.""" 

769 if size_k % block_size_k == 0 and block_size_k % group_size == 0: 

770 return block_size_k 

771 

772 max_search = min(block_size_k, size_k) 

773 start = (max_search // group_size) * group_size 

774 for candidate in range(start, group_size - 1, -group_size): 

775 if size_k % candidate == 0: 

776 return candidate 

777 

778 if size_k % group_size == 0: 

779 return group_size 

780 

781 return size_k 

782 

783 

784@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

785@triton.jit 

786def _silu_and_mul_kernel(x, y): 

787 x_fp32 = x.to(tl.float32) 

788 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

789 return x_silu * y 

790 

791 

792@triton.jit 

793def write_zeros_to_output( 

794 c_ptr, 

795 stride_cm, 

796 stride_cn, 

797 pid_n, 

798 N, 

799 offs_token, 

800 token_mask, 

801 BLOCK_SIZE_M, 

802 BLOCK_SIZE_N, 

803 compute_type, 

804): 

805 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) 

806 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

807 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

808 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

809 tl.store(c_ptrs, accumulator, mask=c_mask) 

810 

811 

812@triton.jit 

813def fused_moe_kernel_gptq_awq( 

814 # Pointers to matrices 

815 a_ptr, 

816 b_ptr, 

817 c_ptr, 

818 b_scale_ptr, 

819 b_zp_ptr, 

820 topk_weights_ptr, 

821 sorted_token_ids_ptr, 

822 expert_ids_ptr, 

823 num_tokens_post_padded_ptr, 

824 # Matrix dimensions 

825 N: tl.constexpr, 

826 K: tl.constexpr, 

827 EM, 

828 num_valid_tokens, 

829 # The stride variables represent how much to increase the ptr by when 

830 # moving by 1 element in a particular dimension. E.g. `stride_am` is 

831 # how much to increase `a_ptr` by to get the element one row down 

832 # (A has M rows). 

833 stride_am, 

834 stride_ak, 

835 stride_be, 

836 stride_bk, 

837 stride_bn, 

838 stride_cm, 

839 stride_cn, 

840 stride_bse, 

841 stride_bsk, 

842 stride_bsn, 

843 stride_bze, 

844 stride_bzk, 

845 stride_bzn, 

846 block_k_diviable: tl.constexpr, 

847 group_size: tl.constexpr, 

848 # Meta-parameters 

849 BLOCK_SIZE_M: tl.constexpr, 

850 BLOCK_SIZE_N: tl.constexpr, 

851 BLOCK_SIZE_K: tl.constexpr, 

852 GROUP_SIZE_M: tl.constexpr, 

853 SPLIT_K: tl.constexpr, 

854 MUL_ROUTED_WEIGHT: tl.constexpr, 

855 top_k: tl.constexpr, 

856 compute_type: tl.constexpr, 

857 has_zp: tl.constexpr, 

858 use_int4_w4a16: tl.constexpr, 

859 use_int8_w8a16: tl.constexpr, 

860): 

861 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights.""" 

862 # Map pid to C block (grouped ordering for L2 reuse) 

863 pid = tl.program_id(axis=0) 

864 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

865 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

866 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

867 group_id = pid // num_pid_in_group 

868 first_pid_m = group_id * GROUP_SIZE_M 

869 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

870 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

871 pid_n = (pid % num_pid_in_group) // group_size_m 

872 

873 # Create pointers for first blocks of A and B 

874 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

875 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

876 return 

877 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

878 # Cast to int64 to prevent overflow in stride*offset products 

879 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

880 token_mask = offs_token < num_valid_tokens 

881 

882 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

883 if off_experts == -1: 

884 # ----------------------------------------------------------- 

885 # Write back zeros to the output when the expert is not 

886 # in the current expert parallel rank. 

887 write_zeros_to_output( 

888 c_ptr, 

889 stride_cm, 

890 stride_cn, 

891 pid_n, 

892 N, 

893 offs_token, 

894 token_mask, 

895 BLOCK_SIZE_M, 

896 BLOCK_SIZE_N, 

897 compute_type, 

898 ) 

899 return 

900 

901 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

902 offs_k = tl.arange(0, BLOCK_SIZE_K) 

903 a_ptrs = a_ptr + ( 

904 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

905 ) 

906 

907 if use_int4_w4a16: 

908 b_ptrs = ( 

909 b_ptr 

910 + off_experts * stride_be 

911 + (offs_k[:, None] // 2) * stride_bk 

912 + offs_bn[None, :] * stride_bn 

913 ) 

914 b_shifter = (offs_k[:, None] % 2) * 4 

915 elif use_int8_w8a16: 

916 b_ptrs = ( 

917 b_ptr 

918 + off_experts * stride_be 

919 + offs_k[:, None] * stride_bk 

920 + offs_bn[None, :] * stride_bn 

921 ) 

922 

923 if not has_zp and use_int4_w4a16: 

924 b_zp_num = 8 

925 if not has_zp and use_int8_w8a16: 

926 b_zp_num = 128 

927 elif has_zp and use_int4_w4a16: 

928 b_zp_shifter = (offs_bn[None, :] % 2) * 4 

929 

930 # Accumulate C block in fp32 

931 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

932 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

933 if not block_k_diviable: 

934 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K 

935 k_other = 0.0 

936 else: 

937 k_mask = None 

938 k_other = None 

939 

940 a = tl.load( 

941 a_ptrs, 

942 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

943 other=0.0, 

944 ) 

945 b = tl.load(b_ptrs) 

946 if use_int4_w4a16: 

947 b = (b >> b_shifter) & 0xF 

948 

949 b_scale_ptrs = ( 

950 b_scale_ptr 

951 + off_experts * stride_bse 

952 + offs_bn[None, :] * stride_bsn 

953 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk 

954 ) 

955 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) 

956 b_scale = b_scale.to(tl.float32) 

957 

958 if has_zp and use_int4_w4a16: 

959 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

960 b_zp_ptrs = ( 

961 b_zp_ptr 

962 + off_experts * stride_bze 

963 + (offs_bn[None, :] // 2) * stride_bzn 

964 + offs_k_true * stride_bzk 

965 ) 

966 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

967 b_zp = (b_zp >> b_zp_shifter) & 0xF 

968 b_zp = b_zp.to(tl.float32) 

969 elif has_zp and use_int8_w8a16: 

970 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

971 b_zp_ptrs = ( 

972 b_zp_ptr 

973 + off_experts * stride_bze 

974 + offs_bn[None, :] * stride_bzn 

975 + offs_k_true * stride_bzk 

976 ) 

977 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

978 b_zp = b_zp.to(tl.float32) 

979 

980 if has_zp: 

981 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) 

982 else: 

983 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) 

984 accumulator = tl.dot(a, b, acc=accumulator) 

985 

986 a_ptrs += BLOCK_SIZE_K * stride_ak 

987 if use_int4_w4a16: 

988 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk 

989 else: 

990 b_ptrs += BLOCK_SIZE_K * stride_bk 

991 

992 if MUL_ROUTED_WEIGHT: 

993 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

994 accumulator = accumulator * moe_weight[:, None] 

995 

996 accumulator = accumulator.to(compute_type) 

997 # Write back output 

998 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

999 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

1000 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

1001 tl.store(c_ptrs, accumulator, mask=c_mask) 

1002 

1003 

1004@triton.jit 

1005def fused_moe_kernel( 

1006 # Pointers to matrices 

1007 a_ptr, 

1008 b_ptr, 

1009 c_ptr, 

1010 b_bias_ptr, 

1011 a_scale_ptr, 

1012 b_scale_ptr, 

1013 topk_weights_ptr, 

1014 sorted_token_ids_ptr, 

1015 expert_ids_ptr, 

1016 num_tokens_post_padded_ptr, 

1017 # Matrix dimensions 

1018 N, 

1019 K, 

1020 EM, 

1021 num_valid_tokens, 

1022 stride_am, 

1023 stride_ak, 

1024 stride_be, 

1025 stride_bk, 

1026 stride_bn, 

1027 stride_cm, 

1028 stride_cn, 

1029 stride_asm, 

1030 stride_ask, 

1031 stride_bse, 

1032 stride_bsk, 

1033 stride_bsn, 

1034 stride_bbe, # bias expert stride 

1035 stride_bbn, # bias N stride 

1036 # Block size for block-wise quantization 

1037 group_n: tl.constexpr, 

1038 group_k: tl.constexpr, 

1039 naive_block_assignment: tl.constexpr, 

1040 # Meta-parameters 

1041 BLOCK_SIZE_M: tl.constexpr, 

1042 BLOCK_SIZE_N: tl.constexpr, 

1043 BLOCK_SIZE_K: tl.constexpr, 

1044 GROUP_SIZE_M: tl.constexpr, 

1045 SPLIT_K: tl.constexpr, 

1046 MUL_ROUTED_WEIGHT: tl.constexpr, 

1047 top_k: tl.constexpr, 

1048 compute_type: tl.constexpr, 

1049 use_fp8_w8a8: tl.constexpr, 

1050 use_int8_w8a8: tl.constexpr, 

1051 use_int8_w8a16: tl.constexpr, 

1052 per_channel_quant: tl.constexpr, 

1053 HAS_BIAS: tl.constexpr, 

1054 SWAP_AB: tl.constexpr, 

1055 K_DIVISIBLE_BY_BLOCK_K: tl.constexpr, 

1056 N_DIVISIBLE_BY_BLOCK_N: tl.constexpr, 

1057 PAIR_GATE_UP_DOT: tl.constexpr, 

1058 DIRECT_SUM: tl.constexpr, 

1059 OUT_TOP_K: tl.constexpr, 

1060 FUSE_SILU: tl.constexpr, 

1061): 

1062 """Fused MoE kernel: token × expert GEMM with quantization support and optional SiLU fusion.""" 

1063 # Map pid to C block (grouped ordering for L2 reuse) 

1064 pid = tl.program_id(axis=0) 

1065 # Adjust N for FUSE_SILU. If fused, the actual output dimension is N // 2 

1066 N_out = N // 2 if FUSE_SILU else N 

1067 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

1068 num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N) 

1069 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

1070 group_id = pid // num_pid_in_group 

1071 first_pid_m = group_id * GROUP_SIZE_M 

1072 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

1073 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

1074 pid_n = (pid % num_pid_in_group) // group_size_m 

1075 

1076 # Create pointers for first blocks of A and B 

1077 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

1078 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

1079 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

1080 return 

1081 offs_token_id = pid_m * BLOCK_SIZE_M + offs 

1082 if not naive_block_assignment: 

1083 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) 

1084 else: 

1085 offs_token = tl.where( 

1086 offs == 0, 

1087 pid_m, # first element = pid_m 

1088 num_valid_tokens, # remaining elements = constant 

1089 ) 

1090 offs_token = offs_token.to(tl.int64) # prevent int32 overflow 

1091 

1092 token_mask = offs_token < num_valid_tokens 

1093 

1094 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

1095 

1096 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) 

1097 if not N_DIVISIBLE_BY_BLOCK_N: 

1098 offs_bn = offs_bn % N_out 

1099 offs_k = tl.arange(0, BLOCK_SIZE_K) 

1100 a_base = a_ptr + (offs_token[:, None] // top_k * stride_am) 

1101 

1102 if FUSE_SILU and PAIR_GATE_UP_DOT: 

1103 if off_experts == -1: 

1104 write_zeros_to_output( 

1105 c_ptr, 

1106 stride_cm, 

1107 stride_cn, 

1108 pid_n, 

1109 N_out, 

1110 offs_token, 

1111 token_mask, 

1112 BLOCK_SIZE_M, 

1113 BLOCK_SIZE_N, 

1114 compute_type, 

1115 ) 

1116 return 

1117 

1118 offs_pair = tl.arange(0, BLOCK_SIZE_N * 2).to(tl.int64) 

1119 offs_pair_bn = tl.where( 

1120 offs_pair < BLOCK_SIZE_N, 

1121 pid_n * BLOCK_SIZE_N + offs_pair, 

1122 N_out + pid_n * BLOCK_SIZE_N + offs_pair - BLOCK_SIZE_N, 

1123 ) 

1124 a_ptrs = a_base + offs_k[None, :] * stride_ak 

1125 b_pair_ptrs = ( 

1126 b_ptr 

1127 + off_experts * stride_be 

1128 + (offs_k[:, None] * stride_bk + offs_pair_bn[None, :] * stride_bn) 

1129 ) 

1130 pair_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N * 2), dtype=tl.float32) 

1131 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1132 if K_DIVISIBLE_BY_BLOCK_K: 

1133 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

1134 if N_DIVISIBLE_BY_BLOCK_N: 

1135 b_pair = tl.load(b_pair_ptrs) 

1136 else: 

1137 b_pair = tl.load( 

1138 b_pair_ptrs, mask=offs_pair_bn[None, :] < N, other=0.0 

1139 ) 

1140 else: 

1141 k_remaining = K - k * BLOCK_SIZE_K 

1142 a = tl.load( 

1143 a_ptrs, 

1144 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), 

1145 other=0.0, 

1146 ) 

1147 b_pair = tl.load( 

1148 b_pair_ptrs, 

1149 mask=(offs_k[:, None] < k_remaining) & (offs_pair_bn[None, :] < N), 

1150 other=0.0, 

1151 ) 

1152 pair_acc += tl.dot(a, b_pair) 

1153 a_ptrs += BLOCK_SIZE_K * stride_ak 

1154 b_pair_ptrs += BLOCK_SIZE_K * stride_bk 

1155 

1156 if HAS_BIAS: 

1157 pair_bias_ptrs = ( 

1158 b_bias_ptr + off_experts * stride_bbe + (offs_pair_bn * stride_bbn) 

1159 ) 

1160 pair_bias = tl.load(pair_bias_ptrs, mask=offs_pair_bn < N, other=0.0) 

1161 pair_acc += pair_bias[None, :] 

1162 

1163 gate_up = tl.trans( 

1164 tl.reshape(pair_acc, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N)), 

1165 (0, 2, 1), 

1166 ) 

1167 gate_acc, up_acc = tl.split(gate_up) 

1168 gate_sig = tl.sigmoid(gate_acc) 

1169 accumulator = ( 

1170 gate_acc.to(compute_type) 

1171 * gate_sig.to(compute_type) 

1172 * up_acc.to(compute_type) 

1173 ) 

1174 

1175 elif FUSE_SILU: 

1176 offs_bn_gate = offs_bn 

1177 offs_bn_up = offs_bn + N_out 

1178 

1179 b_expert_base = b_ptr + off_experts * stride_be 

1180 b_ptrs_gate = b_expert_base + ( 

1181 offs_k[:, None] * stride_bk + offs_bn_gate[None, :] * stride_bn 

1182 ) 

1183 b_ptrs_up = b_expert_base + ( 

1184 offs_k[:, None] * stride_bk + offs_bn_up[None, :] * stride_bn 

1185 ) 

1186 

1187 if use_fp8_w8a8 or use_int8_w8a8: 

1188 if group_k > 0 and group_n > 0: # block-wise 

1189 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1190 # Use scalar scale load for hardware broadcast when block size fits within quantization group. 

1191 if BLOCK_SIZE_N <= group_n: 

1192 offs_bsn_gate_idx = (pid_n * BLOCK_SIZE_N) % N_out // group_n 

1193 offs_bsn_up_idx = ( 

1194 (pid_n * BLOCK_SIZE_N) % N_out + N_out 

1195 ) // group_n 

1196 else: 

1197 offs_bsn_gate_idx = offs_bn_gate // group_n 

1198 offs_bsn_up_idx = offs_bn_up // group_n 

1199 b_scale_gate_ptrs = ( 

1200 b_scale_ptr 

1201 + off_experts * stride_bse 

1202 + offs_bsn_gate_idx * stride_bsn 

1203 ) 

1204 b_scale_up_ptrs = ( 

1205 b_scale_ptr 

1206 + off_experts * stride_bse 

1207 + offs_bsn_up_idx * stride_bsn 

1208 ) 

1209 elif per_channel_quant: # channel-wise 

1210 b_scale_gate_ptrs = ( 

1211 b_scale_ptr 

1212 + off_experts * stride_bse 

1213 + offs_bn_gate[None, :] * stride_bsn 

1214 ) 

1215 b_scale_gate = tl.load(b_scale_gate_ptrs) 

1216 b_scale_up_ptrs = ( 

1217 b_scale_ptr 

1218 + off_experts * stride_bse 

1219 + offs_bn_up[None, :] * stride_bsn 

1220 ) 

1221 b_scale_up = tl.load(b_scale_up_ptrs) 

1222 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1223 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

1224 else: # tensor-wise 

1225 a_scale = tl.load(a_scale_ptr) 

1226 b_scale_gate = tl.load(b_scale_ptr + off_experts) 

1227 b_scale_up = b_scale_gate 

1228 

1229 # Pass 1: Sequential execution of gate projection to minimize peak register pressure. 

1230 a_ptrs = a_base + offs_k[None, :] * stride_ak 

1231 acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

1232 

1233 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1234 # Eliminate masking overhead when K is perfectly aligned with BLOCK_SIZE_K. 

1235 if K_DIVISIBLE_BY_BLOCK_K: 

1236 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

1237 b_gate = tl.load(b_ptrs_gate) 

1238 else: 

1239 k_remaining = K - k * BLOCK_SIZE_K 

1240 a = tl.load( 

1241 a_ptrs, 

1242 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), 

1243 other=0.0, 

1244 ) 

1245 b_gate = tl.load( 

1246 b_ptrs_gate, mask=offs_k[:, None] < k_remaining, other=0.0 

1247 ) 

1248 

1249 if use_fp8_w8a8 or use_int8_w8a8: 

1250 if group_k > 0 and group_n > 0: 

1251 k_start = k * BLOCK_SIZE_K 

1252 offs_ks = k_start // group_k 

1253 a_scale = tl.load( 

1254 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

1255 ) 

1256 b_scale_val = tl.load(b_scale_gate_ptrs + offs_ks * stride_bsk) 

1257 

1258 # Pre-compute combined scale to reduce arithmetic overhead via the associative property. 

1259 if BLOCK_SIZE_N <= group_n: 

1260 combined_scale = a_scale[:, None] * b_scale_val 

1261 else: 

1262 combined_scale = a_scale[:, None] * b_scale_val[None, :] 

1263 acc_gate += tl.dot(a, b_gate) * combined_scale 

1264 else: 

1265 if use_fp8_w8a8: 

1266 acc_gate = tl.dot(a, b_gate, acc=acc_gate) 

1267 else: 

1268 acc_gate += tl.dot(a, b_gate) 

1269 else: 

1270 acc_gate += tl.dot(a, b_gate) 

1271 

1272 a_ptrs += BLOCK_SIZE_K * stride_ak 

1273 b_ptrs_gate += BLOCK_SIZE_K * stride_bk 

1274 

1275 if use_fp8_w8a8 or use_int8_w8a8: 

1276 if group_k > 0 and group_n > 0: 

1277 pass 

1278 elif per_channel_quant: 

1279 acc_gate = acc_gate * a_scale * b_scale_gate 

1280 else: 

1281 acc_gate = acc_gate * a_scale * b_scale_gate 

1282 

1283 # Pass 2: Sequential up projection; operand A is reloaded with high L1 hit rate. 

1284 a_ptrs = a_base + offs_k[None, :] * stride_ak 

1285 acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

1286 

1287 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1288 # Apply mask elimination during the up projection stage. 

1289 if K_DIVISIBLE_BY_BLOCK_K: 

1290 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

1291 b_up = tl.load(b_ptrs_up) 

1292 else: 

1293 k_remaining = K - k * BLOCK_SIZE_K 

1294 a = tl.load( 

1295 a_ptrs, 

1296 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), 

1297 other=0.0, 

1298 ) 

1299 b_up = tl.load(b_ptrs_up, mask=offs_k[:, None] < k_remaining, other=0.0) 

1300 

1301 if use_fp8_w8a8 or use_int8_w8a8: 

1302 if group_k > 0 and group_n > 0: 

1303 k_start = k * BLOCK_SIZE_K 

1304 offs_ks = k_start // group_k 

1305 a_scale = tl.load( 

1306 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

1307 ) 

1308 b_scale_val = tl.load(b_scale_up_ptrs + offs_ks * stride_bsk) 

1309 

1310 # Apply pre-computed scale merging to reduce multiplication overhead. 

1311 if BLOCK_SIZE_N <= group_n: 

1312 combined_scale = a_scale[:, None] * b_scale_val 

1313 else: 

1314 combined_scale = a_scale[:, None] * b_scale_val[None, :] 

1315 acc_up += tl.dot(a, b_up) * combined_scale 

1316 else: 

1317 if use_fp8_w8a8: 

1318 acc_up = tl.dot(a, b_up, acc=acc_up) 

1319 else: 

1320 acc_up += tl.dot(a, b_up) 

1321 else: 

1322 acc_up += tl.dot(a, b_up) 

1323 

1324 a_ptrs += BLOCK_SIZE_K * stride_ak 

1325 b_ptrs_up += BLOCK_SIZE_K * stride_bk 

1326 

1327 if use_fp8_w8a8 or use_int8_w8a8: 

1328 if group_k > 0 and group_n > 0: 

1329 pass 

1330 elif per_channel_quant: 

1331 acc_up = acc_up * a_scale * b_scale_up 

1332 else: 

1333 acc_up = acc_up * a_scale * b_scale_up 

1334 

1335 # SiLU activation fusion 

1336 accumulator = tl.fdiv(acc_gate, (1.0 + tl.exp(-acc_gate))) * acc_up 

1337 

1338 else: 

1339 if off_experts == -1: 

1340 # Expert not in current EP rank, write zeros 

1341 write_zeros_to_output( 

1342 c_ptr, 

1343 stride_cm, 

1344 stride_cn, 

1345 pid_n, 

1346 N_out, 

1347 offs_token, 

1348 token_mask, 

1349 BLOCK_SIZE_M, 

1350 BLOCK_SIZE_N, 

1351 compute_type, 

1352 ) 

1353 return 

1354 a_ptrs = a_base + offs_k[None, :] * stride_ak 

1355 b_ptrs = ( 

1356 b_ptr 

1357 + off_experts * stride_be 

1358 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

1359 ) 

1360 

1361 if use_int8_w8a16: 

1362 b_scale_ptrs = ( 

1363 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

1364 ) 

1365 b_scale = tl.load(b_scale_ptrs) 

1366 

1367 if use_fp8_w8a8 or use_int8_w8a8: 

1368 if group_k > 0 and group_n > 0: # block-wise 

1369 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1370 # Use scalar scale load for hardware broadcast when block size fits within quantization group. 

1371 if BLOCK_SIZE_N <= group_n: 

1372 offs_bsn = (pid_n * BLOCK_SIZE_N) % N_out // group_n 

1373 else: 

1374 offs_bsn = offs_bn // group_n 

1375 b_scale_ptrs = ( 

1376 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn 

1377 ) 

1378 elif per_channel_quant: # channel-wise 

1379 b_scale_ptrs = ( 

1380 b_scale_ptr 

1381 + off_experts * stride_bse 

1382 + offs_bn[None, :] * stride_bsn 

1383 ) 

1384 b_scale = tl.load(b_scale_ptrs) 

1385 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1386 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

1387 else: # tensor-wise 

1388 a_scale = tl.load(a_scale_ptr) 

1389 b_scale = tl.load(b_scale_ptr + off_experts) 

1390 

1391 if HAS_BIAS: 

1392 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn 

1393 bias = tl.load(bias_ptrs, mask=(offs_bn < N_out), other=0.0) 

1394 

1395 # Accumulate C block in fp32 

1396 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

1397 if SWAP_AB: 

1398 accumulator_nm = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32) 

1399 

1400 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1401 # Eliminate masking overhead when K is perfectly aligned with BLOCK_SIZE_K. 

1402 if K_DIVISIBLE_BY_BLOCK_K: 

1403 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

1404 b = tl.load(b_ptrs) 

1405 else: 

1406 k_remaining = K - k * BLOCK_SIZE_K 

1407 a = tl.load( 

1408 a_ptrs, 

1409 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), 

1410 other=0.0, 

1411 ) 

1412 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) 

1413 

1414 if use_int8_w8a16: 

1415 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) 

1416 elif use_fp8_w8a8 or use_int8_w8a8: 

1417 if group_k > 0 and group_n > 0: 

1418 k_start = k * BLOCK_SIZE_K 

1419 offs_ks = k_start // group_k 

1420 a_scale = tl.load( 

1421 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

1422 ) 

1423 if SWAP_AB: 

1424 b_scale_val = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

1425 if BLOCK_SIZE_N <= group_n: 

1426 combined_scale_nm = b_scale_val * a_scale[None, :] 

1427 else: 

1428 combined_scale_nm = b_scale_val[:, None] * a_scale[None, :] 

1429 accumulator_nm += ( 

1430 tl.dot(tl.trans(b), tl.trans(a)) * combined_scale_nm 

1431 ) 

1432 else: 

1433 b_scale_val = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

1434 # Pre-compute combined scale to reduce arithmetic overhead via the associative property. 

1435 if BLOCK_SIZE_N <= group_n: 

1436 combined_scale = a_scale[:, None] * b_scale_val 

1437 else: 

1438 combined_scale = a_scale[:, None] * b_scale_val[None, :] 

1439 accumulator += tl.dot(a, b) * combined_scale 

1440 else: 

1441 if use_fp8_w8a8: 

1442 if SWAP_AB: 

1443 accumulator_nm = tl.dot( 

1444 tl.trans(b), tl.trans(a), acc=accumulator_nm 

1445 ) 

1446 else: 

1447 accumulator = tl.dot(a, b, acc=accumulator) 

1448 else: 

1449 if SWAP_AB: 

1450 accumulator_nm += tl.dot(tl.trans(b), tl.trans(a)) 

1451 else: 

1452 accumulator += tl.dot(a, b) 

1453 else: 

1454 if SWAP_AB: 

1455 accumulator_nm += tl.dot(tl.trans(b), tl.trans(a)) 

1456 else: 

1457 accumulator += tl.dot(a, b) 

1458 a_ptrs += BLOCK_SIZE_K * stride_ak 

1459 b_ptrs += BLOCK_SIZE_K * stride_bk 

1460 

1461 if SWAP_AB: 

1462 accumulator = tl.trans(accumulator_nm) 

1463 

1464 # Dequantization 

1465 if use_int8_w8a16: 

1466 accumulator = accumulator * b_scale 

1467 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): 

1468 accumulator = accumulator * a_scale * b_scale 

1469 

1470 if HAS_BIAS: 

1471 accumulator += bias[None, :] 

1472 

1473 # Router weight multiplication (must be in fp32) 

1474 if MUL_ROUTED_WEIGHT: 

1475 moe_weight = tl.load( 

1476 topk_weights_ptr + offs_token, 

1477 mask=token_mask, 

1478 other=0, 

1479 ) 

1480 accumulator *= moe_weight[:, None] 

1481 

1482 accumulator = accumulator.to(compute_type) 

1483 

1484 # Write back output 

1485 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

1486 if DIRECT_SUM: 

1487 offs_c = offs_token // OUT_TOP_K 

1488 else: 

1489 offs_c = offs_token 

1490 c_ptrs = c_ptr + stride_cm * offs_c[:, None] + stride_cn * offs_cn[None, :] 

1491 c_mask = token_mask[:, None] 

1492 if not N_DIVISIBLE_BY_BLOCK_N: 

1493 c_mask = c_mask & (offs_cn[None, :] < N_out) 

1494 if DIRECT_SUM: 

1495 # Kernel completion provides the only ordering needed here. 

1496 tl.atomic_add(c_ptrs, accumulator, sem="relaxed", mask=c_mask) 

1497 else: 

1498 tl.store(c_ptrs, accumulator, mask=c_mask) 

1499 

1500 

1501def invoke_fused_moe_wna16_triton_kernel( 

1502 A: torch.Tensor, 

1503 B: torch.Tensor, 

1504 C: torch.Tensor, 

1505 B_scale: torch.Tensor | None, 

1506 B_zp: torch.Tensor | None, 

1507 topk_weights: torch.Tensor | None, 

1508 sorted_token_ids: torch.Tensor, 

1509 expert_ids: torch.Tensor, 

1510 num_tokens_post_padded: torch.Tensor, 

1511 mul_routed_weight: bool, 

1512 top_k: int, 

1513 config: dict[str, Any], 

1514 compute_type: tl.dtype, 

1515 use_int8_w8a16: bool, 

1516 use_int4_w4a16: bool, 

1517 block_shape: list[int] | None, 

1518): 

1519 assert B_scale is not None and B_scale.ndim == 3 

1520 assert B_zp is None or B_zp.ndim == 3 

1521 assert block_shape is not None and block_shape[0] == 0 

1522 

1523 M = A.size(0) 

1524 num_tokens = M * top_k 

1525 

1526 EM = sorted_token_ids.size(0) 

1527 if A.size(0) < config["BLOCK_SIZE_M"]: 

1528 # optimize for small batch_size. 

1529 # We assume that top_ids of each token is unique, 

1530 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, 

1531 # and we can skip some invalid blocks. 

1532 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) 

1533 grid = lambda META: ( 

1534 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1535 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1536 ) 

1537 config = config.copy() 

1538 config.update( 

1539 get_moe_wna16_block_config( 

1540 config=config, 

1541 use_moe_wna16_cuda=False, 

1542 num_valid_tokens=num_tokens, 

1543 size_k=A.size(1), 

1544 size_n=B.size(1), 

1545 num_experts=B.size(1), 

1546 group_size=block_shape[1], 

1547 real_top_k=top_k, 

1548 block_size_m=config["BLOCK_SIZE_M"], 

1549 ) 

1550 ) 

1551 

1552 fused_moe_kernel_gptq_awq[grid]( 

1553 A, 

1554 B, 

1555 C, 

1556 B_scale, 

1557 B_zp, 

1558 topk_weights, 

1559 sorted_token_ids, 

1560 expert_ids, 

1561 num_tokens_post_padded, 

1562 B.size(1), 

1563 A.size(1), 

1564 EM, 

1565 num_tokens, 

1566 A.stride(0), 

1567 A.stride(1), 

1568 B.stride(0), 

1569 B.stride(2), 

1570 B.stride(1), 

1571 C.stride(1), 

1572 C.stride(2), 

1573 B_scale.stride(0), 

1574 B_scale.stride(2), 

1575 B_scale.stride(1), 

1576 B_zp.stride(0) if B_zp is not None else 0, 

1577 B_zp.stride(2) if B_zp is not None else 0, 

1578 B_zp.stride(1) if B_zp is not None else 0, 

1579 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, 

1580 group_size=block_shape[1], 

1581 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1582 top_k=top_k, 

1583 compute_type=compute_type, 

1584 has_zp=B_zp is not None, 

1585 use_int4_w4a16=use_int4_w4a16, 

1586 use_int8_w8a16=use_int8_w8a16, 

1587 **config, 

1588 ) 

1589 

1590 

1591def invoke_fused_moe_triton_kernel( 

1592 A: torch.Tensor, 

1593 B: torch.Tensor, 

1594 C: torch.Tensor, 

1595 A_scale: Optional[torch.Tensor], 

1596 B_scale: Optional[torch.Tensor], 

1597 topk_weights: Optional[torch.Tensor], 

1598 sorted_token_ids: torch.Tensor, 

1599 expert_ids: torch.Tensor, 

1600 num_tokens_post_padded: torch.Tensor, 

1601 mul_routed_weight: bool, 

1602 top_k: int, 

1603 config: dict[str, Any], 

1604 compute_type: tl.dtype, 

1605 use_fp8_w8a8: bool = False, 

1606 use_int8_w8a8: bool = False, 

1607 use_int8_w8a16: bool = False, 

1608 use_int4_w4a16: bool = False, 

1609 per_channel_quant: bool = False, 

1610 block_shape: Optional[list[int]] = None, 

1611 B_bias: torch.Tensor | None = None, 

1612 FUSE_SILU: bool = False, 

1613 direct_sum: bool = False, 

1614 out_top_k: int = 1, 

1615) -> None: 

1616 """Launch the fused_moe_kernel Triton kernel.""" 

1617 assert topk_weights is not None or not mul_routed_weight 

1618 assert topk_weights is None or topk_weights.stride(1) == 1 

1619 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1620 

1621 if use_fp8_w8a8 or use_int8_w8a8: 

1622 assert B_scale is not None 

1623 assert block_shape is None or triton.cdiv( 

1624 B.size(-2), block_shape[0] 

1625 ) == B_scale.size(-2) 

1626 assert block_shape is None or triton.cdiv( 

1627 B.size(-1), block_shape[1] 

1628 ) == B_scale.size(-1) 

1629 elif use_int8_w8a16 or use_int4_w4a16: 

1630 assert B_scale is not None 

1631 assert block_shape is None or block_shape[0] == 0 

1632 else: 

1633 assert A_scale is None 

1634 assert B_scale is None 

1635 

1636 M = A.size(0) 

1637 num_tokens = M * top_k 

1638 if sorted_token_ids is not None: 

1639 EM = sorted_token_ids.size(0) 

1640 if A.size(0) < config["BLOCK_SIZE_M"]: 

1641 EM = min( 

1642 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"] 

1643 ) 

1644 else: 

1645 EM = num_tokens * config["BLOCK_SIZE_M"] 

1646 

1647 # FUSE_SILU means B.size(1) contains both Gate and Up. N is halved. 

1648 actual_N = B.size(1) // 2 if FUSE_SILU else B.size(1) 

1649 grid = lambda META: ( 

1650 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1651 * triton.cdiv(actual_N, META["BLOCK_SIZE_N"]), 

1652 ) 

1653 HAS_BIAS = B_bias is not None 

1654 

1655 config = config.copy() 

1656 config["SPLIT_K"] = 1 

1657 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") 

1658 if block_shape is not None: 

1659 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) 

1660 

1661 swap_AB = config.pop("SWAP_AB", False) 

1662 pair_gate_up_dot = config.pop("PAIR_GATE_UP_DOT", False) 

1663 # Force disable SWAP_AB in fusion mode 

1664 if FUSE_SILU: 

1665 swap_AB = False 

1666 

1667 fused_moe_kernel[grid]( 

1668 A, 

1669 B, 

1670 C, 

1671 B_bias, 

1672 A_scale, 

1673 B_scale, 

1674 topk_weights, 

1675 sorted_token_ids, 

1676 expert_ids, 

1677 num_tokens_post_padded, 

1678 B.size(1), # N 

1679 B.size(2), # K 

1680 EM, 

1681 num_tokens, 

1682 A.stride(0), 

1683 A.stride(1), 

1684 B.stride(0), 

1685 B.stride(2), 

1686 B.stride(1), 

1687 C.stride(1), 

1688 C.stride(2), 

1689 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, 

1690 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, 

1691 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1692 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, 

1693 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1694 B_bias.stride(0) if B_bias is not None else 0, 

1695 B_bias.stride(1) if B_bias is not None else 0, 

1696 0 if block_shape is None else block_shape[0], 

1697 0 if block_shape is None else block_shape[1], 

1698 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1699 top_k=top_k, 

1700 compute_type=compute_type, 

1701 use_fp8_w8a8=use_fp8_w8a8, 

1702 use_int8_w8a8=use_int8_w8a8, 

1703 use_int8_w8a16=use_int8_w8a16, 

1704 per_channel_quant=per_channel_quant, 

1705 naive_block_assignment=(sorted_token_ids is None), 

1706 HAS_BIAS=HAS_BIAS, 

1707 BLOCK_SIZE_K=BLOCK_SIZE_K, 

1708 SWAP_AB=swap_AB, 

1709 K_DIVISIBLE_BY_BLOCK_K=(B.size(2) % BLOCK_SIZE_K == 0), 

1710 N_DIVISIBLE_BY_BLOCK_N=(actual_N % config["BLOCK_SIZE_N"] == 0), 

1711 PAIR_GATE_UP_DOT=pair_gate_up_dot, 

1712 DIRECT_SUM=direct_sum, 

1713 OUT_TOP_K=out_top_k, 

1714 FUSE_SILU=FUSE_SILU, 

1715 **config, 

1716 ) 

1717 

1718 

1719def dispatch_fused_moe_kernel( 

1720 A: torch.Tensor, 

1721 B: torch.Tensor, 

1722 C: torch.Tensor, 

1723 A_scale: Optional[torch.Tensor], 

1724 B_scale: Optional[torch.Tensor], 

1725 B_zp: Optional[torch.Tensor], 

1726 topk_weights: Optional[torch.Tensor], 

1727 sorted_token_ids: torch.Tensor, 

1728 expert_ids: torch.Tensor, 

1729 num_tokens_post_padded: torch.Tensor, 

1730 mul_routed_weight: bool, 

1731 top_k: int, 

1732 config: dict[str, Any], 

1733 compute_type: tl.dtype, 

1734 use_fp8_w8a8: bool, 

1735 use_int8_w8a8: bool, 

1736 use_int8_w8a16: bool, 

1737 use_int4_w4a16: bool, 

1738 per_channel_quant: bool, 

1739 block_shape: Optional[list[int]] = None, 

1740 B_bias: Optional[torch.Tensor] = None, 

1741 FUSE_SILU: bool = False, 

1742 direct_sum: bool = False, 

1743 out_top_k: int = 1, 

1744) -> None: 

1745 """Dispatch to the appropriate fused MoE kernel based on quantization flags.""" 

1746 assert topk_weights is not None or not mul_routed_weight 

1747 assert topk_weights is None or topk_weights.stride(1) == 1 

1748 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1749 

1750 # M = A.size(0) 

1751 # num_tokens = M * top_k 

1752 

1753 if False: 

1754 # TODO: Other precision-specific implementations 

1755 # use_fp8_w8a8, 

1756 # use_int8_w8a8, 

1757 # use_int8_w8a16, 

1758 # use_int4_w4a16, 

1759 pass 

1760 if (use_int8_w8a16 or use_int4_w4a16) and ( 

1761 block_shape is not None and block_shape[1] > 0 

1762 ): 

1763 assert B_bias is None 

1764 invoke_fused_moe_wna16_triton_kernel( 

1765 A, 

1766 B, 

1767 C, 

1768 B_scale, 

1769 B_zp, 

1770 topk_weights, 

1771 sorted_token_ids, 

1772 expert_ids, 

1773 num_tokens_post_padded, 

1774 mul_routed_weight, 

1775 top_k, 

1776 config, 

1777 compute_type, 

1778 use_int8_w8a16, 

1779 use_int4_w4a16, 

1780 block_shape, 

1781 ) 

1782 else: 

1783 invoke_fused_moe_triton_kernel( 

1784 A, 

1785 B, 

1786 C, 

1787 A_scale, 

1788 B_scale, 

1789 topk_weights, 

1790 sorted_token_ids, 

1791 expert_ids, 

1792 num_tokens_post_padded, 

1793 mul_routed_weight, 

1794 top_k, 

1795 config, 

1796 compute_type, 

1797 use_fp8_w8a8, 

1798 use_int8_w8a8, 

1799 use_int8_w8a16, 

1800 use_int4_w4a16, 

1801 per_channel_quant, 

1802 block_shape, 

1803 B_bias, 

1804 FUSE_SILU=FUSE_SILU, 

1805 direct_sum=direct_sum, 

1806 out_top_k=out_top_k, 

1807 ) 

1808 

1809 

1810def fused_experts_impl( 

1811 hidden_states: torch.Tensor, 

1812 w1: torch.Tensor, 

1813 w2: torch.Tensor, 

1814 topk_weights: torch.Tensor, 

1815 topk_ids: torch.Tensor, 

1816 inplace: bool = False, 

1817 activation: str = "silu", 

1818 apply_router_weight_on_input: bool = False, 

1819 use_fp8_w8a8: bool = False, 

1820 use_int8_w8a8: bool = False, 

1821 use_int8_w8a16: bool = False, 

1822 use_int4_w4a16: bool = False, 

1823 ocp_mx_scheme: str | None = None, 

1824 per_channel_quant: bool = False, 

1825 global_num_experts: int = -1, 

1826 expert_map: torch.Tensor | None = None, 

1827 w1_scale: Optional[torch.Tensor] = None, 

1828 w2_scale: Optional[torch.Tensor] = None, 

1829 w1_zp: torch.Tensor | None = None, 

1830 w2_zp: torch.Tensor | None = None, 

1831 a1_scale: Optional[torch.Tensor] = None, 

1832 a2_scale: Optional[torch.Tensor] = None, 

1833 block_shape: Optional[list[int]] = None, 

1834 w1_bias: Optional[torch.Tensor] = None, 

1835 w2_bias: Optional[torch.Tensor] = None, 

1836) -> torch.Tensor: 

1837 logger.debug("GEMS FUSED MOE") 

1838 assert ( 

1839 activation == "silu" 

1840 ), f"Only 'silu' activation is supported, got {activation}" 

1841 

1842 activation_enum = MoEActivation.from_str(activation) 

1843 

1844 # Check constraints 

1845 if use_int4_w4a16: 

1846 # INT4 stored unpacked in INT8 containers (full K dim) 

1847 assert hidden_states.size(1) == w1.size( 

1848 2 

1849 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1850 elif ocp_mx_scheme is not None: 

1851 if ocp_mx_scheme.startswith("w_mxfp4"): 

1852 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" 

1853 elif ocp_mx_scheme.startswith("w_mxfp6"): 

1854 assert ( 

1855 hidden_states.size(1) == (w1.size(2) * 4) // 3 

1856 ), "hidden size mismatch" 

1857 else: 

1858 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1859 else: 

1860 assert hidden_states.size(1) == w1.size( 

1861 2 

1862 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1863 

1864 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

1865 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

1866 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

1867 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

1868 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

1869 

1870 num_tokens = hidden_states.size(0) 

1871 E, N, _ = w1.size() 

1872 K = w2.size(1) 

1873 if global_num_experts == -1: 

1874 global_num_experts = E 

1875 top_k_num = topk_ids.size(1) 

1876 

1877 CHUNK_SIZE: int = 32 * 1024 

1878 M = min(num_tokens, CHUNK_SIZE) 

1879 

1880 config_dtype = _get_config_dtype_str( 

1881 use_fp8_w8a8=use_fp8_w8a8, 

1882 use_int8_w8a16=use_int8_w8a16, 

1883 use_int4_w4a16=use_int4_w4a16, 

1884 ocp_mx_scheme=ocp_mx_scheme, 

1885 dtype=hidden_states.dtype, 

1886 ) 

1887 is_plain_half_config = config_dtype in _PLAIN_HALF_CONFIG_DTYPES 

1888 is_fp8_blockwise = config_dtype == "fp8_w8a8" and block_shape is not None 

1889 

1890 quant_dtype = _get_config_quant_dtype( 

1891 use_fp8_w8a8=use_fp8_w8a8, 

1892 use_int8_w8a8=use_int8_w8a8, 

1893 ocp_mx_scheme=ocp_mx_scheme, 

1894 ) 

1895 

1896 get_moe_config = functools.partial( 

1897 try_get_optimal_moe_config, 

1898 w1.size(), 

1899 w2.size(), 

1900 top_k_num, 

1901 config_dtype, 

1902 block_shape=block_shape, 

1903 E=E, 

1904 return_is_embedded=True, 

1905 ) 

1906 

1907 base_config, is_embedded_config = get_moe_config(M) 

1908 

1909 # cache1 and cache3 share memory (non-overlapping lifetime) 

1910 cache13 = torch.empty( 

1911 M * top_k_num * max(N, K), 

1912 device=hidden_states.device, 

1913 dtype=hidden_states.dtype, 

1914 ) 

1915 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

1916 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

1917 

1918 # cache2 needs separate memory (concurrent with cache1) 

1919 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

1920 intermediate_cache2 = torch.empty( 

1921 (M * top_k_num, activation_out_dim), 

1922 device=hidden_states.device, 

1923 dtype=hidden_states.dtype, 

1924 ) 

1925 

1926 if hidden_states.dtype == torch.bfloat16: 

1927 compute_type = tl.bfloat16 

1928 elif hidden_states.dtype == torch.float16: 

1929 compute_type = tl.float16 

1930 elif hidden_states.dtype == torch.float32: 

1931 compute_type = tl.float32 

1932 else: 

1933 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

1934 

1935 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

1936 

1937 if ocp_mx_scheme is not None: 

1938 # Dequantize OCP MX weights (TODO: skip on platforms with native MX) 

1939 if ocp_mx_scheme.startswith("w_mxfp4"): 

1940 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) 

1941 w1_scale = None 

1942 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) 

1943 w2_scale = None 

1944 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): 

1945 w1 = dequant_mxfp6( 

1946 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1947 ) 

1948 w1_scale = None 

1949 w2 = dequant_mxfp6( 

1950 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1951 ) 

1952 w2_scale = None 

1953 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): 

1954 w1 = dequant_mxfp6( 

1955 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1956 ) 

1957 w1_scale = None 

1958 w2 = dequant_mxfp6( 

1959 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1960 ) 

1961 w2_scale = None 

1962 else: 

1963 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1964 

1965 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot) 

1966 if use_int8_w8a16 or use_int4_w4a16: 

1967 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype) 

1968 w1_scale = None 

1969 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype) 

1970 w2_scale = None 

1971 use_int8_w8a16 = False 

1972 use_int4_w4a16 = False 

1973 

1974 direct_sum_supported = is_plain_half_config or is_fp8_blockwise 

1975 

1976 # Check if we can safely fuse the activation with the first GEMM pass 

1977 can_use_fused_silu = ( 

1978 activation_enum in (MoEActivation.SILU, MoEActivation.SWIGLUOAI) 

1979 and w1_bias is None 

1980 and expert_map is None # Fused kernel doesn't handle EP -1 experts 

1981 ) 

1982 

1983 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

1984 begin_chunk_idx, end_chunk_idx = ( 

1985 chunk * CHUNK_SIZE, 

1986 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

1987 ) 

1988 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

1989 tokens_in_chunk, _ = curr_hidden_states.size() 

1990 

1991 if tokens_in_chunk == 0: 

1992 break 

1993 

1994 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

1995 # Adjust cache size for last chunk 

1996 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

1997 intermediate_cache2 = intermediate_cache2[ 

1998 : tokens_in_chunk * topk_ids.size(1) 

1999 ] 

2000 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

2001 base_config, is_embedded_config = get_moe_config(tokens_in_chunk) 

2002 

2003 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

2004 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

2005 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

2006 A=curr_hidden_states, 

2007 A_scale=a1_scale, 

2008 quant_dtype=quant_dtype, 

2009 per_act_token_quant=per_channel_quant, 

2010 block_shape=block_shape, 

2011 ocp_mx_scheme=ocp_mx_scheme, 

2012 ) 

2013 

2014 SPARSITY_FACTOR = 4 

2015 naive_block_assignment = ( 

2016 expert_map is None 

2017 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts 

2018 and not ( 

2019 (use_int8_w8a16 or use_int4_w4a16) 

2020 and block_shape is not None 

2021 and block_shape[1] > 0 

2022 ) 

2023 ) 

2024 

2025 if not naive_block_assignment: 

2026 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

2027 curr_topk_ids, 

2028 base_config["BLOCK_SIZE_M"], 

2029 global_num_experts, 

2030 expert_map, 

2031 # ignore_invalid_experts=True, 

2032 ) 

2033 else: 

2034 max_num_tokens_padded = topk_ids.numel() * base_config["BLOCK_SIZE_M"] 

2035 expert_ids = curr_topk_ids.view(-1) 

2036 num_tokens_post_padded = torch.empty( 

2037 (1), dtype=torch.int32, device=topk_ids.device 

2038 ) 

2039 num_tokens_post_padded.fill_(max_num_tokens_padded) 

2040 sorted_token_ids = None 

2041 

2042 # 1. Extract a unified boolean flag for GEMM1 fusion and select config 

2043 do_fuse_silu = can_use_fused_silu and not naive_block_assignment 

2044 use_half_gemm_fast_paths = not is_embedded_config and is_plain_half_config 

2045 

2046 gemm1_config = base_config 

2047 if do_fuse_silu and use_half_gemm_fast_paths: 

2048 gemm1_config, _ = get_moe_config( 

2049 tokens_in_chunk, 

2050 gemm_stage="gemm1", 

2051 enable_gemm_fast_path=True, 

2052 ) 

2053 

2054 # 2. Dynamically determine the differing parameters based on the fusion flag 

2055 if do_fuse_silu: 

2056 # Output goes directly to cache 2 with adjusted dimensions 

2057 out_cache = intermediate_cache2.view( 

2058 tokens_in_chunk, top_k_num, activation_out_dim 

2059 ) 

2060 # Fused kernel weight handling depends on apply_router_weight_on_input 

2061 if apply_router_weight_on_input: 

2062 weights_arg = curr_topk_weights 

2063 else: 

2064 weights_arg = None 

2065 else: 

2066 # Standard path outputs to cache 1 

2067 out_cache = intermediate_cache1 

2068 # Standard path always passes the weights 

2069 weights_arg = curr_topk_weights 

2070 

2071 # 3. Unified GEMM1 dispatch call to eliminate redundant code blocks 

2072 dispatch_fused_moe_kernel( 

2073 qcurr_hidden_states, 

2074 w1, 

2075 out_cache, # Dynamically assigned output buffer 

2076 a1q_scale, 

2077 w1_scale, 

2078 w1_zp, 

2079 weights_arg, # Dynamically assigned weights argument 

2080 sorted_token_ids, 

2081 expert_ids, 

2082 num_tokens_post_padded, 

2083 apply_router_weight_on_input, 

2084 top_k_num, 

2085 gemm1_config, 

2086 compute_type=compute_type, 

2087 use_fp8_w8a8=use_fp8_w8a8, 

2088 use_int8_w8a8=use_int8_w8a8, 

2089 use_int8_w8a16=use_int8_w8a16, 

2090 use_int4_w4a16=use_int4_w4a16, 

2091 per_channel_quant=per_channel_quant, 

2092 block_shape=block_shape, 

2093 B_bias=w1_bias, 

2094 FUSE_SILU=do_fuse_silu, # Master switch for the kernel 

2095 ) 

2096 

2097 # 4. Apply activation separately if the fused path was not taken 

2098 if not do_fuse_silu: 

2099 apply_moe_activation( 

2100 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

2101 ) 

2102 

2103 # 5. Quantize activated intermediate for GEMM2 

2104 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

2105 A=intermediate_cache2, 

2106 A_scale=a2_scale, 

2107 quant_dtype=quant_dtype, 

2108 per_act_token_quant=per_channel_quant, 

2109 block_shape=block_shape, 

2110 ocp_mx_scheme=ocp_mx_scheme, 

2111 ) 

2112 

2113 if expert_map is not None: 

2114 intermediate_cache3.zero_() 

2115 

2116 # 6. Select GEMM2 config and output buffer/reduction path 

2117 gemm2_config = base_config 

2118 if use_half_gemm_fast_paths: 

2119 gemm2_config, _ = get_moe_config( 

2120 tokens_in_chunk, 

2121 gemm_stage="gemm2", 

2122 enable_gemm_fast_path=True, 

2123 ) 

2124 use_direct_sum = ( 

2125 not is_embedded_config 

2126 and direct_sum_supported 

2127 and tokens_in_chunk >= MOE_DIRECT_SUM_MIN_TOKENS 

2128 and expert_map is None 

2129 and not apply_router_weight_on_input 

2130 ) 

2131 if use_direct_sum: 

2132 gemm2_output = out_hidden_states[begin_chunk_idx:end_chunk_idx].view( 

2133 tokens_in_chunk, 1, K 

2134 ) 

2135 gemm2_output.zero_() 

2136 else: 

2137 gemm2_output = intermediate_cache3 

2138 

2139 # 7. Dispatch GEMM2 

2140 dispatch_fused_moe_kernel( 

2141 qintermediate_cache2, 

2142 w2, 

2143 gemm2_output, 

2144 a2q_scale, 

2145 w2_scale, 

2146 w2_zp, 

2147 curr_topk_weights, 

2148 sorted_token_ids, 

2149 expert_ids, 

2150 num_tokens_post_padded, 

2151 not apply_router_weight_on_input, 

2152 1, 

2153 gemm2_config, 

2154 compute_type=compute_type, 

2155 use_fp8_w8a8=use_fp8_w8a8, 

2156 use_int8_w8a8=use_int8_w8a8, 

2157 use_int8_w8a16=use_int8_w8a16, 

2158 use_int4_w4a16=use_int4_w4a16, 

2159 per_channel_quant=per_channel_quant, 

2160 block_shape=block_shape, 

2161 B_bias=w2_bias, 

2162 FUSE_SILU=False, 

2163 direct_sum=use_direct_sum, 

2164 out_top_k=top_k_num, 

2165 ) 

2166 

2167 # 8. Reduce GEMM2 top-k outputs unless direct_sum wrote final output directly 

2168 if not use_direct_sum: 

2169 moe_sum( 

2170 intermediate_cache3.view(*intermediate_cache3.size()), 

2171 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

2172 ) 

2173 

2174 return out_hidden_states 

2175 

2176 

2177def inplace_fused_experts( 

2178 hidden_states: torch.Tensor, 

2179 w1: torch.Tensor, 

2180 w2: torch.Tensor, 

2181 topk_weights: torch.Tensor, 

2182 topk_ids: torch.Tensor, 

2183 activation: str = "silu", 

2184 apply_router_weight_on_input: bool = False, 

2185 use_fp8_w8a8: bool = False, 

2186 use_int8_w8a8: bool = False, 

2187 use_int8_w8a16: bool = False, 

2188 use_int4_w4a16: bool = False, 

2189 per_channel_quant: bool = False, 

2190 global_num_experts: int = -1, 

2191 w1_scale: Optional[torch.Tensor] = None, 

2192 w2_scale: Optional[torch.Tensor] = None, 

2193 a1_scale: Optional[torch.Tensor] = None, 

2194 a2_scale: Optional[torch.Tensor] = None, 

2195 block_shape: Optional[list[int]] = None, 

2196 w1_bias: Optional[torch.Tensor] = None, 

2197 w2_bias: Optional[torch.Tensor] = None, 

2198) -> None: 

2199 """ 

2200 In-place fused MoE: writes output directly into ``hidden_states``. 

2201 

2202 Same semantics as ``fused_experts_impl(..., inplace=True)``. 

2203 Returns None (the result is stored in ``hidden_states``). 

2204 """ 

2205 fused_experts_impl( 

2206 hidden_states, 

2207 w1, 

2208 w2, 

2209 topk_weights, 

2210 topk_ids, 

2211 inplace=True, 

2212 activation=activation, 

2213 apply_router_weight_on_input=apply_router_weight_on_input, 

2214 use_fp8_w8a8=use_fp8_w8a8, 

2215 use_int8_w8a8=use_int8_w8a8, 

2216 use_int8_w8a16=use_int8_w8a16, 

2217 use_int4_w4a16=use_int4_w4a16, 

2218 per_channel_quant=per_channel_quant, 

2219 global_num_experts=global_num_experts, 

2220 w1_scale=w1_scale, 

2221 w2_scale=w2_scale, 

2222 a1_scale=a1_scale, 

2223 a2_scale=a2_scale, 

2224 block_shape=block_shape, 

2225 w1_bias=w1_bias, 

2226 w2_bias=w2_bias, 

2227 ) 

2228 

2229 

2230def outplace_fused_experts( 

2231 hidden_states: torch.Tensor, 

2232 w1: torch.Tensor, 

2233 w2: torch.Tensor, 

2234 topk_weights: torch.Tensor, 

2235 topk_ids: torch.Tensor, 

2236 activation: str = "silu", 

2237 apply_router_weight_on_input: bool = False, 

2238 use_fp8_w8a8: bool = False, 

2239 use_int8_w8a8: bool = False, 

2240 use_int8_w8a16: bool = False, 

2241 use_int4_w4a16: bool = False, 

2242 per_channel_quant: bool = False, 

2243 global_num_experts: int = -1, 

2244 w1_scale: Optional[torch.Tensor] = None, 

2245 w2_scale: Optional[torch.Tensor] = None, 

2246 a1_scale: Optional[torch.Tensor] = None, 

2247 a2_scale: Optional[torch.Tensor] = None, 

2248 block_shape: Optional[list[int]] = None, 

2249 w1_bias: Optional[torch.Tensor] = None, 

2250 w2_bias: Optional[torch.Tensor] = None, 

2251) -> torch.Tensor: 

2252 """ 

2253 Out-of-place fused MoE: allocates and returns a new output tensor. 

2254 

2255 Same semantics as ``fused_experts_impl(..., inplace=False)``. 

2256 """ 

2257 return fused_experts_impl( 

2258 hidden_states, 

2259 w1, 

2260 w2, 

2261 topk_weights, 

2262 topk_ids, 

2263 inplace=False, 

2264 activation=activation, 

2265 apply_router_weight_on_input=apply_router_weight_on_input, 

2266 use_fp8_w8a8=use_fp8_w8a8, 

2267 use_int8_w8a8=use_int8_w8a8, 

2268 use_int8_w8a16=use_int8_w8a16, 

2269 use_int4_w4a16=use_int4_w4a16, 

2270 per_channel_quant=per_channel_quant, 

2271 global_num_experts=global_num_experts, 

2272 w1_scale=w1_scale, 

2273 w2_scale=w2_scale, 

2274 a1_scale=a1_scale, 

2275 a2_scale=a2_scale, 

2276 block_shape=block_shape, 

2277 w1_bias=w1_bias, 

2278 w2_bias=w2_bias, 

2279 )