Coverage for src/flag_gems/fused/fused_moe.py: 42%

692 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# 

4# Adapted from the vLLM project (https://github.com/vllm-project/vllm). 

5# Source files under vllm/model_executor/layers/: 

6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl 

7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation 

8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input 

9# fused_moe/config.py – _get_config_dtype_str 

10# quantization/utils/mxfp4_utils.py – dequant_mxfp4 

11# quantization/utils/mxfp6_utils.py – dequant_mxfp6 

12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE 

13 

14 

15import functools 

16import logging 

17import os 

18from enum import Enum 

19from typing import Any, Optional 

20 

21import torch 

22import torch.nn.functional as F 

23import triton 

24import triton.language as tl 

25import yaml 

26 

27from flag_gems.fused.moe_align_block_size import moe_align_block_size 

28from flag_gems.fused.moe_sum import moe_sum 

29from flag_gems.runtime import device, torch_device_fn 

30from flag_gems.utils import pointwise_dynamic 

31 

32logger = logging.getLogger(__name__) 

33 

34# OCP MX quantization helpers (requires amd-quark) 

35 

36OCP_MX_BLOCK_SIZE = 32 

37 

38 

39@functools.lru_cache(maxsize=1) 

40def get_embedded_moe_configs(): 

41 config_path = os.path.join( 

42 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.yaml" 

43 ) 

44 if not os.path.exists(config_path): 

45 return {}, {} 

46 with open(config_path, "r") as f: 

47 # JSON keys are strings, values are dicts where keys are M and values are configs 

48 data = yaml.safe_load(f) 

49 

50 fallback = data.get("_FALLBACK", {}) 

51 

52 # We need to convert the innermost keys (which are stringified integers for M) back to integers. 

53 # Ensure we map the lists back to config dicts. 

54 keys_order = [ 

55 "BLOCK_SIZE_M", 

56 "BLOCK_SIZE_N", 

57 "BLOCK_SIZE_K", 

58 "GROUP_SIZE_M", 

59 "num_warps", 

60 "num_stages", 

61 ] 

62 parsed_data = {} 

63 for dev, configs in data.items(): 

64 if dev == "_FALLBACK": 

65 continue 

66 parsed_data[dev] = {} 

67 for k, m_dict in configs.items(): 

68 parsed_dict = {} 

69 for m, v in m_dict.items(): 

70 if isinstance(v, list): 

71 parsed_dict[int(m)] = dict(zip(keys_order, v)) 

72 else: 

73 parsed_dict[int(m)] = v 

74 parsed_data[dev][k] = parsed_dict 

75 

76 return parsed_data, fallback 

77 

78 

79def dequant_mxfp4( 

80 x: torch.Tensor, 

81 scale: torch.Tensor, 

82 float_dtype: torch.dtype, 

83) -> torch.Tensor: 

84 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4.""" 

85 try: 

86 from quark.torch.kernel import mx 

87 except ImportError as err: 

88 raise ImportError("amd-quark is required for MX-FP4") from err 

89 

90 return mx.dq_mxfp4(x, scale, float_dtype) 

91 

92 

93def dequant_mxfp6( 

94 x: torch.Tensor, 

95 scale: torch.Tensor, 

96 float_dtype: torch.dtype, 

97 quant_dtype: str, 

98) -> torch.Tensor: 

99 """Dequantize MXFP6 tensor via quark hw_emulation.""" 

100 try: 

101 from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( 

102 dequantize_fp4_fp6_per_group, 

103 ) 

104 from quark.torch.utils.pack import create_pack_method 

105 except ImportError as err: 

106 raise ImportError("amd-quark is required for MX-FP6") from err 

107 

108 pack_method = create_pack_method(None, dtype=quant_dtype) 

109 unpacked_x = pack_method.unpack(x, reorder=False) 

110 

111 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype) 

112 

113 return dequantize_fp4_fp6_per_group( 

114 unpacked_x, 

115 scale, 

116 axis=-1, 

117 group_size=OCP_MX_BLOCK_SIZE, 

118 quant_dtype=quant_dtype, 

119 ).to(float_dtype) 

120 

121 

122# Activation quantization helpers 

123 

124 

125@functools.lru_cache(maxsize=1) 

126def _get_device_name() -> str: 

127 """Return the normalised device name (spaces replaced by underscores). 

128 

129 Matches the naming convention used by vLLM for its per-device config files. 

130 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture). 

131 """ 

132 try: 

133 name = torch_device_fn.get_device_name().replace(" ", "_") 

134 except AttributeError: 

135 name = device.name 

136 # Normalise the H200 product family to a single key, following vLLM. 

137 if "H200" in name.split("_"): 

138 name = "NVIDIA_H200" 

139 # H800 has the same SM 9.0 as H100; use H100 configs as fallback. 

140 embedded_configs, fallback_mapping = get_embedded_moe_configs() 

141 if name in embedded_configs: 

142 return name 

143 # Fallback mapping for devices whose tuning profiles are equivalent. 

144 fallback = fallback_mapping.get(name) 

145 if fallback and fallback in embedded_configs: 

146 logger.info("Device %s not in config table, falling back to %s", name, fallback) 

147 return fallback 

148 return name 

149 

150 

151def get_moe_configs( 

152 E: int, 

153 N: int, 

154 dtype: str | None, 

155 block_n: int | None = None, 

156 block_k: int | None = None, 

157) -> dict[int, Any] | None: 

158 """ 

159 Return optimized configurations for the fused MoE kernel. 

160 

161 Looks up pre-tuned configs from the embedded table (ported from vLLM) 

162 for the current GPU device. Returns None if no matching config is found. 

163 """ 

164 device_name = _get_device_name() 

165 embedded_configs, _ = get_embedded_moe_configs() 

166 device_table = embedded_configs.get(device_name) 

167 if device_table is None: 

168 logger.warning( 

169 "No embedded MoE configs for device %s. Will use default config.", 

170 device_name, 

171 ) 

172 return None 

173 

174 _block_n = block_n if block_n else 0 

175 _block_k = block_k if block_k else 0 

176 key = f"{E},{N},{dtype},{_block_n},{_block_k}" 

177 configs = device_table.get(key) 

178 if configs is not None: 

179 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key) 

180 return configs 

181 logger.warning( 

182 "No embedded MoE config for device=%s, key=%s. Will use default config.", 

183 device_name, 

184 key, 

185 ) 

186 return None 

187 

188 

189def try_get_optimal_moe_config( 

190 w1_shape: tuple[int, ...], 

191 w2_shape: tuple[int, ...], 

192 top_k: int, 

193 dtype: str | None, 

194 M: int, 

195 E: int, 

196 block_shape: list[int] | None = None, 

197) -> dict[str, int]: 

198 override_config: Optional[dict[str, Any]] = None 

199 

200 is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == ( 

201 9, 

202 0, 

203 ) 

204 if ( 

205 is_hopper 

206 and dtype == "fp8_w8a8" 

207 and block_shape is not None 

208 and len(block_shape) == 2 

209 ): 

210 # Use heuristic config like hpc-ops 

211 avg_tokens_per_expert = M * top_k // E 

212 if avg_tokens_per_expert <= 16: 

213 block_size_m = 16 

214 elif avg_tokens_per_expert <= 32: 

215 block_size_m = 32 

216 elif avg_tokens_per_expert <= 48: 

217 block_size_m = 48 

218 else: 

219 block_size_m = 64 

220 config = { 

221 "BLOCK_SIZE_M": block_size_m, 

222 "BLOCK_SIZE_N": block_shape[0], 

223 "BLOCK_SIZE_K": block_shape[1], 

224 "GROUP_SIZE_M": 1, 

225 "num_warps": 4, 

226 "num_stages": 3, 

227 "SWAP_AB": True, 

228 } 

229 override_config = config 

230 

231 if override_config: 

232 config = override_config 

233 else: 

234 # First try to load optimal config from the file 

235 E, _, N = w2_shape 

236 if dtype == "int4_w4a16": 

237 N = N * 2 

238 block_n = block_shape[0] if block_shape else 0 

239 block_k = block_shape[1] if block_shape else 0 

240 configs = get_moe_configs(E, N, dtype, block_n, block_k) 

241 

242 if configs: 

243 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

244 else: 

245 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) 

246 return config 

247 

248 

249def _get_config_quant_dtype( 

250 use_fp8_w8a8: bool, 

251 use_int8_w8a8: bool, 

252 ocp_mx_scheme: str | None, 

253) -> None | torch.dtype | str: 

254 """Map quantization flags to the corresponding dtype.""" 

255 if use_fp8_w8a8: 

256 return torch.float8_e4m3fn 

257 elif use_int8_w8a8: 

258 return torch.int8 

259 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": 

260 return "mxfp4" 

261 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: 

262 return "mxfp6_e3m2" 

263 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: 

264 return "mxfp6_e2m3" 

265 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}: 

266 return torch.bfloat16 

267 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}: 

268 return torch.float8_e4m3fn 

269 

270 return None 

271 

272 

273def get_moe_wna16_block_config( 

274 config: dict[str, int], 

275 use_moe_wna16_cuda: bool, 

276 num_valid_tokens: int, 

277 size_k: int, 

278 size_n: int, 

279 num_experts: int, 

280 group_size: int, 

281 real_top_k: int, 

282 block_size_m: int, 

283): 

284 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: 

285 return {} 

286 if not use_moe_wna16_cuda: 

287 if num_valid_tokens // real_top_k == 1: 

288 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} 

289 else: 

290 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} 

291 else: 

292 block_size_n = 128 

293 block_size_k = 128 

294 if block_size_k <= group_size: 

295 block_size_k = group_size 

296 

297 num_n_blocks = size_k // block_size_k 

298 num_k_blocks = size_n // block_size_k 

299 num_m_blocks = ( 

300 num_valid_tokens + block_size_m - 1 

301 ) / block_size_m + num_experts 

302 if num_valid_tokens // real_top_k <= block_size_m: 

303 num_m_blocks = min(num_m_blocks, num_valid_tokens) 

304 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks 

305 

306 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: 

307 block_size_k = 256 

308 num_blocks = num_blocks // (256 // block_size_k) 

309 

310 if ( 

311 num_m_blocks <= 16 

312 and size_k % (block_size_k * 2) == 0 

313 and size_k % (block_size_k * 2) == 0 

314 and block_size_k <= 512 

315 and num_blocks >= 512 

316 ): 

317 block_size_k = block_size_k * 2 

318 num_blocks = num_blocks // 2 

319 

320 if num_blocks > 1024: 

321 block_size_n = 256 

322 num_n_blocks = num_n_blocks // 2 

323 num_blocks = num_blocks // 2 

324 

325 if size_n <= 1024 and num_blocks >= 1024: 

326 block_size_n = 1024 

327 

328 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) 

329 

330 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} 

331 

332 

333def get_default_config( 

334 M: int, 

335 E: int, 

336 N: int, 

337 K: int, 

338 topk: int, 

339 dtype: str | None, 

340 block_shape: list[int] | None = None, 

341) -> dict[str, int]: 

342 """Default Triton config for fused MoE kernel. 

343 

344 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100. 

345 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each 

346 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical. 

347 """ 

348 if dtype == "fp8_w8a8" and block_shape is not None: 

349 config = { 

350 "BLOCK_SIZE_M": 16 if M <= 64 else 64, 

351 "BLOCK_SIZE_N": block_shape[0], 

352 "BLOCK_SIZE_K": block_shape[1], 

353 "GROUP_SIZE_M": 1 if M <= 16 else 32, 

354 "num_warps": 4, 

355 "num_stages": 3, 

356 } 

357 else: 

358 # tokens_per_expert drives block_m: use M//E (not M*topk//E) to 

359 # estimate the actual per-expert token count after routing. 

360 tokens_per_expert = M // max(E, 1) 

361 

362 if tokens_per_expert <= 2: 

363 block_m = 16 

364 elif tokens_per_expert <= 4: 

365 block_m = 32 

366 elif tokens_per_expert <= 16: 

367 block_m = 64 

368 else: 

369 block_m = 128 

370 

371 # Tile sizing 

372 if N >= 4096: 

373 block_n = 128 if M <= 128 else 256 

374 elif N >= 1024: 

375 block_n = 64 if M <= 64 else 128 

376 else: 

377 block_n = 64 if M <= 64 else 128 

378 

379 if dtype == "fp8_w8a8": 

380 block_k = 128 

381 elif M <= 64: 

382 block_k = 128 

383 else: 

384 block_k = 64 

385 

386 if tokens_per_expert > 128: 

387 group_m = 16 

388 elif tokens_per_expert > 32: 

389 group_m = 8 

390 else: 

391 group_m = 1 

392 

393 # Prefer 4 warps for small tiles; only use 8 for large M 

394 num_warps = 4 if M <= 128 else 8 

395 num_stages = 3 

396 

397 smem_per_stage = (block_m * block_k + block_k * block_n) * 2 

398 while num_stages > 2 and smem_per_stage * num_stages > 200_000: 

399 num_stages -= 1 

400 

401 config = { 

402 "BLOCK_SIZE_M": block_m, 

403 "BLOCK_SIZE_N": block_n, 

404 "BLOCK_SIZE_K": block_k, 

405 "GROUP_SIZE_M": group_m, 

406 "num_warps": num_warps, 

407 "num_stages": num_stages, 

408 } 

409 return config 

410 

411 

412def _get_config_dtype_str( 

413 dtype: Optional[torch.dtype] = None, 

414 use_fp8_w8a8: bool = False, 

415 use_fp8_w8a16: bool = False, 

416 use_int8_w8a16: bool = False, 

417 use_int4_w4a16: bool = False, 

418 ocp_mx_scheme: str | None = None, 

419) -> str | None: 

420 """Return dtype string for kernel config lookup.""" 

421 if use_fp8_w8a8: 

422 return "fp8_w8a8" 

423 elif use_fp8_w8a16: 

424 return "fp8_w8a16" 

425 elif use_int8_w8a16: 

426 return "int8_w8a16" 

427 elif use_int4_w4a16: 

428 return "int4_w4a16" 

429 elif ocp_mx_scheme is not None: 

430 return None 

431 elif dtype == torch.float: 

432 return "float32" 

433 return None 

434 

435 

436# MoE activation enum 

437 

438 

439class MoEActivation(Enum): 

440 """Activation functions for MoE layers.""" 

441 

442 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d] 

443 SILU = "silu" 

444 GELU = "gelu" 

445 RELU2 = "relu2" 

446 SWIGLUOAI = "swigluoai" 

447 SWIGLUSTEP = "swiglustep" 

448 

449 # Non-gated: input [..., d] -> output [..., d] 

450 SILU_NO_MUL = "silu_no_mul" 

451 GELU_NO_MUL = "gelu_no_mul" 

452 RELU2_NO_MUL = "relu2_no_mul" 

453 

454 @property 

455 def is_gated(self) -> bool: 

456 return not self.value.endswith("_no_mul") 

457 

458 def without_mul(self) -> "MoEActivation": 

459 """Return the non-gated variant.""" 

460 _without_mul: dict[MoEActivation, MoEActivation] = { 

461 MoEActivation.SILU: MoEActivation.SILU_NO_MUL, 

462 MoEActivation.GELU: MoEActivation.GELU_NO_MUL, 

463 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, 

464 } 

465 return _without_mul.get(self, self) 

466 

467 @classmethod 

468 def from_str(cls, s: str) -> "MoEActivation": 

469 for member in cls: 

470 if member.value == s: 

471 return member 

472 valid = [m.value for m in cls] 

473 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}") 

474 

475 @staticmethod 

476 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int: 

477 """Return N for non-gated, N // 2 for gated activations.""" 

478 return N if not activation.is_gated else N // 2 

479 

480 

481def apply_moe_activation( 

482 activation: MoEActivation, 

483 output: torch.Tensor, 

484 input: torch.Tensor, 

485) -> torch.Tensor: 

486 """Apply MoE activation (pure PyTorch / FlagGems Triton).""" 

487 assert input.dim() == 2, "Input must be 2D" 

488 assert output.dim() == 2, "Output must be 2D" 

489 if activation.is_gated: 

490 assert output.size(-1) * 2 == input.size(-1), ( 

491 f"{activation.value} expects 2x ratio: " 

492 f"{output.size(-1) * 2} vs {input.size(-1)}" 

493 ) 

494 else: 

495 assert output.size(-1) == input.size(-1), ( 

496 f"{activation.value} expects equal sizes: " 

497 f"{output.size(-1)} vs {input.size(-1)}" 

498 ) 

499 

500 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): 

501 N = output.size(-1) 

502 x, y = input[:, :N], input[:, N:] 

503 _silu_and_mul_kernel(x, y, out0=output) 

504 elif activation == MoEActivation.GELU: 

505 N = output.size(-1) 

506 gate, up = input[:, :N], input[:, N:] 

507 output.copy_(F.gelu(gate) * up) 

508 elif activation == MoEActivation.SWIGLUSTEP: 

509 N = output.size(-1) 

510 gate, up = input[:, :N], input[:, N:] 

511 output.copy_(torch.sigmoid(gate) * up) 

512 elif activation == MoEActivation.RELU2: 

513 N = output.size(-1) 

514 gate, up = input[:, :N], input[:, N:] 

515 output.copy_(F.relu(gate).square() * up) 

516 

517 elif activation == MoEActivation.SILU_NO_MUL: 

518 output.copy_(F.silu(input)) 

519 elif activation == MoEActivation.GELU_NO_MUL: 

520 output.copy_(F.gelu(input)) 

521 elif activation == MoEActivation.RELU2_NO_MUL: 

522 F.relu(input, inplace=True) 

523 torch.square(input, out=output) 

524 else: 

525 raise ValueError(f"Unsupported FusedMoe activation: {activation}") 

526 

527 return output 

528 

529 

530def _fp8_quantize( 

531 A: torch.Tensor, 

532 A_scale: Optional[torch.Tensor], 

533 per_act_token: bool, 

534 block_shape: Optional[list[int]] = None, 

535) -> tuple[torch.Tensor, torch.Tensor]: 

536 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise.""" 

537 fp8_dtype = torch.float8_e4m3fn 

538 finfo = torch.finfo(fp8_dtype) 

539 fp8_max = finfo.max 

540 fp8_min = finfo.min 

541 eps = 1e-10 

542 

543 if block_shape is not None: 

544 assert not per_act_token 

545 assert len(block_shape) == 2 

546 block_k = block_shape[1] 

547 assert A.size(-1) % block_k == 0 

548 if A.ndim == 2 and A.stride(-1) == 1: 

549 from flag_gems.ops.per_token_group_quant_fp8 import ( 

550 per_token_group_quant_fp8, 

551 ) 

552 

553 return per_token_group_quant_fp8( 

554 A, 

555 group_size=block_k, 

556 eps=eps, 

557 dtype=fp8_dtype, 

558 column_major_scales=False, 

559 scale_ue8m0=False, 

560 ) 

561 orig_shape = A.shape 

562 A_flat = A.reshape(-1, A.size(-1)) 

563 M, K = A_flat.shape 

564 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

565 amax = ( 

566 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

567 ) 

568 scale = amax / fp8_max 

569 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

570 A_q = A_q.reshape(orig_shape) 

571 scale = scale.reshape(M, K // block_k) 

572 return A_q, scale 

573 

574 elif per_act_token: 

575 A_flat = A.reshape(-1, A.size(-1)) 

576 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

577 scale = amax / fp8_max 

578 min_scale = torch.tensor( 

579 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device 

580 ) 

581 scale = scale.clamp(min=min_scale) 

582 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

583 A_q = A_q.reshape(A.shape) 

584 scale = scale.reshape(A.shape[:-1] + (1,)) 

585 return A_q, scale 

586 

587 else: 

588 if A_scale is not None: 

589 scale = ( 

590 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

591 ) 

592 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

593 return A_q, A_scale 

594 else: 

595 amax = A.abs().amax().clamp(min=eps).to(torch.float32) 

596 scale = amax / fp8_max 

597 iscale = 1.0 / scale 

598 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

599 return A_q, scale.view(1) 

600 

601 

602def _int8_quantize( 

603 A: torch.Tensor, 

604 A_scale: Optional[torch.Tensor], 

605 per_act_token: bool, 

606 block_shape: Optional[list[int]] = None, 

607) -> tuple[torch.Tensor, torch.Tensor]: 

608 """INT8 quantization: per-tensor, per-token, or block-wise.""" 

609 iinfo = torch.iinfo(torch.int8) 

610 int8_max = iinfo.max 

611 int8_min = iinfo.min 

612 eps = 1e-10 

613 

614 if block_shape is not None: 

615 assert not per_act_token 

616 assert len(block_shape) == 2 

617 block_k = block_shape[1] 

618 assert A.size(-1) % block_k == 0 

619 orig_shape = A.shape 

620 A_flat = A.reshape(-1, A.size(-1)) 

621 M, K = A_flat.shape 

622 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

623 amax = ( 

624 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

625 ) 

626 scale = amax / int8_max 

627 A_q = ( 

628 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

629 ) 

630 A_q = A_q.reshape(orig_shape) 

631 scale = scale.reshape(M, K // block_k) 

632 return A_q, scale 

633 

634 elif per_act_token: 

635 A_flat = A.reshape(-1, A.size(-1)) 

636 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

637 scale = amax / int8_max 

638 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

639 A_q = A_q.reshape(A.shape) 

640 scale = scale.reshape(A.shape[:-1] + (1,)) 

641 return A_q, scale 

642 

643 else: 

644 assert A_scale is not None, "int8 per-tensor requires A_scale" 

645 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

646 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

647 return A_q, A_scale 

648 

649 

650def moe_kernel_quantize_input( 

651 A: torch.Tensor, 

652 A_scale: Optional[torch.Tensor], 

653 quant_dtype: None | torch.dtype | str, 

654 per_act_token_quant: bool, 

655 block_shape: Optional[list[int]] = None, 

656 ocp_mx_scheme: str | None = None, 

657) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 

658 """Quantize MoE input activations before GEMM.""" 

659 if ocp_mx_scheme is not None: 

660 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}: 

661 pass 

662 elif ocp_mx_scheme.endswith("a_fp8"): 

663 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False) 

664 A = (qA.float() * qA_scale.float()).to(A.dtype) 

665 return A, None 

666 

667 if quant_dtype is None: 

668 return A, A_scale 

669 elif quant_dtype == torch.float8_e4m3fn: 

670 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) 

671 elif quant_dtype == torch.int8: 

672 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) 

673 else: 

674 return A, A_scale 

675 

676 

677def _ensure_block_size_k_divisible( 

678 size_k: int, block_size_k: int, group_size: int 

679) -> int: 

680 """Find largest block_size_k that divides size_k and is divisible by group_size.""" 

681 if size_k % block_size_k == 0 and block_size_k % group_size == 0: 

682 return block_size_k 

683 

684 max_search = min(block_size_k, size_k) 

685 start = (max_search // group_size) * group_size 

686 for candidate in range(start, group_size - 1, -group_size): 

687 if size_k % candidate == 0: 

688 return candidate 

689 

690 if size_k % group_size == 0: 

691 return group_size 

692 

693 return size_k 

694 

695 

696@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

697@triton.jit 

698def _silu_and_mul_kernel(x, y): 

699 x_fp32 = x.to(tl.float32) 

700 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

701 return x_silu * y 

702 

703 

704@triton.jit 

705def write_zeros_to_output( 

706 c_ptr, 

707 stride_cm, 

708 stride_cn, 

709 pid_n, 

710 N, 

711 offs_token, 

712 token_mask, 

713 BLOCK_SIZE_M, 

714 BLOCK_SIZE_N, 

715 compute_type, 

716): 

717 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) 

718 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

719 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

720 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

721 tl.store(c_ptrs, accumulator, mask=c_mask) 

722 

723 

724@triton.jit 

725def fused_moe_kernel_gptq_awq( 

726 # Pointers to matrices 

727 a_ptr, 

728 b_ptr, 

729 c_ptr, 

730 b_scale_ptr, 

731 b_zp_ptr, 

732 topk_weights_ptr, 

733 sorted_token_ids_ptr, 

734 expert_ids_ptr, 

735 num_tokens_post_padded_ptr, 

736 # Matrix dimensions 

737 N: tl.constexpr, 

738 K: tl.constexpr, 

739 EM, 

740 num_valid_tokens, 

741 # The stride variables represent how much to increase the ptr by when 

742 # moving by 1 element in a particular dimension. E.g. `stride_am` is 

743 # how much to increase `a_ptr` by to get the element one row down 

744 # (A has M rows). 

745 stride_am, 

746 stride_ak, 

747 stride_be, 

748 stride_bk, 

749 stride_bn, 

750 stride_cm, 

751 stride_cn, 

752 stride_bse, 

753 stride_bsk, 

754 stride_bsn, 

755 stride_bze, 

756 stride_bzk, 

757 stride_bzn, 

758 block_k_diviable: tl.constexpr, 

759 group_size: tl.constexpr, 

760 # Meta-parameters 

761 BLOCK_SIZE_M: tl.constexpr, 

762 BLOCK_SIZE_N: tl.constexpr, 

763 BLOCK_SIZE_K: tl.constexpr, 

764 GROUP_SIZE_M: tl.constexpr, 

765 SPLIT_K: tl.constexpr, 

766 MUL_ROUTED_WEIGHT: tl.constexpr, 

767 top_k: tl.constexpr, 

768 compute_type: tl.constexpr, 

769 has_zp: tl.constexpr, 

770 use_int4_w4a16: tl.constexpr, 

771 use_int8_w8a16: tl.constexpr, 

772): 

773 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights.""" 

774 # Map pid to C block (grouped ordering for L2 reuse) 

775 pid = tl.program_id(axis=0) 

776 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

777 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

778 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

779 group_id = pid // num_pid_in_group 

780 first_pid_m = group_id * GROUP_SIZE_M 

781 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

782 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

783 pid_n = (pid % num_pid_in_group) // group_size_m 

784 

785 # Create pointers for first blocks of A and B 

786 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

787 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

788 return 

789 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

790 # Cast to int64 to prevent overflow in stride*offset products 

791 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

792 token_mask = offs_token < num_valid_tokens 

793 

794 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

795 if off_experts == -1: 

796 # ----------------------------------------------------------- 

797 # Write back zeros to the output when the expert is not 

798 # in the current expert parallel rank. 

799 write_zeros_to_output( 

800 c_ptr, 

801 stride_cm, 

802 stride_cn, 

803 pid_n, 

804 N, 

805 offs_token, 

806 token_mask, 

807 BLOCK_SIZE_M, 

808 BLOCK_SIZE_N, 

809 compute_type, 

810 ) 

811 return 

812 

813 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

814 offs_k = tl.arange(0, BLOCK_SIZE_K) 

815 a_ptrs = a_ptr + ( 

816 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

817 ) 

818 

819 if use_int4_w4a16: 

820 b_ptrs = ( 

821 b_ptr 

822 + off_experts * stride_be 

823 + (offs_k[:, None] // 2) * stride_bk 

824 + offs_bn[None, :] * stride_bn 

825 ) 

826 b_shifter = (offs_k[:, None] % 2) * 4 

827 elif use_int8_w8a16: 

828 b_ptrs = ( 

829 b_ptr 

830 + off_experts * stride_be 

831 + offs_k[:, None] * stride_bk 

832 + offs_bn[None, :] * stride_bn 

833 ) 

834 

835 if not has_zp and use_int4_w4a16: 

836 b_zp_num = 8 

837 if not has_zp and use_int8_w8a16: 

838 b_zp_num = 128 

839 elif has_zp and use_int4_w4a16: 

840 b_zp_shifter = (offs_bn[None, :] % 2) * 4 

841 

842 # Accumulate C block in fp32 

843 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

844 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

845 if not block_k_diviable: 

846 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K 

847 k_other = 0.0 

848 else: 

849 k_mask = None 

850 k_other = None 

851 

852 a = tl.load( 

853 a_ptrs, 

854 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

855 other=0.0, 

856 ) 

857 b = tl.load(b_ptrs) 

858 if use_int4_w4a16: 

859 b = (b >> b_shifter) & 0xF 

860 

861 b_scale_ptrs = ( 

862 b_scale_ptr 

863 + off_experts * stride_bse 

864 + offs_bn[None, :] * stride_bsn 

865 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk 

866 ) 

867 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) 

868 b_scale = b_scale.to(tl.float32) 

869 

870 if has_zp and use_int4_w4a16: 

871 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

872 b_zp_ptrs = ( 

873 b_zp_ptr 

874 + off_experts * stride_bze 

875 + (offs_bn[None, :] // 2) * stride_bzn 

876 + offs_k_true * stride_bzk 

877 ) 

878 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

879 b_zp = (b_zp >> b_zp_shifter) & 0xF 

880 b_zp = b_zp.to(tl.float32) 

881 elif has_zp and use_int8_w8a16: 

882 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

883 b_zp_ptrs = ( 

884 b_zp_ptr 

885 + off_experts * stride_bze 

886 + offs_bn[None, :] * stride_bzn 

887 + offs_k_true * stride_bzk 

888 ) 

889 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

890 b_zp = b_zp.to(tl.float32) 

891 

892 if has_zp: 

893 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) 

894 else: 

895 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) 

896 accumulator = tl.dot(a, b, acc=accumulator) 

897 

898 a_ptrs += BLOCK_SIZE_K * stride_ak 

899 if use_int4_w4a16: 

900 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk 

901 else: 

902 b_ptrs += BLOCK_SIZE_K * stride_bk 

903 

904 if MUL_ROUTED_WEIGHT: 

905 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

906 accumulator = accumulator * moe_weight[:, None] 

907 

908 accumulator = accumulator.to(compute_type) 

909 # Write back output 

910 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

911 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

912 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

913 tl.store(c_ptrs, accumulator, mask=c_mask) 

914 

915 

916@triton.jit 

917def fused_moe_kernel( 

918 # Pointers to matrices 

919 a_ptr, 

920 b_ptr, 

921 c_ptr, 

922 b_bias_ptr, 

923 a_scale_ptr, 

924 b_scale_ptr, 

925 topk_weights_ptr, 

926 sorted_token_ids_ptr, 

927 expert_ids_ptr, 

928 num_tokens_post_padded_ptr, 

929 # Matrix dimensions 

930 N, 

931 K, 

932 EM, 

933 num_valid_tokens, 

934 stride_am, 

935 stride_ak, 

936 stride_be, 

937 stride_bk, 

938 stride_bn, 

939 stride_cm, 

940 stride_cn, 

941 stride_asm, 

942 stride_ask, 

943 stride_bse, 

944 stride_bsk, 

945 stride_bsn, 

946 stride_bbe, # bias expert stride 

947 stride_bbn, # bias N stride 

948 # Block size for block-wise quantization 

949 group_n: tl.constexpr, 

950 group_k: tl.constexpr, 

951 naive_block_assignment: tl.constexpr, 

952 # Meta-parameters 

953 BLOCK_SIZE_M: tl.constexpr, 

954 BLOCK_SIZE_N: tl.constexpr, 

955 BLOCK_SIZE_K: tl.constexpr, 

956 GROUP_SIZE_M: tl.constexpr, 

957 SPLIT_K: tl.constexpr, 

958 MUL_ROUTED_WEIGHT: tl.constexpr, 

959 top_k: tl.constexpr, 

960 compute_type: tl.constexpr, 

961 use_fp8_w8a8: tl.constexpr, 

962 use_int8_w8a8: tl.constexpr, 

963 use_int8_w8a16: tl.constexpr, 

964 per_channel_quant: tl.constexpr, 

965 HAS_BIAS: tl.constexpr, 

966 SWAP_AB: tl.constexpr, 

967): 

968 """Fused MoE kernel: token × expert GEMM with quantization support.""" 

969 # Map pid to C block (grouped ordering for L2 reuse) 

970 pid = tl.program_id(axis=0) 

971 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

972 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

973 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

974 group_id = pid // num_pid_in_group 

975 first_pid_m = group_id * GROUP_SIZE_M 

976 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

977 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

978 pid_n = (pid % num_pid_in_group) // group_size_m 

979 

980 # Create pointers for first blocks of A and B 

981 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

982 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

983 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

984 return 

985 if not naive_block_assignment: 

986 offs_token_id = pid_m * BLOCK_SIZE_M + offs 

987 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) 

988 else: 

989 offs_token = tl.where( 

990 offs == 0, 

991 pid_m, # first element = pid_m 

992 num_valid_tokens, # remaining elements = constant 

993 ) 

994 offs_token = offs_token.to(tl.int64) # prevent int32 overflow 

995 

996 token_mask = offs_token < num_valid_tokens 

997 

998 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

999 if off_experts == -1: 

1000 # Expert not in current EP rank, write zeros 

1001 write_zeros_to_output( 

1002 c_ptr, 

1003 stride_cm, 

1004 stride_cn, 

1005 pid_n, 

1006 N, 

1007 offs_token, 

1008 token_mask, 

1009 BLOCK_SIZE_M, 

1010 BLOCK_SIZE_N, 

1011 compute_type, 

1012 ) 

1013 return 

1014 

1015 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

1016 offs_k = tl.arange(0, BLOCK_SIZE_K) 

1017 a_ptrs = a_ptr + ( 

1018 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

1019 ) 

1020 

1021 b_ptrs = ( 

1022 b_ptr 

1023 + off_experts * stride_be 

1024 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

1025 ) 

1026 if use_int8_w8a16: 

1027 b_scale_ptrs = ( 

1028 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

1029 ) 

1030 b_scale = tl.load(b_scale_ptrs) 

1031 

1032 if use_fp8_w8a8 or use_int8_w8a8: 

1033 if group_k > 0 and group_n > 0: # block-wise 

1034 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1035 offs_bsn = offs_bn // group_n 

1036 b_scale_ptrs = ( 

1037 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn 

1038 ) 

1039 elif per_channel_quant: # channel-wise 

1040 b_scale_ptrs = ( 

1041 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

1042 ) 

1043 b_scale = tl.load(b_scale_ptrs) 

1044 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1045 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

1046 else: # tensor-wise 

1047 a_scale = tl.load(a_scale_ptr) 

1048 b_scale = tl.load(b_scale_ptr + off_experts) 

1049 if HAS_BIAS: 

1050 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn 

1051 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) 

1052 # Accumulate C block in fp32 

1053 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

1054 if SWAP_AB: 

1055 accumulator_nm = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32) 

1056 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1057 a = tl.load( 

1058 a_ptrs, 

1059 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

1060 other=0.0, 

1061 ) 

1062 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 

1063 if use_int8_w8a16: 

1064 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) 

1065 elif use_fp8_w8a8 or use_int8_w8a8: 

1066 if group_k > 0 and group_n > 0: 

1067 k_start = k * BLOCK_SIZE_K 

1068 offs_ks = k_start // group_k 

1069 a_scale = tl.load( 

1070 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

1071 ) 

1072 if SWAP_AB: 

1073 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

1074 accumulator_nm += ( 

1075 tl.dot(tl.trans(b), tl.trans(a)) 

1076 * b_scale[:, None] 

1077 * a_scale[None, :] 

1078 ) 

1079 else: 

1080 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

1081 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

1082 else: 

1083 if use_fp8_w8a8: 

1084 accumulator = tl.dot(a, b, acc=accumulator) 

1085 else: 

1086 accumulator += tl.dot(a, b) 

1087 else: 

1088 accumulator += tl.dot(a, b) 

1089 a_ptrs += BLOCK_SIZE_K * stride_ak 

1090 b_ptrs += BLOCK_SIZE_K * stride_bk 

1091 

1092 if SWAP_AB: 

1093 accumulator = tl.trans(accumulator_nm) 

1094 

1095 # Dequantization 

1096 if use_int8_w8a16: 

1097 accumulator = accumulator * b_scale 

1098 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): 

1099 accumulator = accumulator * a_scale * b_scale 

1100 

1101 if HAS_BIAS: 

1102 accumulator += bias[None, :] 

1103 

1104 # Router weight multiplication (must be in fp32) 

1105 if MUL_ROUTED_WEIGHT: 

1106 moe_weight = tl.load( 

1107 topk_weights_ptr + offs_token, 

1108 mask=token_mask, 

1109 other=0, 

1110 ) 

1111 accumulator *= moe_weight[:, None] 

1112 

1113 accumulator = accumulator.to(compute_type) 

1114 

1115 # Write back output 

1116 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

1117 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

1118 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

1119 tl.store(c_ptrs, accumulator, mask=c_mask) 

1120 

1121 

1122def invoke_fused_moe_wna16_triton_kernel( 

1123 A: torch.Tensor, 

1124 B: torch.Tensor, 

1125 C: torch.Tensor, 

1126 B_scale: torch.Tensor | None, 

1127 B_zp: torch.Tensor | None, 

1128 topk_weights: torch.Tensor | None, 

1129 sorted_token_ids: torch.Tensor, 

1130 expert_ids: torch.Tensor, 

1131 num_tokens_post_padded: torch.Tensor, 

1132 mul_routed_weight: bool, 

1133 top_k: int, 

1134 config: dict[str, Any], 

1135 compute_type: tl.dtype, 

1136 use_int8_w8a16: bool, 

1137 use_int4_w4a16: bool, 

1138 block_shape: list[int] | None, 

1139): 

1140 assert B_scale is not None and B_scale.ndim == 3 

1141 assert B_zp is None or B_zp.ndim == 3 

1142 assert block_shape is not None and block_shape[0] == 0 

1143 

1144 M = A.size(0) 

1145 num_tokens = M * top_k 

1146 

1147 EM = sorted_token_ids.size(0) 

1148 if A.size(0) < config["BLOCK_SIZE_M"]: 

1149 # optimize for small batch_size. 

1150 # We assume that top_ids of each token is unique, 

1151 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, 

1152 # and we can skip some invalid blocks. 

1153 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) 

1154 grid = lambda META: ( 

1155 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1156 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1157 ) 

1158 config = config.copy() 

1159 config.update( 

1160 get_moe_wna16_block_config( 

1161 config=config, 

1162 use_moe_wna16_cuda=False, 

1163 num_valid_tokens=num_tokens, 

1164 size_k=A.size(1), 

1165 size_n=B.size(1), 

1166 num_experts=B.size(1), 

1167 group_size=block_shape[1], 

1168 real_top_k=top_k, 

1169 block_size_m=config["BLOCK_SIZE_M"], 

1170 ) 

1171 ) 

1172 

1173 fused_moe_kernel_gptq_awq[grid]( 

1174 A, 

1175 B, 

1176 C, 

1177 B_scale, 

1178 B_zp, 

1179 topk_weights, 

1180 sorted_token_ids, 

1181 expert_ids, 

1182 num_tokens_post_padded, 

1183 B.size(1), 

1184 A.size(1), 

1185 EM, 

1186 num_tokens, 

1187 A.stride(0), 

1188 A.stride(1), 

1189 B.stride(0), 

1190 B.stride(2), 

1191 B.stride(1), 

1192 C.stride(1), 

1193 C.stride(2), 

1194 B_scale.stride(0), 

1195 B_scale.stride(2), 

1196 B_scale.stride(1), 

1197 B_zp.stride(0) if B_zp is not None else 0, 

1198 B_zp.stride(2) if B_zp is not None else 0, 

1199 B_zp.stride(1) if B_zp is not None else 0, 

1200 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, 

1201 group_size=block_shape[1], 

1202 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1203 top_k=top_k, 

1204 compute_type=compute_type, 

1205 has_zp=B_zp is not None, 

1206 use_int4_w4a16=use_int4_w4a16, 

1207 use_int8_w8a16=use_int8_w8a16, 

1208 **config, 

1209 ) 

1210 

1211 

1212def invoke_fused_moe_triton_kernel( 

1213 A: torch.Tensor, 

1214 B: torch.Tensor, 

1215 C: torch.Tensor, 

1216 A_scale: Optional[torch.Tensor], 

1217 B_scale: Optional[torch.Tensor], 

1218 topk_weights: Optional[torch.Tensor], 

1219 sorted_token_ids: torch.Tensor, 

1220 expert_ids: torch.Tensor, 

1221 num_tokens_post_padded: torch.Tensor, 

1222 mul_routed_weight: bool, 

1223 top_k: int, 

1224 config: dict[str, Any], 

1225 compute_type: tl.dtype, 

1226 use_fp8_w8a8: bool = False, 

1227 use_int8_w8a8: bool = False, 

1228 use_int8_w8a16: bool = False, 

1229 use_int4_w4a16: bool = False, 

1230 per_channel_quant: bool = False, 

1231 block_shape: Optional[list[int]] = None, 

1232 B_bias: torch.Tensor | None = None, 

1233) -> None: 

1234 """Launch the fused_moe_kernel Triton kernel.""" 

1235 assert topk_weights is not None or not mul_routed_weight 

1236 assert topk_weights is None or topk_weights.stride(1) == 1 

1237 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1238 

1239 if use_fp8_w8a8 or use_int8_w8a8: 

1240 assert B_scale is not None 

1241 assert block_shape is None or triton.cdiv( 

1242 B.size(-2), block_shape[0] 

1243 ) == B_scale.size(-2) 

1244 assert block_shape is None or triton.cdiv( 

1245 B.size(-1), block_shape[1] 

1246 ) == B_scale.size(-1) 

1247 elif use_int8_w8a16 or use_int4_w4a16: 

1248 assert B_scale is not None 

1249 assert block_shape is None or block_shape[0] == 0 

1250 else: 

1251 assert A_scale is None 

1252 assert B_scale is None 

1253 

1254 M = A.size(0) 

1255 num_tokens = M * top_k 

1256 if sorted_token_ids is not None: 

1257 EM = sorted_token_ids.size(0) 

1258 if A.size(0) < config["BLOCK_SIZE_M"]: 

1259 EM = min( 

1260 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"] 

1261 ) 

1262 else: 

1263 EM = num_tokens * config["BLOCK_SIZE_M"] 

1264 grid = lambda META: ( 

1265 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1266 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1267 ) 

1268 HAS_BIAS = B_bias is not None 

1269 

1270 config = config.copy() 

1271 config["SPLIT_K"] = 1 

1272 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") 

1273 if block_shape is not None: 

1274 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) 

1275 

1276 swap_AB = config.pop("SWAP_AB", False) 

1277 fused_moe_kernel[grid]( 

1278 A, 

1279 B, 

1280 C, 

1281 B_bias, 

1282 A_scale, 

1283 B_scale, 

1284 topk_weights, 

1285 sorted_token_ids, 

1286 expert_ids, 

1287 num_tokens_post_padded, 

1288 B.size(1), # N 

1289 B.size(2), # K 

1290 EM, 

1291 num_tokens, 

1292 A.stride(0), 

1293 A.stride(1), 

1294 B.stride(0), 

1295 B.stride(2), 

1296 B.stride(1), 

1297 C.stride(1), 

1298 C.stride(2), 

1299 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, 

1300 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, 

1301 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1302 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, 

1303 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1304 B_bias.stride(0) if B_bias is not None else 0, 

1305 B_bias.stride(1) if B_bias is not None else 0, 

1306 0 if block_shape is None else block_shape[0], 

1307 0 if block_shape is None else block_shape[1], 

1308 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1309 top_k=top_k, 

1310 compute_type=compute_type, 

1311 use_fp8_w8a8=use_fp8_w8a8, 

1312 use_int8_w8a8=use_int8_w8a8, 

1313 use_int8_w8a16=use_int8_w8a16, 

1314 per_channel_quant=per_channel_quant, 

1315 naive_block_assignment=(sorted_token_ids is None), 

1316 HAS_BIAS=HAS_BIAS, 

1317 BLOCK_SIZE_K=BLOCK_SIZE_K, 

1318 SWAP_AB=swap_AB, 

1319 **config, 

1320 ) 

1321 

1322 

1323def dispatch_fused_moe_kernel( 

1324 A: torch.Tensor, 

1325 B: torch.Tensor, 

1326 C: torch.Tensor, 

1327 A_scale: Optional[torch.Tensor], 

1328 B_scale: Optional[torch.Tensor], 

1329 B_zp: Optional[torch.Tensor], 

1330 topk_weights: Optional[torch.Tensor], 

1331 sorted_token_ids: torch.Tensor, 

1332 expert_ids: torch.Tensor, 

1333 num_tokens_post_padded: torch.Tensor, 

1334 mul_routed_weight: bool, 

1335 top_k: int, 

1336 config: dict[str, Any], 

1337 compute_type: tl.dtype, 

1338 use_fp8_w8a8: bool, 

1339 use_int8_w8a8: bool, 

1340 use_int8_w8a16: bool, 

1341 use_int4_w4a16: bool, 

1342 per_channel_quant: bool, 

1343 block_shape: Optional[list[int]] = None, 

1344 B_bias: Optional[torch.Tensor] = None, 

1345) -> None: 

1346 """Dispatch to the appropriate fused MoE kernel based on quantization flags.""" 

1347 assert topk_weights is not None or not mul_routed_weight 

1348 assert topk_weights is None or topk_weights.stride(1) == 1 

1349 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1350 

1351 # M = A.size(0) 

1352 # num_tokens = M * top_k 

1353 

1354 if False: 

1355 # TODO: Other precision-specific implementations 

1356 # use_fp8_w8a8, 

1357 # use_int8_w8a8, 

1358 # use_int8_w8a16, 

1359 # use_int4_w4a16, 

1360 pass 

1361 if (use_int8_w8a16 or use_int4_w4a16) and ( 

1362 block_shape is not None and block_shape[1] > 0 

1363 ): 

1364 assert B_bias is None 

1365 invoke_fused_moe_wna16_triton_kernel( 

1366 A, 

1367 B, 

1368 C, 

1369 B_scale, 

1370 B_zp, 

1371 topk_weights, 

1372 sorted_token_ids, 

1373 expert_ids, 

1374 num_tokens_post_padded, 

1375 mul_routed_weight, 

1376 top_k, 

1377 config, 

1378 compute_type, 

1379 use_int8_w8a16, 

1380 use_int4_w4a16, 

1381 block_shape, 

1382 ) 

1383 else: 

1384 invoke_fused_moe_triton_kernel( 

1385 A, 

1386 B, 

1387 C, 

1388 A_scale, 

1389 B_scale, 

1390 topk_weights, 

1391 sorted_token_ids, 

1392 expert_ids, 

1393 num_tokens_post_padded, 

1394 mul_routed_weight, 

1395 top_k, 

1396 config, 

1397 compute_type, 

1398 use_fp8_w8a8, 

1399 use_int8_w8a8, 

1400 use_int8_w8a16, 

1401 use_int4_w4a16, 

1402 per_channel_quant, 

1403 block_shape, 

1404 B_bias, 

1405 ) 

1406 

1407 

1408def fused_experts_impl( 

1409 hidden_states: torch.Tensor, 

1410 w1: torch.Tensor, 

1411 w2: torch.Tensor, 

1412 topk_weights: torch.Tensor, 

1413 topk_ids: torch.Tensor, 

1414 inplace: bool = False, 

1415 activation: str = "silu", 

1416 apply_router_weight_on_input: bool = False, 

1417 use_fp8_w8a8: bool = False, 

1418 use_int8_w8a8: bool = False, 

1419 use_int8_w8a16: bool = False, 

1420 use_int4_w4a16: bool = False, 

1421 ocp_mx_scheme: str | None = None, 

1422 per_channel_quant: bool = False, 

1423 global_num_experts: int = -1, 

1424 expert_map: torch.Tensor | None = None, 

1425 w1_scale: Optional[torch.Tensor] = None, 

1426 w2_scale: Optional[torch.Tensor] = None, 

1427 w1_zp: torch.Tensor | None = None, 

1428 w2_zp: torch.Tensor | None = None, 

1429 a1_scale: Optional[torch.Tensor] = None, 

1430 a2_scale: Optional[torch.Tensor] = None, 

1431 block_shape: Optional[list[int]] = None, 

1432 w1_bias: Optional[torch.Tensor] = None, 

1433 w2_bias: Optional[torch.Tensor] = None, 

1434) -> torch.Tensor: 

1435 logger.debug("GEMS FUSED MOE") 

1436 assert ( 

1437 activation == "silu" 

1438 ), f"Only 'silu' activation is supported, got {activation}" 

1439 

1440 activation_enum = MoEActivation.from_str(activation) 

1441 

1442 # Check constraints 

1443 if use_int4_w4a16: 

1444 # INT4 stored unpacked in INT8 containers (full K dim) 

1445 assert hidden_states.size(1) == w1.size( 

1446 2 

1447 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1448 elif ocp_mx_scheme is not None: 

1449 if ocp_mx_scheme.startswith("w_mxfp4"): 

1450 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" 

1451 elif ocp_mx_scheme.startswith("w_mxfp6"): 

1452 assert ( 

1453 hidden_states.size(1) == (w1.size(2) * 4) // 3 

1454 ), "hidden size mismatch" 

1455 else: 

1456 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1457 else: 

1458 assert hidden_states.size(1) == w1.size( 

1459 2 

1460 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1461 

1462 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

1463 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

1464 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

1465 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

1466 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

1467 

1468 num_tokens = hidden_states.size(0) 

1469 E, N, _ = w1.size() 

1470 K = w2.size(1) 

1471 if global_num_experts == -1: 

1472 global_num_experts = E 

1473 top_k_num = topk_ids.size(1) 

1474 

1475 CHUNK_SIZE: int = 16 * 1024 

1476 M = min(num_tokens, CHUNK_SIZE) 

1477 

1478 config_dtype = _get_config_dtype_str( 

1479 use_fp8_w8a8=use_fp8_w8a8, 

1480 use_int8_w8a16=use_int8_w8a16, 

1481 use_int4_w4a16=use_int4_w4a16, 

1482 ocp_mx_scheme=ocp_mx_scheme, 

1483 dtype=hidden_states.dtype, 

1484 ) 

1485 

1486 quant_dtype = _get_config_quant_dtype( 

1487 use_fp8_w8a8=use_fp8_w8a8, 

1488 use_int8_w8a8=use_int8_w8a8, 

1489 ocp_mx_scheme=ocp_mx_scheme, 

1490 ) 

1491 

1492 get_config_func = functools.partial( 

1493 try_get_optimal_moe_config, 

1494 w1.size(), 

1495 w2.size(), 

1496 top_k_num, 

1497 config_dtype, 

1498 block_shape=block_shape, 

1499 E=E, 

1500 ) 

1501 

1502 config = get_config_func(M) 

1503 

1504 # cache1 and cache3 share memory (non-overlapping lifetime) 

1505 cache13 = torch.empty( 

1506 M * top_k_num * max(N, K), 

1507 device=hidden_states.device, 

1508 dtype=hidden_states.dtype, 

1509 ) 

1510 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

1511 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

1512 

1513 # cache2 needs separate memory (concurrent with cache1) 

1514 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

1515 intermediate_cache2 = torch.empty( 

1516 (M * top_k_num, activation_out_dim), 

1517 device=hidden_states.device, 

1518 dtype=hidden_states.dtype, 

1519 ) 

1520 

1521 if hidden_states.dtype == torch.bfloat16: 

1522 compute_type = tl.bfloat16 

1523 elif hidden_states.dtype == torch.float16: 

1524 compute_type = tl.float16 

1525 elif hidden_states.dtype == torch.float32: 

1526 compute_type = tl.float32 

1527 else: 

1528 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

1529 

1530 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

1531 

1532 if ocp_mx_scheme is not None: 

1533 # Dequantize OCP MX weights (TODO: skip on platforms with native MX) 

1534 if ocp_mx_scheme.startswith("w_mxfp4"): 

1535 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) 

1536 w1_scale = None 

1537 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) 

1538 w2_scale = None 

1539 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): 

1540 w1 = dequant_mxfp6( 

1541 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1542 ) 

1543 w1_scale = None 

1544 w2 = dequant_mxfp6( 

1545 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1546 ) 

1547 w2_scale = None 

1548 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): 

1549 w1 = dequant_mxfp6( 

1550 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1551 ) 

1552 w1_scale = None 

1553 w2 = dequant_mxfp6( 

1554 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1555 ) 

1556 w2_scale = None 

1557 else: 

1558 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1559 

1560 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot) 

1561 if use_int8_w8a16 or use_int4_w4a16: 

1562 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype) 

1563 w1_scale = None 

1564 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype) 

1565 w2_scale = None 

1566 use_int8_w8a16 = False 

1567 use_int4_w4a16 = False 

1568 

1569 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

1570 begin_chunk_idx, end_chunk_idx = ( 

1571 chunk * CHUNK_SIZE, 

1572 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

1573 ) 

1574 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

1575 tokens_in_chunk, _ = curr_hidden_states.size() 

1576 

1577 if tokens_in_chunk == 0: 

1578 break 

1579 

1580 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

1581 # Adjust cache size for last chunk 

1582 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

1583 intermediate_cache2 = intermediate_cache2[ 

1584 : tokens_in_chunk * topk_ids.size(1) 

1585 ] 

1586 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

1587 config = get_config_func(tokens_in_chunk) 

1588 

1589 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

1590 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

1591 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

1592 A=curr_hidden_states, 

1593 A_scale=a1_scale, 

1594 quant_dtype=quant_dtype, 

1595 per_act_token_quant=per_channel_quant, 

1596 block_shape=block_shape, 

1597 ocp_mx_scheme=ocp_mx_scheme, 

1598 ) 

1599 

1600 SPARSITY_FACTOR = 4 

1601 naive_block_assignment = ( 

1602 expert_map is None 

1603 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts 

1604 and not ( 

1605 (use_int8_w8a16 or use_int4_w4a16) 

1606 and block_shape is not None 

1607 and block_shape[1] > 0 

1608 ) 

1609 ) 

1610 

1611 if not naive_block_assignment: 

1612 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

1613 curr_topk_ids, 

1614 config["BLOCK_SIZE_M"], 

1615 global_num_experts, 

1616 expert_map, 

1617 # ignore_invalid_experts=True, 

1618 ) 

1619 else: 

1620 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"] 

1621 expert_ids = curr_topk_ids.view(-1) 

1622 num_tokens_post_padded = torch.empty( 

1623 (1), dtype=torch.int32, device=topk_ids.device 

1624 ) 

1625 num_tokens_post_padded.fill_(max_num_tokens_padded) 

1626 sorted_token_ids = None 

1627 

1628 dispatch_fused_moe_kernel( 

1629 qcurr_hidden_states, 

1630 w1, 

1631 intermediate_cache1, 

1632 a1q_scale, 

1633 w1_scale, 

1634 w1_zp, 

1635 curr_topk_weights, 

1636 sorted_token_ids, 

1637 expert_ids, 

1638 num_tokens_post_padded, 

1639 apply_router_weight_on_input, 

1640 top_k_num, 

1641 config, 

1642 compute_type=compute_type, 

1643 use_fp8_w8a8=use_fp8_w8a8, 

1644 use_int8_w8a8=use_int8_w8a8, 

1645 use_int8_w8a16=use_int8_w8a16, 

1646 use_int4_w4a16=use_int4_w4a16, 

1647 per_channel_quant=per_channel_quant, 

1648 block_shape=block_shape, 

1649 B_bias=w1_bias, 

1650 ) 

1651 

1652 apply_moe_activation( 

1653 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

1654 ) 

1655 

1656 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

1657 A=intermediate_cache2, 

1658 A_scale=a2_scale, 

1659 quant_dtype=quant_dtype, 

1660 per_act_token_quant=per_channel_quant, 

1661 block_shape=block_shape, 

1662 ocp_mx_scheme=ocp_mx_scheme, 

1663 ) 

1664 

1665 if expert_map is not None: 

1666 intermediate_cache3.zero_() 

1667 

1668 dispatch_fused_moe_kernel( 

1669 qintermediate_cache2, 

1670 w2, 

1671 intermediate_cache3, 

1672 a2q_scale, 

1673 w2_scale, 

1674 w2_zp, 

1675 curr_topk_weights, 

1676 sorted_token_ids, 

1677 expert_ids, 

1678 num_tokens_post_padded, 

1679 not apply_router_weight_on_input, 

1680 1, 

1681 config, 

1682 compute_type=compute_type, 

1683 use_fp8_w8a8=use_fp8_w8a8, 

1684 use_int8_w8a8=use_int8_w8a8, 

1685 use_int8_w8a16=use_int8_w8a16, 

1686 use_int4_w4a16=use_int4_w4a16, 

1687 per_channel_quant=per_channel_quant, 

1688 block_shape=block_shape, 

1689 B_bias=w2_bias, 

1690 ) 

1691 

1692 moe_sum( 

1693 intermediate_cache3.view(*intermediate_cache3.size()), 

1694 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

1695 ) 

1696 

1697 return out_hidden_states 

1698 

1699 

1700def inplace_fused_experts( 

1701 hidden_states: torch.Tensor, 

1702 w1: torch.Tensor, 

1703 w2: torch.Tensor, 

1704 topk_weights: torch.Tensor, 

1705 topk_ids: torch.Tensor, 

1706 activation: str = "silu", 

1707 apply_router_weight_on_input: bool = False, 

1708 use_fp8_w8a8: bool = False, 

1709 use_int8_w8a8: bool = False, 

1710 use_int8_w8a16: bool = False, 

1711 use_int4_w4a16: bool = False, 

1712 per_channel_quant: bool = False, 

1713 global_num_experts: int = -1, 

1714 w1_scale: Optional[torch.Tensor] = None, 

1715 w2_scale: Optional[torch.Tensor] = None, 

1716 a1_scale: Optional[torch.Tensor] = None, 

1717 a2_scale: Optional[torch.Tensor] = None, 

1718 block_shape: Optional[list[int]] = None, 

1719 w1_bias: Optional[torch.Tensor] = None, 

1720 w2_bias: Optional[torch.Tensor] = None, 

1721) -> None: 

1722 """ 

1723 In-place fused MoE: writes output directly into ``hidden_states``. 

1724 

1725 Same semantics as ``fused_experts_impl(..., inplace=True)``. 

1726 Returns None (the result is stored in ``hidden_states``). 

1727 """ 

1728 fused_experts_impl( 

1729 hidden_states, 

1730 w1, 

1731 w2, 

1732 topk_weights, 

1733 topk_ids, 

1734 inplace=True, 

1735 activation=activation, 

1736 apply_router_weight_on_input=apply_router_weight_on_input, 

1737 use_fp8_w8a8=use_fp8_w8a8, 

1738 use_int8_w8a8=use_int8_w8a8, 

1739 use_int8_w8a16=use_int8_w8a16, 

1740 use_int4_w4a16=use_int4_w4a16, 

1741 per_channel_quant=per_channel_quant, 

1742 global_num_experts=global_num_experts, 

1743 w1_scale=w1_scale, 

1744 w2_scale=w2_scale, 

1745 a1_scale=a1_scale, 

1746 a2_scale=a2_scale, 

1747 block_shape=block_shape, 

1748 w1_bias=w1_bias, 

1749 w2_bias=w2_bias, 

1750 ) 

1751 

1752 

1753def outplace_fused_experts( 

1754 hidden_states: torch.Tensor, 

1755 w1: torch.Tensor, 

1756 w2: torch.Tensor, 

1757 topk_weights: torch.Tensor, 

1758 topk_ids: torch.Tensor, 

1759 activation: str = "silu", 

1760 apply_router_weight_on_input: bool = False, 

1761 use_fp8_w8a8: bool = False, 

1762 use_int8_w8a8: bool = False, 

1763 use_int8_w8a16: bool = False, 

1764 use_int4_w4a16: bool = False, 

1765 per_channel_quant: bool = False, 

1766 global_num_experts: int = -1, 

1767 w1_scale: Optional[torch.Tensor] = None, 

1768 w2_scale: Optional[torch.Tensor] = None, 

1769 a1_scale: Optional[torch.Tensor] = None, 

1770 a2_scale: Optional[torch.Tensor] = None, 

1771 block_shape: Optional[list[int]] = None, 

1772 w1_bias: Optional[torch.Tensor] = None, 

1773 w2_bias: Optional[torch.Tensor] = None, 

1774) -> torch.Tensor: 

1775 """ 

1776 Out-of-place fused MoE: allocates and returns a new output tensor. 

1777 

1778 Same semantics as ``fused_experts_impl(..., inplace=False)``. 

1779 """ 

1780 return fused_experts_impl( 

1781 hidden_states, 

1782 w1, 

1783 w2, 

1784 topk_weights, 

1785 topk_ids, 

1786 inplace=False, 

1787 activation=activation, 

1788 apply_router_weight_on_input=apply_router_weight_on_input, 

1789 use_fp8_w8a8=use_fp8_w8a8, 

1790 use_int8_w8a8=use_int8_w8a8, 

1791 use_int8_w8a16=use_int8_w8a16, 

1792 use_int4_w4a16=use_int4_w4a16, 

1793 per_channel_quant=per_channel_quant, 

1794 global_num_experts=global_num_experts, 

1795 w1_scale=w1_scale, 

1796 w2_scale=w2_scale, 

1797 a1_scale=a1_scale, 

1798 a2_scale=a2_scale, 

1799 block_shape=block_shape, 

1800 w1_bias=w1_bias, 

1801 w2_bias=w2_bias, 

1802 )