Coverage for src/flag_gems/runtime/backend/_ascend/fused/fused_moe.py: 0%

687 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# 

4# Adapted from the vLLM project (https://github.com/vllm-project/vllm). 

5# Source files under vllm/model_executor/layers/: 

6# fused_moe/fused_moe.py – Triton kernels, dispatch, fused_experts_impl 

7# fused_moe/activation.py – MoEActivation enum, apply_moe_activation 

8# fused_moe/utils.py – _fp8_quantize, _int8_quantize, moe_kernel_quantize_input 

9# fused_moe/config.py – _get_config_dtype_str 

10# quantization/utils/mxfp4_utils.py – dequant_mxfp4 

11# quantization/utils/mxfp6_utils.py – dequant_mxfp6 

12# quantization/utils/ocp_mx_utils.py – OCP_MX_BLOCK_SIZE 

13 

14 

15import functools 

16import logging 

17import os 

18from enum import Enum 

19from typing import Any, Optional 

20 

21import torch 

22import torch.nn.functional as F 

23import triton 

24import triton.language as tl 

25import yaml 

26 

27# Using relative imports will cause the module to be not found. 

28from flag_gems.runtime.backend._ascend.fused.moe_align_block_size import ( 

29 moe_align_block_size, 

30) 

31from flag_gems.runtime.backend._ascend.fused.moe_sum import moe_sum 

32from flag_gems.utils import pointwise_dynamic 

33 

34logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

35 

36# OCP MX quantization helpers (requires amd-quark) 

37 

38OCP_MX_BLOCK_SIZE = 32 

39 

40 

41@functools.lru_cache(maxsize=1) 

42def get_embedded_moe_configs(): 

43 config_path = os.path.join( 

44 os.path.dirname(__file__), "..", "utils", "configs", "fused_moe_config.yaml" 

45 ) 

46 if not os.path.exists(config_path): 

47 return {}, {} 

48 with open(config_path, "r") as f: 

49 # JSON keys are strings, values are dicts where keys are M and values are configs 

50 data = yaml.safe_load(f) 

51 

52 fallback = data.get("_FALLBACK", {}) 

53 

54 # We need to convert the innermost keys (which are stringified integers for M) back to integers. 

55 # Ensure we map the lists back to config dicts. 

56 keys_order = [ 

57 "BLOCK_SIZE_M", 

58 "BLOCK_SIZE_N", 

59 "BLOCK_SIZE_K", 

60 "GROUP_SIZE_M", 

61 "num_warps", 

62 "num_stages", 

63 ] 

64 parsed_data = {} 

65 for dev, configs in data.items(): 

66 if dev == "_FALLBACK": 

67 continue 

68 parsed_data[dev] = {} 

69 for k, m_dict in configs.items(): 

70 parsed_dict = {} 

71 for m, v in m_dict.items(): 

72 if isinstance(v, list): 

73 parsed_dict[int(m)] = dict(zip(keys_order, v)) 

74 else: 

75 parsed_dict[int(m)] = v 

76 parsed_data[dev][k] = parsed_dict 

77 

78 return parsed_data, fallback 

79 

80 

81def dequant_mxfp4( 

82 x: torch.Tensor, 

83 scale: torch.Tensor, 

84 float_dtype: torch.dtype, 

85) -> torch.Tensor: 

86 """Dequantize MXFP4 tensor via quark.torch.kernel.mx.dq_mxfp4.""" 

87 try: 

88 from quark.torch.kernel import mx 

89 except ImportError as err: 

90 raise ImportError("amd-quark is required for MX-FP4") from err 

91 

92 return mx.dq_mxfp4(x, scale, float_dtype) 

93 

94 

95def dequant_mxfp6( 

96 x: torch.Tensor, 

97 scale: torch.Tensor, 

98 float_dtype: torch.dtype, 

99 quant_dtype: str, 

100) -> torch.Tensor: 

101 """Dequantize MXFP6 tensor via quark hw_emulation.""" 

102 try: 

103 from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( 

104 dequantize_fp4_fp6_per_group, 

105 ) 

106 from quark.torch.utils.pack import create_pack_method 

107 except ImportError as err: 

108 raise ImportError("amd-quark is required for MX-FP6") from err 

109 

110 pack_method = create_pack_method(None, dtype=quant_dtype) 

111 unpacked_x = pack_method.unpack(x, reorder=False) 

112 

113 scale = 2 ** (scale.view(torch.uint8).to(torch.int16) - 127).to(float_dtype) 

114 

115 return dequantize_fp4_fp6_per_group( 

116 unpacked_x, 

117 scale, 

118 axis=-1, 

119 group_size=OCP_MX_BLOCK_SIZE, 

120 quant_dtype=quant_dtype, 

121 ).to(float_dtype) 

122 

123 

124# Activation quantization helpers 

125 

126 

127@functools.lru_cache(maxsize=1) 

128def _get_device_name() -> str: 

129 """Return the normalised CUDA device name (spaces replaced by underscores). 

130 

131 Matches the naming convention used by vLLM for its per-device config files. 

132 H800 falls back to H100_80GB_HBM3 (same SM 9.0 architecture). 

133 """ 

134 name = torch.npu.get_device_name().replace(" ", "_") 

135 # Normalise the H200 product family to a single key, following vLLM. 

136 if "H200" in name.split("_"): 

137 name = "NVIDIA_H200" 

138 # H800 has the same SM 9.0 as H100; use H100 configs as fallback. 

139 embedded_configs, fallback_mapping = get_embedded_moe_configs() 

140 if name in embedded_configs: 

141 return name 

142 # Fallback mapping for devices whose tuning profiles are equivalent. 

143 fallback = fallback_mapping.get(name) 

144 if fallback and fallback in embedded_configs: 

145 logger.info("Device %s not in config table, falling back to %s", name, fallback) 

146 return fallback 

147 return name 

148 

149 

150def get_moe_configs( 

151 E: int, 

152 N: int, 

153 dtype: str | None, 

154 block_n: int | None = None, 

155 block_k: int | None = None, 

156) -> dict[int, Any] | None: 

157 """ 

158 Return optimized configurations for the fused MoE kernel. 

159 

160 Looks up pre-tuned configs from the embedded table (ported from vLLM) 

161 for the current GPU device. Returns None if no matching config is found. 

162 """ 

163 device_name = _get_device_name() 

164 embedded_configs, _ = get_embedded_moe_configs() 

165 device_table = embedded_configs.get(device_name) 

166 if device_table is None: 

167 logger.warning( 

168 "No embedded MoE configs for device %s. Will use default config.", 

169 device_name, 

170 ) 

171 return None 

172 

173 _block_n = block_n if block_n else 0 

174 _block_k = block_k if block_k else 0 

175 key = f"{E},{N},{dtype},{_block_n},{_block_k}" 

176 configs = device_table.get(key) 

177 if configs is not None: 

178 logger.info("Using embedded MoE config for device=%s, key=%s", device_name, key) 

179 return configs 

180 logger.warning( 

181 "No embedded MoE config for device=%s, key=%s. Will use default config.", 

182 device_name, 

183 key, 

184 ) 

185 return None 

186 

187 

188def try_get_optimal_moe_config( 

189 w1_shape: tuple[int, ...], 

190 w2_shape: tuple[int, ...], 

191 top_k: int, 

192 dtype: str | None, 

193 M: int, 

194 block_shape: list[int] | None = None, 

195) -> dict[str, int]: 

196 override_config: Optional[dict[str, Any]] = None 

197 if override_config: 

198 config = override_config 

199 else: 

200 # First try to load optimal config from the file 

201 E, _, N = w2_shape 

202 if dtype == "int4_w4a16": 

203 N = N * 2 

204 block_n = block_shape[0] if block_shape else 0 

205 block_k = block_shape[1] if block_shape else 0 

206 configs = get_moe_configs(E, N, dtype, block_n, block_k) 

207 

208 if configs: 

209 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

210 else: 

211 config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) 

212 return config 

213 

214 

215def _get_config_quant_dtype( 

216 use_fp8_w8a8: bool, 

217 use_int8_w8a8: bool, 

218 ocp_mx_scheme: str | None, 

219) -> None | torch.dtype | str: 

220 """Map quantization flags to the corresponding dtype.""" 

221 if use_fp8_w8a8: 

222 return torch.float8_e4m3fn 

223 elif use_int8_w8a8: 

224 return torch.int8 

225 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": 

226 return "mxfp4" 

227 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e3m2", "w_mxfp6_e3m2_a_mxfp6_e3m2"}: 

228 return "mxfp6_e3m2" 

229 elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: 

230 return "mxfp6_e2m3" 

231 elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}: 

232 return torch.bfloat16 

233 elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}: 

234 return torch.float8_e4m3fn 

235 

236 return None 

237 

238 

239def get_moe_wna16_block_config( 

240 config: dict[str, int], 

241 use_moe_wna16_cuda: bool, 

242 num_valid_tokens: int, 

243 size_k: int, 

244 size_n: int, 

245 num_experts: int, 

246 group_size: int, 

247 real_top_k: int, 

248 block_size_m: int, 

249): 

250 if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: 

251 return {} 

252 if not use_moe_wna16_cuda: 

253 if num_valid_tokens // real_top_k == 1: 

254 return {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64} 

255 else: 

256 return {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32} 

257 else: 

258 block_size_n = 128 

259 block_size_k = 128 

260 if block_size_k <= group_size: 

261 block_size_k = group_size 

262 

263 num_n_blocks = size_k // block_size_k 

264 num_k_blocks = size_n // block_size_k 

265 num_m_blocks = ( 

266 num_valid_tokens + block_size_m - 1 

267 ) / block_size_m + num_experts 

268 if num_valid_tokens // real_top_k <= block_size_m: 

269 num_m_blocks = min(num_m_blocks, num_valid_tokens) 

270 num_blocks = num_m_blocks * num_n_blocks * num_k_blocks 

271 

272 if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: 

273 block_size_k = 256 

274 num_blocks = num_blocks // (256 // block_size_k) 

275 

276 if ( 

277 num_m_blocks <= 16 

278 and size_k % (block_size_k * 2) == 0 

279 and size_k % (block_size_k * 2) == 0 

280 and block_size_k <= 512 

281 and num_blocks >= 512 

282 ): 

283 block_size_k = block_size_k * 2 

284 num_blocks = num_blocks // 2 

285 

286 if num_blocks > 1024: 

287 block_size_n = 256 

288 num_n_blocks = num_n_blocks // 2 

289 num_blocks = num_blocks // 2 

290 

291 if size_n <= 1024 and num_blocks >= 1024: 

292 block_size_n = 1024 

293 

294 block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size) 

295 

296 return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} 

297 

298 

299def get_default_config( 

300 M: int, 

301 E: int, 

302 N: int, 

303 K: int, 

304 topk: int, 

305 dtype: str | None, 

306 block_shape: list[int] | None = None, 

307) -> dict[str, int]: 

308 """Default Triton config for fused MoE kernel. 

309 

310 Heuristic selection aligned with vLLM v0.17.0 defaults, tuned on H20/H100. 

311 Key insight: for high-expert-count MoE (e.g. DeepSeek-V3 E=256), each 

312 expert sees very few tokens, so small BLOCK_SIZE_M (16) is critical. 

313 """ 

314 if dtype == "fp8_w8a8" and block_shape is not None: 

315 config = { 

316 "BLOCK_SIZE_M": 16 if M <= 64 else 64, 

317 "BLOCK_SIZE_N": block_shape[0], 

318 "BLOCK_SIZE_K": block_shape[1], 

319 "GROUP_SIZE_M": 1 if M <= 16 else 32, 

320 "num_warps": 4, 

321 "num_stages": 3, 

322 } 

323 else: 

324 # tokens_per_expert drives block_m: use M//E (not M*topk//E) to 

325 # estimate the actual per-expert token count after routing. 

326 tokens_per_expert = M // max(E, 1) 

327 

328 if tokens_per_expert <= 2: 

329 block_m = 16 

330 elif tokens_per_expert <= 4: 

331 block_m = 32 

332 elif tokens_per_expert <= 16: 

333 block_m = 64 

334 else: 

335 block_m = 128 

336 

337 # Tile sizing 

338 if N >= 4096: 

339 block_n = 128 if M <= 128 else 256 

340 elif N >= 1024: 

341 block_n = 64 if M <= 64 else 128 

342 else: 

343 block_n = 64 if M <= 64 else 128 

344 

345 if dtype == "fp8_w8a8": 

346 block_k = 128 

347 else: 

348 # Cap BLOCK_SIZE_K at 32: BK=64 with BM≥64 triggers 

349 # triton-ascend compiler errors ('vsel' unsupported) on 

350 # large shapes (e.g. Mixtral N=28672). 

351 block_k = 32 

352 

353 if tokens_per_expert > 128: 

354 group_m = 16 

355 elif tokens_per_expert > 32: 

356 group_m = 8 

357 else: 

358 group_m = 1 

359 

360 # Adaptive stages: optimize for different M sizes 

361 # Small M: use more stages for better data reuse 

362 # Large M: use fewer stages to reduce memory overhead 

363 if M <= 64: 

364 num_stages = 2 

365 num_warps = 2 

366 elif M <= 256: 

367 num_stages = 2 

368 num_warps = 4 

369 else: 

370 num_stages = 2 

371 num_warps = 4 

372 

373 # UB budget check for Ascend NPU (192KB = 196608 bytes) 

374 # Account for: A tile (BM*BK*2) + B tile (BK*BN*2) + Accumulator (BM*BN*4) 

375 # Use tighter safety factor for small M 

376 UB_LIMIT = 196608 

377 SAFETY_FACTOR = 0.65 if M <= 64 else 0.70 

378 

379 ub_per_stage = ( 

380 block_m * block_k * 2 # A tile (bf16) 

381 + block_k * block_n * 2 # B tile (bf16) 

382 + block_m * block_n * 4 # Accumulator (fp32) 

383 ) 

384 

385 # Reduce stages if needed 

386 while num_stages > 1 and ub_per_stage * num_stages > UB_LIMIT * SAFETY_FACTOR: 

387 num_stages -= 1 

388 

389 # Reduce block sizes if still over budget 

390 while num_stages >= 1 and ub_per_stage * num_stages > UB_LIMIT * SAFETY_FACTOR: 

391 # Reduce in order: BN (biggest impact), BM, BK 

392 if block_n > 64: 

393 block_n = max(64, block_n // 2) 

394 elif block_m > 16: 

395 block_m = max(16, block_m // 2) 

396 elif block_k > 16: 

397 block_k = max(16, block_k // 2) 

398 else: 

399 break 

400 ub_per_stage = ( 

401 block_m * block_k * 2 + block_k * block_n * 2 + block_m * block_n * 4 

402 ) 

403 

404 config = { 

405 "BLOCK_SIZE_M": block_m, 

406 "BLOCK_SIZE_N": block_n, 

407 "BLOCK_SIZE_K": block_k, 

408 "GROUP_SIZE_M": group_m, 

409 "num_warps": num_warps, 

410 "num_stages": num_stages, 

411 } 

412 return config 

413 

414 

415def _get_config_dtype_str( 

416 dtype: Optional[torch.dtype] = None, 

417 use_fp8_w8a8: bool = False, 

418 use_fp8_w8a16: bool = False, 

419 use_int8_w8a16: bool = False, 

420 use_int4_w4a16: bool = False, 

421 ocp_mx_scheme: str | None = None, 

422) -> str | None: 

423 """Return dtype string for kernel config lookup.""" 

424 if use_fp8_w8a8: 

425 return "fp8_w8a8" 

426 elif use_fp8_w8a16: 

427 return "fp8_w8a16" 

428 elif use_int8_w8a16: 

429 return "int8_w8a16" 

430 elif use_int4_w4a16: 

431 return "int4_w4a16" 

432 elif ocp_mx_scheme is not None: 

433 return None 

434 elif dtype == torch.float: 

435 return "float32" 

436 return None 

437 

438 

439# MoE activation enum 

440 

441 

442class MoEActivation(Enum): 

443 """Activation functions for MoE layers.""" 

444 

445 # Gated: gate * activation(up), input [..., 2*d] -> output [..., d] 

446 SILU = "silu" 

447 GELU = "gelu" 

448 RELU2 = "relu2" 

449 SWIGLUOAI = "swigluoai" 

450 SWIGLUSTEP = "swiglustep" 

451 

452 # Non-gated: input [..., d] -> output [..., d] 

453 SILU_NO_MUL = "silu_no_mul" 

454 GELU_NO_MUL = "gelu_no_mul" 

455 RELU2_NO_MUL = "relu2_no_mul" 

456 

457 @property 

458 def is_gated(self) -> bool: 

459 return not self.value.endswith("_no_mul") 

460 

461 def without_mul(self) -> "MoEActivation": 

462 """Return the non-gated variant.""" 

463 _without_mul: dict[MoEActivation, MoEActivation] = { 

464 MoEActivation.SILU: MoEActivation.SILU_NO_MUL, 

465 MoEActivation.GELU: MoEActivation.GELU_NO_MUL, 

466 MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, 

467 } 

468 return _without_mul.get(self, self) 

469 

470 @classmethod 

471 def from_str(cls, s: str) -> "MoEActivation": 

472 for member in cls: 

473 if member.value == s: 

474 return member 

475 valid = [m.value for m in cls] 

476 raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}") 

477 

478 @staticmethod 

479 def adjust_N_for_activation(N: int, activation: "MoEActivation") -> int: 

480 """Return N for non-gated, N // 2 for gated activations.""" 

481 return N if not activation.is_gated else N // 2 

482 

483 

484def apply_moe_activation( 

485 activation: MoEActivation, 

486 output: torch.Tensor, 

487 input: torch.Tensor, 

488) -> torch.Tensor: 

489 """Apply MoE activation (pure PyTorch / FlagGems Triton).""" 

490 assert input.dim() == 2, "Input must be 2D" 

491 assert output.dim() == 2, "Output must be 2D" 

492 if activation.is_gated: 

493 assert output.size(-1) * 2 == input.size( 

494 -1 

495 ), f"{activation.value} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}" 

496 else: 

497 assert output.size(-1) == input.size( 

498 -1 

499 ), f"{activation.value} expects equal sizes: {output.size(-1)} vs {input.size(-1)}" 

500 

501 if activation in (MoEActivation.SILU, MoEActivation.SWIGLUOAI): 

502 N = output.size(-1) 

503 x, y = input[:, :N], input[:, N:] 

504 _silu_and_mul_kernel(x, y, out0=output) 

505 elif activation == MoEActivation.GELU: 

506 N = output.size(-1) 

507 gate, up = input[:, :N], input[:, N:] 

508 output.copy_(F.gelu(gate) * up) 

509 elif activation == MoEActivation.SWIGLUSTEP: 

510 N = output.size(-1) 

511 gate, up = input[:, :N], input[:, N:] 

512 output.copy_(torch.sigmoid(gate) * up) 

513 elif activation == MoEActivation.RELU2: 

514 N = output.size(-1) 

515 gate, up = input[:, :N], input[:, N:] 

516 output.copy_(F.relu(gate).square() * up) 

517 

518 elif activation == MoEActivation.SILU_NO_MUL: 

519 output.copy_(F.silu(input)) 

520 elif activation == MoEActivation.GELU_NO_MUL: 

521 output.copy_(F.gelu(input)) 

522 elif activation == MoEActivation.RELU2_NO_MUL: 

523 F.relu(input, inplace=True) 

524 torch.square(input, out=output) 

525 else: 

526 raise ValueError(f"Unsupported FusedMoe activation: {activation}") 

527 

528 return output 

529 

530 

531def _fp8_quantize( 

532 A: torch.Tensor, 

533 A_scale: Optional[torch.Tensor], 

534 per_act_token: bool, 

535 block_shape: Optional[list[int]] = None, 

536) -> tuple[torch.Tensor, torch.Tensor]: 

537 """FP8 E4M3 quantization: per-tensor, per-token, or block-wise.""" 

538 fp8_dtype = torch.float8_e4m3fn 

539 finfo = torch.finfo(fp8_dtype) 

540 fp8_max = finfo.max 

541 fp8_min = finfo.min 

542 eps = 1e-10 

543 

544 if block_shape is not None: 

545 assert not per_act_token 

546 assert len(block_shape) == 2 

547 block_k = block_shape[1] 

548 assert A.size(-1) % block_k == 0 

549 orig_shape = A.shape 

550 A_flat = A.reshape(-1, A.size(-1)) 

551 M, K = A_flat.shape 

552 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

553 amax = ( 

554 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

555 ) 

556 scale = amax / fp8_max 

557 A_q = (A_groups.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

558 A_q = A_q.reshape(orig_shape) 

559 scale = scale.reshape(M, K // block_k) 

560 return A_q, scale 

561 

562 elif per_act_token: 

563 A_flat = A.reshape(-1, A.size(-1)) 

564 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

565 scale = amax / fp8_max 

566 min_scale = torch.tensor( 

567 1.0 / (fp8_max * 512.0), dtype=torch.float32, device=A.device 

568 ) 

569 scale = scale.clamp(min=min_scale) 

570 A_q = (A_flat.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

571 A_q = A_q.reshape(A.shape) 

572 scale = scale.reshape(A.shape[:-1] + (1,)) 

573 return A_q, scale 

574 

575 else: 

576 if A_scale is not None: 

577 scale = ( 

578 A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

579 ) 

580 A_q = (A.float() / scale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

581 return A_q, A_scale 

582 else: 

583 amax = A.abs().amax().clamp(min=eps).to(torch.float32) 

584 scale = amax / fp8_max 

585 iscale = 1.0 / scale 

586 A_q = (A.float() * iscale).clamp(fp8_min, fp8_max).to(fp8_dtype) 

587 return A_q, scale.view(1) 

588 

589 

590def _int8_quantize( 

591 A: torch.Tensor, 

592 A_scale: Optional[torch.Tensor], 

593 per_act_token: bool, 

594 block_shape: Optional[list[int]] = None, 

595) -> tuple[torch.Tensor, torch.Tensor]: 

596 """INT8 quantization: per-tensor, per-token, or block-wise.""" 

597 iinfo = torch.iinfo(torch.int8) 

598 int8_max = iinfo.max 

599 int8_min = iinfo.min 

600 eps = 1e-10 

601 

602 if block_shape is not None: 

603 assert not per_act_token 

604 assert len(block_shape) == 2 

605 block_k = block_shape[1] 

606 assert A.size(-1) % block_k == 0 

607 orig_shape = A.shape 

608 A_flat = A.reshape(-1, A.size(-1)) 

609 M, K = A_flat.shape 

610 A_groups = A_flat.reshape(M * (K // block_k), block_k) 

611 amax = ( 

612 A_groups.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

613 ) 

614 scale = amax / int8_max 

615 A_q = ( 

616 (A_groups.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

617 ) 

618 A_q = A_q.reshape(orig_shape) 

619 scale = scale.reshape(M, K // block_k) 

620 return A_q, scale 

621 

622 elif per_act_token: 

623 A_flat = A.reshape(-1, A.size(-1)) 

624 amax = A_flat.abs().amax(dim=-1, keepdim=True).clamp(min=eps).to(torch.float32) 

625 scale = amax / int8_max 

626 A_q = (A_flat.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

627 A_q = A_q.reshape(A.shape) 

628 scale = scale.reshape(A.shape[:-1] + (1,)) 

629 return A_q, scale 

630 

631 else: 

632 assert A_scale is not None, "int8 per-tensor requires A_scale" 

633 scale = A_scale.float().view(1, 1) if A_scale.numel() == 1 else A_scale.float() 

634 A_q = (A.float() / scale).round().clamp(int8_min, int8_max).to(torch.int8) 

635 return A_q, A_scale 

636 

637 

638def moe_kernel_quantize_input( 

639 A: torch.Tensor, 

640 A_scale: Optional[torch.Tensor], 

641 quant_dtype: None | torch.dtype | str, 

642 per_act_token_quant: bool, 

643 block_shape: Optional[list[int]] = None, 

644 ocp_mx_scheme: str | None = None, 

645) -> tuple[torch.Tensor, Optional[torch.Tensor]]: 

646 """Quantize MoE input activations before GEMM.""" 

647 if ocp_mx_scheme is not None: 

648 if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}: 

649 pass 

650 elif ocp_mx_scheme.endswith("a_fp8"): 

651 qA, qA_scale = _fp8_quantize(A, A_scale, per_act_token=False) 

652 A = (qA.float() * qA_scale.float()).to(A.dtype) 

653 return A, None 

654 

655 if quant_dtype is None: 

656 return A, A_scale 

657 elif quant_dtype == torch.float8_e4m3fn: 

658 return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) 

659 elif quant_dtype == torch.int8: 

660 return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) 

661 else: 

662 return A, A_scale 

663 

664 

665def _ensure_block_size_k_divisible( 

666 size_k: int, block_size_k: int, group_size: int 

667) -> int: 

668 """Find largest block_size_k that divides size_k and is divisible by group_size.""" 

669 if size_k % block_size_k == 0 and block_size_k % group_size == 0: 

670 return block_size_k 

671 

672 max_search = min(block_size_k, size_k) 

673 start = (max_search // group_size) * group_size 

674 for candidate in range(start, group_size - 1, -group_size): 

675 if size_k % candidate == 0: 

676 return candidate 

677 

678 if size_k % group_size == 0: 

679 return group_size 

680 

681 return size_k 

682 

683 

684@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

685@triton.jit 

686def _silu_and_mul_kernel(x, y): 

687 x_fp32 = x.to(tl.float32) 

688 x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) 

689 return x_silu * y 

690 

691 

692@triton.jit 

693def write_zeros_to_output( 

694 c_ptr, 

695 stride_cm, 

696 stride_cn, 

697 pid_n, 

698 N, 

699 offs_token, 

700 token_mask, 

701 BLOCK_SIZE_M, 

702 BLOCK_SIZE_N, 

703 compute_type, 

704): 

705 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) 

706 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

707 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

708 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

709 tl.store(c_ptrs, accumulator, mask=c_mask) 

710 

711 

712@triton.jit 

713def fused_moe_kernel_gptq_awq( 

714 # Pointers to matrices 

715 a_ptr, 

716 b_ptr, 

717 c_ptr, 

718 b_scale_ptr, 

719 b_zp_ptr, 

720 topk_weights_ptr, 

721 sorted_token_ids_ptr, 

722 expert_ids_ptr, 

723 num_tokens_post_padded_ptr, 

724 # Matrix dimensions 

725 N: tl.constexpr, 

726 K: tl.constexpr, 

727 EM, 

728 num_valid_tokens, 

729 # The stride variables represent how much to increase the ptr by when 

730 # moving by 1 element in a particular dimension. E.g. `stride_am` is 

731 # how much to increase `a_ptr` by to get the element one row down 

732 # (A has M rows). 

733 stride_am, 

734 stride_ak, 

735 stride_be, 

736 stride_bk, 

737 stride_bn, 

738 stride_cm, 

739 stride_cn, 

740 stride_bse, 

741 stride_bsk, 

742 stride_bsn, 

743 stride_bze, 

744 stride_bzk, 

745 stride_bzn, 

746 block_k_diviable: tl.constexpr, 

747 group_size: tl.constexpr, 

748 # Meta-parameters 

749 BLOCK_SIZE_M: tl.constexpr, 

750 BLOCK_SIZE_N: tl.constexpr, 

751 BLOCK_SIZE_K: tl.constexpr, 

752 GROUP_SIZE_M: tl.constexpr, 

753 SPLIT_K: tl.constexpr, 

754 MUL_ROUTED_WEIGHT: tl.constexpr, 

755 top_k: tl.constexpr, 

756 compute_type: tl.constexpr, 

757 has_zp: tl.constexpr, 

758 use_int4_w4a16: tl.constexpr, 

759 use_int8_w8a16: tl.constexpr, 

760): 

761 """Fused MoE kernel for GPTQ/AWQ (WNA16) quantized weights.""" 

762 # Map pid to C block (grouped ordering for L2 reuse) 

763 pid = tl.program_id(axis=0) 

764 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

765 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

766 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

767 group_id = pid // num_pid_in_group 

768 first_pid_m = group_id * GROUP_SIZE_M 

769 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

770 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

771 pid_n = (pid % num_pid_in_group) // group_size_m 

772 

773 # Create pointers for first blocks of A and B 

774 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

775 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

776 return 

777 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

778 # Cast to int64 to prevent overflow in stride*offset products 

779 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

780 token_mask = offs_token < num_valid_tokens 

781 

782 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

783 if off_experts == -1: 

784 # ----------------------------------------------------------- 

785 # Write back zeros to the output when the expert is not 

786 # in the current expert parallel rank. 

787 write_zeros_to_output( 

788 c_ptr, 

789 stride_cm, 

790 stride_cn, 

791 pid_n, 

792 N, 

793 offs_token, 

794 token_mask, 

795 BLOCK_SIZE_M, 

796 BLOCK_SIZE_N, 

797 compute_type, 

798 ) 

799 return 

800 

801 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

802 offs_k = tl.arange(0, BLOCK_SIZE_K) 

803 a_ptrs = a_ptr + ( 

804 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

805 ) 

806 

807 if use_int4_w4a16: 

808 b_ptrs = ( 

809 b_ptr 

810 + off_experts * stride_be 

811 + (offs_k[:, None] // 2) * stride_bk 

812 + offs_bn[None, :] * stride_bn 

813 ) 

814 b_shifter = (offs_k[:, None] % 2) * 4 

815 elif use_int8_w8a16: 

816 b_ptrs = ( 

817 b_ptr 

818 + off_experts * stride_be 

819 + offs_k[:, None] * stride_bk 

820 + offs_bn[None, :] * stride_bn 

821 ) 

822 

823 if not has_zp and use_int4_w4a16: 

824 b_zp_num = 8 

825 if not has_zp and use_int8_w8a16: 

826 b_zp_num = 128 

827 elif has_zp and use_int4_w4a16: 

828 b_zp_shifter = (offs_bn[None, :] % 2) * 4 

829 

830 # Accumulate C block in fp32 

831 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

832 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

833 if not block_k_diviable: 

834 k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K 

835 k_other = 0.0 

836 else: 

837 k_mask = None 

838 k_other = None 

839 

840 a = tl.load( 

841 a_ptrs, 

842 mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

843 other=0.0, 

844 ) 

845 b = tl.load(b_ptrs) 

846 if use_int4_w4a16: 

847 b = (b >> b_shifter) & 0xF 

848 

849 b_scale_ptrs = ( 

850 b_scale_ptr 

851 + off_experts * stride_bse 

852 + offs_bn[None, :] * stride_bsn 

853 + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk 

854 ) 

855 b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) 

856 b_scale = b_scale.to(tl.float32) 

857 

858 if has_zp and use_int4_w4a16: 

859 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

860 b_zp_ptrs = ( 

861 b_zp_ptr 

862 + off_experts * stride_bze 

863 + (offs_bn[None, :] // 2) * stride_bzn 

864 + offs_k_true * stride_bzk 

865 ) 

866 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

867 b_zp = (b_zp >> b_zp_shifter) & 0xF 

868 b_zp = b_zp.to(tl.float32) 

869 elif has_zp and use_int8_w8a16: 

870 offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size 

871 b_zp_ptrs = ( 

872 b_zp_ptr 

873 + off_experts * stride_bze 

874 + offs_bn[None, :] * stride_bzn 

875 + offs_k_true * stride_bzk 

876 ) 

877 b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) 

878 b_zp = b_zp.to(tl.float32) 

879 

880 if has_zp: 

881 b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) 

882 else: 

883 b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) 

884 accumulator = tl.dot(a, b, acc=accumulator) 

885 

886 a_ptrs += BLOCK_SIZE_K * stride_ak 

887 if use_int4_w4a16: 

888 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk 

889 else: 

890 b_ptrs += BLOCK_SIZE_K * stride_bk 

891 

892 if MUL_ROUTED_WEIGHT: 

893 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

894 accumulator = accumulator * moe_weight[:, None] 

895 

896 accumulator = accumulator.to(compute_type) 

897 # Write back output 

898 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

899 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

900 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

901 tl.store(c_ptrs, accumulator, mask=c_mask) 

902 

903 

904@triton.jit 

905def fused_moe_kernel( 

906 # Pointers to matrices 

907 a_ptr, 

908 b_ptr, 

909 c_ptr, 

910 b_bias_ptr, 

911 a_scale_ptr, 

912 b_scale_ptr, 

913 topk_weights_ptr, 

914 sorted_token_ids_ptr, 

915 expert_ids_ptr, 

916 num_tokens_post_padded_ptr, 

917 # Matrix dimensions 

918 N, 

919 K, 

920 EM, 

921 num_valid_tokens, 

922 stride_am, 

923 stride_ak, 

924 stride_be, 

925 stride_bk, 

926 stride_bn, 

927 stride_cm, 

928 stride_cn, 

929 stride_asm, 

930 stride_ask, 

931 stride_bse, 

932 stride_bsk, 

933 stride_bsn, 

934 stride_bbe, # bias expert stride 

935 stride_bbn, # bias N stride 

936 # Block size for block-wise quantization 

937 group_n: tl.constexpr, 

938 group_k: tl.constexpr, 

939 naive_block_assignment: tl.constexpr, 

940 # Meta-parameters 

941 BLOCK_SIZE_M: tl.constexpr, 

942 BLOCK_SIZE_N: tl.constexpr, 

943 BLOCK_SIZE_K: tl.constexpr, 

944 GROUP_SIZE_M: tl.constexpr, 

945 SPLIT_K: tl.constexpr, 

946 MUL_ROUTED_WEIGHT: tl.constexpr, 

947 top_k: tl.constexpr, 

948 compute_type: tl.constexpr, 

949 use_fp8_w8a8: tl.constexpr, 

950 use_int8_w8a8: tl.constexpr, 

951 use_int8_w8a16: tl.constexpr, 

952 per_channel_quant: tl.constexpr, 

953 HAS_BIAS: tl.constexpr, 

954): 

955 """Fused MoE kernel: token × expert GEMM with quantization support.""" 

956 # Map pid to C block (grouped ordering for L2 reuse) 

957 pid = tl.program_id(axis=0) 

958 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

959 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

960 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

961 group_id = pid // num_pid_in_group 

962 first_pid_m = group_id * GROUP_SIZE_M 

963 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

964 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

965 pid_n = (pid % num_pid_in_group) // group_size_m 

966 

967 # Create pointers for first blocks of A and B 

968 offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

969 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

970 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

971 return 

972 if not naive_block_assignment: 

973 offs_token_id = pid_m * BLOCK_SIZE_M + offs 

974 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) 

975 else: 

976 offs_token = tl.where( 

977 offs == 0, 

978 pid_m, # first element = pid_m 

979 num_valid_tokens, # remaining elements = constant 

980 ) 

981 offs_token = offs_token.to(tl.int64) # prevent int32 overflow 

982 

983 token_mask = offs_token < num_valid_tokens 

984 

985 offs_token = tl.where(token_mask, offs_token, 0) 

986 

987 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

988 if off_experts == -1: 

989 # Expert not in current EP rank, write zeros 

990 write_zeros_to_output( 

991 c_ptr, 

992 stride_cm, 

993 stride_cn, 

994 pid_n, 

995 N, 

996 offs_token, 

997 token_mask, 

998 BLOCK_SIZE_M, 

999 BLOCK_SIZE_N, 

1000 compute_type, 

1001 ) 

1002 return 

1003 

1004 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

1005 offs_k = tl.arange(0, BLOCK_SIZE_K) 

1006 a_ptrs = a_ptr + ( 

1007 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

1008 ) 

1009 

1010 b_ptrs = ( 

1011 b_ptr 

1012 + off_experts * stride_be 

1013 + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

1014 ) 

1015 if use_int8_w8a16: 

1016 b_scale_ptrs = ( 

1017 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

1018 ) 

1019 b_scale = tl.load(b_scale_ptrs) 

1020 

1021 if use_fp8_w8a8 or use_int8_w8a8: 

1022 if group_k > 0 and group_n > 0: # block-wise 

1023 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1024 offs_bsn = offs_bn // group_n 

1025 b_scale_ptrs = ( 

1026 b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn 

1027 ) 

1028 elif per_channel_quant: # channel-wise 

1029 b_scale_ptrs = ( 

1030 b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn 

1031 ) 

1032 b_scale = tl.load(b_scale_ptrs) 

1033 a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm 

1034 a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] 

1035 else: # tensor-wise 

1036 a_scale = tl.load(a_scale_ptr) 

1037 b_scale = tl.load(b_scale_ptr + off_experts) 

1038 if HAS_BIAS: 

1039 bias_ptrs = b_bias_ptr + off_experts * stride_bbe + offs_bn * stride_bbn 

1040 bias = tl.load(bias_ptrs, mask=(offs_bn < N), other=0.0) 

1041 # Accumulate C block in fp32 

1042 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

1043 k_total = K 

1044 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

1045 # Pre-compute remaining K for this iteration 

1046 k_offset = k * BLOCK_SIZE_K 

1047 k_remaining = k_total - k_offset 

1048 # Use other=0.0 for proper tail block handling - compatible with all batch sizes 

1049 a = tl.load( 

1050 a_ptrs, 

1051 mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), 

1052 other=0.0, 

1053 ) 

1054 b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) 

1055 if use_int8_w8a16: 

1056 accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) 

1057 elif use_fp8_w8a8 or use_int8_w8a8: 

1058 if group_k > 0 and group_n > 0: 

1059 k_start = k * BLOCK_SIZE_K 

1060 offs_ks = k_start // group_k 

1061 a_scale = tl.load( 

1062 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 

1063 ) 

1064 b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) 

1065 

1066 accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] 

1067 else: 

1068 if use_fp8_w8a8: 

1069 accumulator = tl.dot(a, b, acc=accumulator) 

1070 else: 

1071 accumulator += tl.dot(a, b) 

1072 else: 

1073 accumulator += tl.dot(a, b) 

1074 # Update pointers for next iteration 

1075 a_ptrs += BLOCK_SIZE_K * stride_ak 

1076 b_ptrs += BLOCK_SIZE_K * stride_bk 

1077 

1078 # Dequantization 

1079 if use_int8_w8a16: 

1080 accumulator = accumulator * b_scale 

1081 elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): 

1082 accumulator = accumulator * a_scale * b_scale 

1083 

1084 if HAS_BIAS: 

1085 accumulator += bias[None, :] 

1086 

1087 # Router weight multiplication (must be in fp32) 

1088 if MUL_ROUTED_WEIGHT: 

1089 moe_weight = tl.load( 

1090 topk_weights_ptr + offs_token, 

1091 mask=token_mask, 

1092 other=0, 

1093 ) 

1094 accumulator *= moe_weight[:, None] 

1095 

1096 accumulator = accumulator.to(compute_type) 

1097 

1098 # Write back output 

1099 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

1100 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

1101 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

1102 tl.store(c_ptrs, accumulator, mask=c_mask) 

1103 

1104 

1105def invoke_fused_moe_wna16_triton_kernel( 

1106 A: torch.Tensor, 

1107 B: torch.Tensor, 

1108 C: torch.Tensor, 

1109 B_scale: torch.Tensor | None, 

1110 B_zp: torch.Tensor | None, 

1111 topk_weights: torch.Tensor | None, 

1112 sorted_token_ids: torch.Tensor, 

1113 expert_ids: torch.Tensor, 

1114 num_tokens_post_padded: torch.Tensor, 

1115 mul_routed_weight: bool, 

1116 top_k: int, 

1117 config: dict[str, Any], 

1118 compute_type: tl.dtype, 

1119 use_int8_w8a16: bool, 

1120 use_int4_w4a16: bool, 

1121 block_shape: list[int] | None, 

1122): 

1123 assert B_scale is not None and B_scale.ndim == 3 

1124 assert B_zp is None or B_zp.ndim == 3 

1125 assert block_shape is not None and block_shape[0] == 0 

1126 

1127 M = A.size(0) 

1128 num_tokens = M * top_k 

1129 

1130 EM = sorted_token_ids.size(0) 

1131 if A.size(0) < config["BLOCK_SIZE_M"]: 

1132 # optimize for small batch_size. 

1133 # We assume that top_ids of each token is unique, 

1134 # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, 

1135 # and we can skip some invalid blocks. 

1136 EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) 

1137 grid = lambda META: ( 

1138 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1139 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1140 ) 

1141 config = config.copy() 

1142 config.update( 

1143 get_moe_wna16_block_config( 

1144 config=config, 

1145 use_moe_wna16_cuda=False, 

1146 num_valid_tokens=num_tokens, 

1147 size_k=A.size(1), 

1148 size_n=B.size(1), 

1149 num_experts=B.size(1), 

1150 group_size=block_shape[1], 

1151 real_top_k=top_k, 

1152 block_size_m=config["BLOCK_SIZE_M"], 

1153 ) 

1154 ) 

1155 

1156 fused_moe_kernel_gptq_awq[grid]( 

1157 A, 

1158 B, 

1159 C, 

1160 B_scale, 

1161 B_zp, 

1162 topk_weights, 

1163 sorted_token_ids, 

1164 expert_ids, 

1165 num_tokens_post_padded, 

1166 B.size(1), 

1167 A.size(1), 

1168 EM, 

1169 num_tokens, 

1170 A.stride(0), 

1171 A.stride(1), 

1172 B.stride(0), 

1173 B.stride(2), 

1174 B.stride(1), 

1175 C.stride(1), 

1176 C.stride(2), 

1177 B_scale.stride(0), 

1178 B_scale.stride(2), 

1179 B_scale.stride(1), 

1180 B_zp.stride(0) if B_zp is not None else 0, 

1181 B_zp.stride(2) if B_zp is not None else 0, 

1182 B_zp.stride(1) if B_zp is not None else 0, 

1183 block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, 

1184 group_size=block_shape[1], 

1185 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1186 top_k=top_k, 

1187 compute_type=compute_type, 

1188 has_zp=B_zp is not None, 

1189 use_int4_w4a16=use_int4_w4a16, 

1190 use_int8_w8a16=use_int8_w8a16, 

1191 **config, 

1192 ) 

1193 

1194 

1195def invoke_fused_moe_triton_kernel( 

1196 A: torch.Tensor, 

1197 B: torch.Tensor, 

1198 C: torch.Tensor, 

1199 A_scale: Optional[torch.Tensor], 

1200 B_scale: Optional[torch.Tensor], 

1201 topk_weights: Optional[torch.Tensor], 

1202 sorted_token_ids: torch.Tensor, 

1203 expert_ids: torch.Tensor, 

1204 num_tokens_post_padded: torch.Tensor, 

1205 mul_routed_weight: bool, 

1206 top_k: int, 

1207 config: dict[str, Any], 

1208 compute_type: tl.dtype, 

1209 use_fp8_w8a8: bool = False, 

1210 use_int8_w8a8: bool = False, 

1211 use_int8_w8a16: bool = False, 

1212 use_int4_w4a16: bool = False, 

1213 per_channel_quant: bool = False, 

1214 block_shape: Optional[list[int]] = None, 

1215 B_bias: torch.Tensor | None = None, 

1216) -> None: 

1217 """Launch the fused_moe_kernel Triton kernel.""" 

1218 assert topk_weights is not None or not mul_routed_weight 

1219 assert topk_weights is None or topk_weights.stride(1) == 1 

1220 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1221 

1222 if use_fp8_w8a8 or use_int8_w8a8: 

1223 assert B_scale is not None 

1224 assert block_shape is None or triton.cdiv( 

1225 B.size(-2), block_shape[0] 

1226 ) == B_scale.size(-2) 

1227 assert block_shape is None or triton.cdiv( 

1228 B.size(-1), block_shape[1] 

1229 ) == B_scale.size(-1) 

1230 elif use_int8_w8a16 or use_int4_w4a16: 

1231 assert B_scale is not None 

1232 assert block_shape is None or block_shape[0] == 0 

1233 else: 

1234 assert A_scale is None 

1235 assert B_scale is None 

1236 

1237 M = A.size(0) 

1238 num_tokens = M * top_k 

1239 if sorted_token_ids is not None: 

1240 EM = sorted_token_ids.size(0) 

1241 if A.size(0) < config["BLOCK_SIZE_M"]: 

1242 EM = min( 

1243 sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"] 

1244 ) 

1245 else: 

1246 EM = num_tokens * config["BLOCK_SIZE_M"] 

1247 grid = lambda META: ( 

1248 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

1249 * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), 

1250 ) 

1251 HAS_BIAS = B_bias is not None 

1252 

1253 config = config.copy() 

1254 config["SPLIT_K"] = 1 

1255 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") 

1256 if block_shape is not None: 

1257 BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) 

1258 

1259 fused_moe_kernel[grid]( 

1260 A, 

1261 B, 

1262 C, 

1263 B_bias, 

1264 A_scale, 

1265 B_scale, 

1266 topk_weights, 

1267 sorted_token_ids, 

1268 expert_ids, 

1269 num_tokens_post_padded, 

1270 B.size(1), # N 

1271 B.size(2), # K 

1272 EM, 

1273 num_tokens, 

1274 A.stride(0), 

1275 A.stride(1), 

1276 B.stride(0), 

1277 B.stride(2), 

1278 B.stride(1), 

1279 C.stride(1), 

1280 C.stride(2), 

1281 A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, 

1282 A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, 

1283 B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1284 B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, 

1285 B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, 

1286 B_bias.stride(0) if B_bias is not None else 0, 

1287 B_bias.stride(1) if B_bias is not None else 0, 

1288 0 if block_shape is None else block_shape[0], 

1289 0 if block_shape is None else block_shape[1], 

1290 MUL_ROUTED_WEIGHT=mul_routed_weight, 

1291 top_k=top_k, 

1292 compute_type=compute_type, 

1293 use_fp8_w8a8=use_fp8_w8a8, 

1294 use_int8_w8a8=use_int8_w8a8, 

1295 use_int8_w8a16=use_int8_w8a16, 

1296 per_channel_quant=per_channel_quant, 

1297 naive_block_assignment=(sorted_token_ids is None), 

1298 HAS_BIAS=HAS_BIAS, 

1299 BLOCK_SIZE_K=BLOCK_SIZE_K, 

1300 **config, 

1301 ) 

1302 

1303 

1304def dispatch_fused_moe_kernel( 

1305 A: torch.Tensor, 

1306 B: torch.Tensor, 

1307 C: torch.Tensor, 

1308 A_scale: Optional[torch.Tensor], 

1309 B_scale: Optional[torch.Tensor], 

1310 B_zp: Optional[torch.Tensor], 

1311 topk_weights: Optional[torch.Tensor], 

1312 sorted_token_ids: torch.Tensor, 

1313 expert_ids: torch.Tensor, 

1314 num_tokens_post_padded: torch.Tensor, 

1315 mul_routed_weight: bool, 

1316 top_k: int, 

1317 config: dict[str, Any], 

1318 compute_type: tl.dtype, 

1319 use_fp8_w8a8: bool, 

1320 use_int8_w8a8: bool, 

1321 use_int8_w8a16: bool, 

1322 use_int4_w4a16: bool, 

1323 per_channel_quant: bool, 

1324 block_shape: Optional[list[int]] = None, 

1325 B_bias: Optional[torch.Tensor] = None, 

1326) -> None: 

1327 """Dispatch to the appropriate fused MoE kernel based on quantization flags.""" 

1328 assert topk_weights is not None or not mul_routed_weight 

1329 assert topk_weights is None or topk_weights.stride(1) == 1 

1330 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 

1331 

1332 # M = A.size(0) 

1333 # num_tokens = M * top_k 

1334 

1335 if False: 

1336 # TODO: Other precision-specific implementations 

1337 # use_fp8_w8a8, 

1338 # use_int8_w8a8, 

1339 # use_int8_w8a16, 

1340 # use_int4_w4a16, 

1341 pass 

1342 if (use_int8_w8a16 or use_int4_w4a16) and ( 

1343 block_shape is not None and block_shape[1] > 0 

1344 ): 

1345 assert B_bias is None 

1346 invoke_fused_moe_wna16_triton_kernel( 

1347 A, 

1348 B, 

1349 C, 

1350 B_scale, 

1351 B_zp, 

1352 topk_weights, 

1353 sorted_token_ids, 

1354 expert_ids, 

1355 num_tokens_post_padded, 

1356 mul_routed_weight, 

1357 top_k, 

1358 config, 

1359 compute_type, 

1360 use_int8_w8a16, 

1361 use_int4_w4a16, 

1362 block_shape, 

1363 ) 

1364 else: 

1365 invoke_fused_moe_triton_kernel( 

1366 A, 

1367 B, 

1368 C, 

1369 A_scale, 

1370 B_scale, 

1371 topk_weights, 

1372 sorted_token_ids, 

1373 expert_ids, 

1374 num_tokens_post_padded, 

1375 mul_routed_weight, 

1376 top_k, 

1377 config, 

1378 compute_type, 

1379 use_fp8_w8a8, 

1380 use_int8_w8a8, 

1381 use_int8_w8a16, 

1382 use_int4_w4a16, 

1383 per_channel_quant, 

1384 block_shape, 

1385 B_bias, 

1386 ) 

1387 

1388 

1389def fused_experts_impl( 

1390 hidden_states: torch.Tensor, 

1391 w1: torch.Tensor, 

1392 w2: torch.Tensor, 

1393 topk_weights: torch.Tensor, 

1394 topk_ids: torch.Tensor, 

1395 inplace: bool = False, 

1396 activation: str = "silu", 

1397 apply_router_weight_on_input: bool = False, 

1398 use_fp8_w8a8: bool = False, 

1399 use_int8_w8a8: bool = False, 

1400 use_int8_w8a16: bool = False, 

1401 use_int4_w4a16: bool = False, 

1402 ocp_mx_scheme: str | None = None, 

1403 per_channel_quant: bool = False, 

1404 global_num_experts: int = -1, 

1405 expert_map: torch.Tensor | None = None, 

1406 w1_scale: Optional[torch.Tensor] = None, 

1407 w2_scale: Optional[torch.Tensor] = None, 

1408 w1_zp: torch.Tensor | None = None, 

1409 w2_zp: torch.Tensor | None = None, 

1410 a1_scale: Optional[torch.Tensor] = None, 

1411 a2_scale: Optional[torch.Tensor] = None, 

1412 block_shape: Optional[list[int]] = None, 

1413 w1_bias: Optional[torch.Tensor] = None, 

1414 w2_bias: Optional[torch.Tensor] = None, 

1415) -> torch.Tensor: 

1416 logger.debug("GEMS_ASCEND FUSED MOE") 

1417 if hasattr(activation, "value"): 

1418 activation = activation.value 

1419 assert ( 

1420 activation == "silu" 

1421 ), f"Only 'silu' activation is supported, got {activation}" 

1422 

1423 activation_enum = MoEActivation.from_str(activation) 

1424 

1425 # Check constraints 

1426 if use_int4_w4a16: 

1427 # INT4 stored unpacked in INT8 containers (full K dim) 

1428 assert hidden_states.size(1) == w1.size( 

1429 2 

1430 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1431 elif ocp_mx_scheme is not None: 

1432 if ocp_mx_scheme.startswith("w_mxfp4"): 

1433 assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" 

1434 elif ocp_mx_scheme.startswith("w_mxfp6"): 

1435 assert ( 

1436 hidden_states.size(1) == (w1.size(2) * 4) // 3 

1437 ), "hidden size mismatch" 

1438 else: 

1439 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1440 else: 

1441 assert hidden_states.size(1) == w1.size( 

1442 2 

1443 ), f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" 

1444 

1445 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

1446 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

1447 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

1448 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

1449 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

1450 

1451 num_tokens = hidden_states.size(0) 

1452 E, N, _ = w1.size() 

1453 K = w2.size(1) 

1454 if global_num_experts == -1: 

1455 global_num_experts = E 

1456 top_k_num = topk_ids.size(1) 

1457 

1458 CHUNK_SIZE: int = 16 * 1024 

1459 M = min(num_tokens, CHUNK_SIZE) 

1460 

1461 config_dtype = _get_config_dtype_str( 

1462 use_fp8_w8a8=use_fp8_w8a8, 

1463 use_int8_w8a16=use_int8_w8a16, 

1464 use_int4_w4a16=use_int4_w4a16, 

1465 ocp_mx_scheme=ocp_mx_scheme, 

1466 dtype=hidden_states.dtype, 

1467 ) 

1468 

1469 quant_dtype = _get_config_quant_dtype( 

1470 use_fp8_w8a8=use_fp8_w8a8, 

1471 use_int8_w8a8=use_int8_w8a8, 

1472 ocp_mx_scheme=ocp_mx_scheme, 

1473 ) 

1474 

1475 get_config_func = functools.partial( 

1476 try_get_optimal_moe_config, 

1477 w1.size(), 

1478 w2.size(), 

1479 top_k_num, 

1480 config_dtype, 

1481 block_shape=block_shape, 

1482 ) 

1483 

1484 config = get_config_func(M) 

1485 

1486 # cache1 and cache3 share memory (non-overlapping lifetime) 

1487 cache13 = torch.empty( 

1488 M * top_k_num * max(N, K), 

1489 device=hidden_states.device, 

1490 dtype=hidden_states.dtype, 

1491 ) 

1492 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

1493 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

1494 

1495 # cache2 needs separate memory (concurrent with cache1) 

1496 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

1497 intermediate_cache2 = torch.empty( 

1498 (M * top_k_num, activation_out_dim), 

1499 device=hidden_states.device, 

1500 dtype=hidden_states.dtype, 

1501 ) 

1502 

1503 if hidden_states.dtype == torch.bfloat16: 

1504 compute_type = tl.bfloat16 

1505 elif hidden_states.dtype == torch.float16: 

1506 compute_type = tl.float16 

1507 elif hidden_states.dtype == torch.float32: 

1508 compute_type = tl.float32 

1509 else: 

1510 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

1511 

1512 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

1513 

1514 if ocp_mx_scheme is not None: 

1515 # Dequantize OCP MX weights (TODO: skip on platforms with native MX) 

1516 if ocp_mx_scheme.startswith("w_mxfp4"): 

1517 w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) 

1518 w1_scale = None 

1519 w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) 

1520 w2_scale = None 

1521 elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"): 

1522 w1 = dequant_mxfp6( 

1523 w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1524 ) 

1525 w1_scale = None 

1526 w2 = dequant_mxfp6( 

1527 w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype 

1528 ) 

1529 w2_scale = None 

1530 elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"): 

1531 w1 = dequant_mxfp6( 

1532 w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1533 ) 

1534 w1_scale = None 

1535 w2 = dequant_mxfp6( 

1536 w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype 

1537 ) 

1538 w2_scale = None 

1539 else: 

1540 raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}") 

1541 

1542 # Dequant INT8/INT4 weights (Triton can't do mixed-dtype dot) 

1543 if use_int8_w8a16 or use_int4_w4a16: 

1544 w1 = w1.to(hidden_states.dtype) * w1_scale.unsqueeze(-1).to(hidden_states.dtype) 

1545 w1_scale = None 

1546 w2 = w2.to(hidden_states.dtype) * w2_scale.unsqueeze(-1).to(hidden_states.dtype) 

1547 w2_scale = None 

1548 use_int8_w8a16 = False 

1549 use_int4_w4a16 = False 

1550 

1551 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

1552 begin_chunk_idx, end_chunk_idx = ( 

1553 chunk * CHUNK_SIZE, 

1554 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

1555 ) 

1556 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

1557 tokens_in_chunk, _ = curr_hidden_states.size() 

1558 

1559 if tokens_in_chunk == 0: 

1560 break 

1561 

1562 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

1563 # Adjust cache size for last chunk 

1564 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

1565 intermediate_cache2 = intermediate_cache2[ 

1566 : tokens_in_chunk * topk_ids.size(1) 

1567 ] 

1568 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

1569 config = get_config_func(tokens_in_chunk) 

1570 

1571 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

1572 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

1573 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

1574 A=curr_hidden_states, 

1575 A_scale=a1_scale, 

1576 quant_dtype=quant_dtype, 

1577 per_act_token_quant=per_channel_quant, 

1578 block_shape=block_shape, 

1579 ocp_mx_scheme=ocp_mx_scheme, 

1580 ) 

1581 

1582 SPARSITY_FACTOR = 4 

1583 # For small tokens (< 32), always use naive assignment to skip alignment overhead 

1584 use_naive_small = tokens_in_chunk < 32 

1585 naive_block_assignment = use_naive_small or ( 

1586 expert_map is None 

1587 and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts 

1588 and not ( 

1589 (use_int8_w8a16 or use_int4_w4a16) 

1590 and block_shape is not None 

1591 and block_shape[1] > 0 

1592 ) 

1593 ) 

1594 

1595 if not naive_block_assignment: 

1596 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

1597 curr_topk_ids, 

1598 config["BLOCK_SIZE_M"], 

1599 global_num_experts, 

1600 expert_map, 

1601 # ignore_invalid_experts=True, 

1602 ) 

1603 else: 

1604 max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"] 

1605 expert_ids = curr_topk_ids.view(-1) 

1606 num_tokens_post_padded = torch.empty( 

1607 (1), dtype=torch.int32, device=topk_ids.device 

1608 ) 

1609 num_tokens_post_padded.fill_(max_num_tokens_padded) 

1610 sorted_token_ids = None 

1611 

1612 dispatch_fused_moe_kernel( 

1613 qcurr_hidden_states, 

1614 w1, 

1615 intermediate_cache1, 

1616 a1q_scale, 

1617 w1_scale, 

1618 w1_zp, 

1619 curr_topk_weights, 

1620 sorted_token_ids, 

1621 expert_ids, 

1622 num_tokens_post_padded, 

1623 apply_router_weight_on_input, 

1624 top_k_num, 

1625 config, 

1626 compute_type=compute_type, 

1627 use_fp8_w8a8=use_fp8_w8a8, 

1628 use_int8_w8a8=use_int8_w8a8, 

1629 use_int8_w8a16=use_int8_w8a16, 

1630 use_int4_w4a16=use_int4_w4a16, 

1631 per_channel_quant=per_channel_quant, 

1632 block_shape=block_shape, 

1633 B_bias=w1_bias, 

1634 ) 

1635 

1636 apply_moe_activation( 

1637 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

1638 ) 

1639 

1640 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

1641 A=intermediate_cache2, 

1642 A_scale=a2_scale, 

1643 quant_dtype=quant_dtype, 

1644 per_act_token_quant=per_channel_quant, 

1645 block_shape=block_shape, 

1646 ocp_mx_scheme=ocp_mx_scheme, 

1647 ) 

1648 

1649 if expert_map is not None: 

1650 intermediate_cache3.zero_() 

1651 

1652 dispatch_fused_moe_kernel( 

1653 qintermediate_cache2, 

1654 w2, 

1655 intermediate_cache3, 

1656 a2q_scale, 

1657 w2_scale, 

1658 w2_zp, 

1659 curr_topk_weights, 

1660 sorted_token_ids, 

1661 expert_ids, 

1662 num_tokens_post_padded, 

1663 not apply_router_weight_on_input, 

1664 1, 

1665 config, 

1666 compute_type=compute_type, 

1667 use_fp8_w8a8=use_fp8_w8a8, 

1668 use_int8_w8a8=use_int8_w8a8, 

1669 use_int8_w8a16=use_int8_w8a16, 

1670 use_int4_w4a16=use_int4_w4a16, 

1671 per_channel_quant=per_channel_quant, 

1672 block_shape=block_shape, 

1673 B_bias=w2_bias, 

1674 ) 

1675 

1676 moe_sum( 

1677 intermediate_cache3.view(*intermediate_cache3.size()), 

1678 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

1679 ) 

1680 

1681 return out_hidden_states 

1682 

1683 

1684def inplace_fused_experts( 

1685 hidden_states: torch.Tensor, 

1686 w1: torch.Tensor, 

1687 w2: torch.Tensor, 

1688 topk_weights: torch.Tensor, 

1689 topk_ids: torch.Tensor, 

1690 activation: str = "silu", 

1691 apply_router_weight_on_input: bool = False, 

1692 use_fp8_w8a8: bool = False, 

1693 use_int8_w8a8: bool = False, 

1694 use_int8_w8a16: bool = False, 

1695 use_int4_w4a16: bool = False, 

1696 per_channel_quant: bool = False, 

1697 global_num_experts: int = -1, 

1698 w1_scale: Optional[torch.Tensor] = None, 

1699 w2_scale: Optional[torch.Tensor] = None, 

1700 a1_scale: Optional[torch.Tensor] = None, 

1701 a2_scale: Optional[torch.Tensor] = None, 

1702 block_shape: Optional[list[int]] = None, 

1703 w1_bias: Optional[torch.Tensor] = None, 

1704 w2_bias: Optional[torch.Tensor] = None, 

1705) -> None: 

1706 """ 

1707 In-place fused MoE: writes output directly into ``hidden_states``. 

1708 

1709 Same semantics as ``fused_experts_impl(..., inplace=True)``. 

1710 Returns None (the result is stored in ``hidden_states``). 

1711 """ 

1712 fused_experts_impl( 

1713 hidden_states, 

1714 w1, 

1715 w2, 

1716 topk_weights, 

1717 topk_ids, 

1718 inplace=True, 

1719 activation=activation, 

1720 apply_router_weight_on_input=apply_router_weight_on_input, 

1721 use_fp8_w8a8=use_fp8_w8a8, 

1722 use_int8_w8a8=use_int8_w8a8, 

1723 use_int8_w8a16=use_int8_w8a16, 

1724 use_int4_w4a16=use_int4_w4a16, 

1725 per_channel_quant=per_channel_quant, 

1726 global_num_experts=global_num_experts, 

1727 w1_scale=w1_scale, 

1728 w2_scale=w2_scale, 

1729 a1_scale=a1_scale, 

1730 a2_scale=a2_scale, 

1731 block_shape=block_shape, 

1732 w1_bias=w1_bias, 

1733 w2_bias=w2_bias, 

1734 ) 

1735 

1736 

1737def outplace_fused_experts( 

1738 hidden_states: torch.Tensor, 

1739 w1: torch.Tensor, 

1740 w2: torch.Tensor, 

1741 topk_weights: torch.Tensor, 

1742 topk_ids: torch.Tensor, 

1743 activation: str = "silu", 

1744 apply_router_weight_on_input: bool = False, 

1745 use_fp8_w8a8: bool = False, 

1746 use_int8_w8a8: bool = False, 

1747 use_int8_w8a16: bool = False, 

1748 use_int4_w4a16: bool = False, 

1749 per_channel_quant: bool = False, 

1750 global_num_experts: int = -1, 

1751 w1_scale: Optional[torch.Tensor] = None, 

1752 w2_scale: Optional[torch.Tensor] = None, 

1753 a1_scale: Optional[torch.Tensor] = None, 

1754 a2_scale: Optional[torch.Tensor] = None, 

1755 block_shape: Optional[list[int]] = None, 

1756 w1_bias: Optional[torch.Tensor] = None, 

1757 w2_bias: Optional[torch.Tensor] = None, 

1758) -> torch.Tensor: 

1759 """ 

1760 Out-of-place fused MoE: allocates and returns a new output tensor. 

1761 

1762 Same semantics as ``fused_experts_impl(..., inplace=False)``. 

1763 """ 

1764 return fused_experts_impl( 

1765 hidden_states, 

1766 w1, 

1767 w2, 

1768 topk_weights, 

1769 topk_ids, 

1770 inplace=False, 

1771 activation=activation, 

1772 apply_router_weight_on_input=apply_router_weight_on_input, 

1773 use_fp8_w8a8=use_fp8_w8a8, 

1774 use_int8_w8a8=use_int8_w8a8, 

1775 use_int8_w8a16=use_int8_w8a16, 

1776 use_int4_w4a16=use_int4_w4a16, 

1777 per_channel_quant=per_channel_quant, 

1778 global_num_experts=global_num_experts, 

1779 w1_scale=w1_scale, 

1780 w2_scale=w2_scale, 

1781 a1_scale=a1_scale, 

1782 a2_scale=a2_scale, 

1783 block_shape=block_shape, 

1784 w1_bias=w1_bias, 

1785 w2_bias=w2_bias, 

1786 )