Coverage for src/flag_gems/fused_moe_mxq.py: 0%

310 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# QC-MoE: Quantized Mixture of Experts kernel for FlagGems 

3# Main module integrating MoE kernels with quantization support 

4 

5from dataclasses import dataclass 

6from enum import Enum 

7from typing import Any, List, Optional, Tuple 

8 

9import torch 

10import triton 

11import triton.language as tl 

12 

13# Device detection 

14_is_cuda = torch.cuda.is_available() 

15 

16if _is_cuda: 

17 

18 def is_sm90_supported(): 

19 device_cap = torch.cuda.get_device_capability() 

20 return device_cap[0] >= 9 # H100, H200, etc. 

21 

22else: 

23 

24 def is_sm90_supported(): 

25 return False 

26 

27 

28# ============================================================================ 

29# QuantMode and QuantConfig 

30# ============================================================================ 

31 

32 

33class QuantMode(Enum): 

34 """Quantization modes supported by QC-MoE.""" 

35 

36 FP16 = "fp16" 

37 FP8 = "fp8" 

38 INT8 = "int8" 

39 W8A16 = "w8a16" # INT8 weight, FP16 activation 

40 W4A16 = "w4a16" # INT4 weight, FP16 activation 

41 

42 

43@dataclass 

44class QuantConfig: 

45 """Configuration for MoE quantization.""" 

46 

47 mode: QuantMode = QuantMode.FP16 

48 group_size: int = 128 

49 has_zero_point: bool = True 

50 per_channel_quant: bool = False 

51 

52 @property 

53 def w_nbits(self) -> int: 

54 """Get weight bit width from mode.""" 

55 if self.mode == QuantMode.W4A16: 

56 return 4 

57 elif self.mode in (QuantMode.W8A16, QuantMode.INT8, QuantMode.FP8): 

58 return 8 

59 return 16 

60 

61 @property 

62 def use_int4(self) -> bool: 

63 return self.mode == QuantMode.W4A16 

64 

65 @property 

66 def use_int8(self) -> bool: 

67 return self.mode in (QuantMode.W8A16, QuantMode.INT8) 

68 

69 

70# ============================================================================ 

71# Triton Kernels 

72# ============================================================================ 

73 

74 

75@triton.jit 

76def fused_moe_kernel_gptq_awq( 

77 # Pointers to matrices 

78 A, 

79 B, 

80 C, 

81 B_scale, 

82 B_zp, 

83 topk_weights, 

84 sorted_token_ids, 

85 expert_ids, 

86 num_tokens_post_padded, 

87 # Matrix dimensions 

88 N, 

89 K, 

90 EM, 

91 num_valid_tokens, 

92 # Strides 

93 stride_am, 

94 stride_ak, 

95 stride_be, 

96 stride_bk, 

97 stride_bn, 

98 stride_cm, 

99 stride_cn, 

100 stride_bse, 

101 stride_bsk, 

102 stride_bsn, 

103 stride_bze, 

104 stride_bzk, 

105 stride_bzn, 

106 group_size: tl.constexpr, 

107 # Meta-parameters 

108 BLOCK_SIZE_M: tl.constexpr, 

109 BLOCK_SIZE_N: tl.constexpr, 

110 BLOCK_SIZE_K: tl.constexpr, 

111 GROUP_SIZE_M: tl.constexpr, 

112 MUL_ROUTED_WEIGHT: tl.constexpr, 

113 top_k: tl.constexpr, 

114 compute_type: tl.constexpr, 

115 has_zp: tl.constexpr, 

116 use_int4_w4a16: tl.constexpr, 

117 use_int8_w8a16: tl.constexpr, 

118 even_Ks: tl.constexpr, 

119 filter_expert: tl.constexpr, 

120): 

121 """ 

122 Simplified MoE kernel for single dispatch entry processing. 

123 Each program processes one (token, expert) pair. 

124 """ 

125 pid = tl.program_id(0) 

126 

127 # Check bounds 

128 if pid >= num_valid_tokens: 

129 return 

130 

131 # Load dispatch information 

132 token_id = tl.load(sorted_token_ids + pid).to(tl.int64) 

133 expert_id = tl.load(expert_ids + pid).to(tl.int64) 

134 weight = tl.load(topk_weights + pid).to(compute_type) 

135 

136 # Precompute strides 

137 stride_bn_c = tl.constexpr(stride_bn) 

138 stride_bk_c = tl.constexpr(stride_bk) 

139 stride_bsn_c = tl.constexpr(stride_bsn) 

140 stride_bsk_c = tl.constexpr(stride_bsk) 

141 stride_bzn_c = tl.constexpr(stride_bzn) 

142 stride_bzk_c = tl.constexpr(stride_bzk) 

143 stride_be_c = tl.constexpr(stride_be) 

144 stride_bse_c = tl.constexpr(stride_bse) 

145 stride_bze_c = tl.constexpr(stride_bze) 

146 

147 # offs_n: range of N elements 

148 offs_n = tl.arange(0, BLOCK_SIZE_N) 

149 n_mask = offs_n < N 

150 

151 # Initialize accumulator 

152 accumulator = tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32) 

153 

154 # Process all K elements in BLOCK_SIZE_K chunks 

155 for k_block in range(tl.cdiv(K, BLOCK_SIZE_K)): 

156 k_base = k_block * BLOCK_SIZE_K 

157 offs_k = tl.arange(0, BLOCK_SIZE_K) 

158 k_indices = k_base + offs_k 

159 k_mask = k_indices < K 

160 

161 # Load activation: A[token_id, k_indices] 

162 a = tl.load( 

163 A + (token_id * stride_am + k_indices * stride_ak), mask=k_mask, other=0.0 

164 ).to(tl.float32) 

165 

166 # Load weight values: W[expert_id, offs_n, k_indices] 

167 w = tl.load( 

168 B 

169 + ( 

170 expert_id * stride_be_c 

171 + offs_n[None, :] * stride_bn_c 

172 + k_indices[:, None] * stride_bk_c 

173 ), 

174 mask=k_mask[:, None] & n_mask[None, :], 

175 other=0.0, 

176 ) 

177 

178 # Dequantize weights 

179 if use_int4_w4a16: 

180 w = (w & 0xF).to(compute_type) 

181 elif use_int8_w8a16: 

182 w = w.to(compute_type) 

183 

184 # Load scales: scales[expert_id, offs_n, group] 

185 scale_group = k_indices // group_size 

186 scales = tl.load( 

187 B_scale 

188 + ( 

189 expert_id * stride_bse_c 

190 + offs_n[None, :] * stride_bsn_c 

191 + scale_group[:, None] * stride_bsk_c 

192 ), 

193 mask=k_mask[:, None] & n_mask[None, :], 

194 other=1.0, 

195 ).to(tl.float32) 

196 

197 # Dequantize based on quantization mode 

198 if use_int4_w4a16: 

199 if has_zp: 

200 zp = tl.load( 

201 B_zp 

202 + ( 

203 expert_id * stride_bze_c 

204 + offs_n[None, :] * stride_bzn_c 

205 + scale_group[:, None] * stride_bzk_c 

206 ), 

207 mask=k_mask[:, None] & n_mask[None, :], 

208 other=0.0, 

209 ).to(tl.float32) 

210 w_dequant = (w.to(tl.float32) - zp) * scales 

211 else: 

212 w_dequant = (w.to(tl.float32) - 8.0) * scales 

213 elif use_int8_w8a16: 

214 if has_zp: 

215 zp = tl.load( 

216 B_zp 

217 + ( 

218 expert_id * stride_bze_c 

219 + offs_n[None, :] * stride_bzn_c 

220 + scale_group[:, None] * stride_bzk_c 

221 ), 

222 mask=k_mask[:, None] & n_mask[None, :], 

223 other=0.0, 

224 ).to(tl.float32) 

225 w_dequant = (w.to(tl.float32) - zp) * scales 

226 else: 

227 w_dequant = (w.to(tl.float32) - 128.0) * scales 

228 else: 

229 # No quantization - weights are already in compute_type (FP16) 

230 w_dequant = w.to(tl.float32) * scales 

231 

232 # Compute matrix multiply using expand and sum: [BLOCK_SIZE_K, BLOCK_SIZE_N] * [BLOCK_SIZE_K, 1] 

233 a_expanded = a[:, None] # [BLOCK_SIZE_K, BLOCK_SIZE_N] 

234 result = tl.sum(a_expanded * w_dequant, axis=0) # [BLOCK_SIZE_N] 

235 

236 # Accumulate 

237 accumulator = accumulator + result 

238 

239 # Apply routing weight 

240 if MUL_ROUTED_WEIGHT: 

241 accumulator = accumulator * weight 

242 

243 accumulator = accumulator.to(compute_type) 

244 

245 # Store result using atomic add 

246 offs_n = tl.arange(0, BLOCK_SIZE_N) 

247 n_mask = offs_n < N 

248 output_ptrs = C + (token_id * stride_cm + offs_n * stride_cn) 

249 tl.atomic_add(output_ptrs, accumulator, mask=n_mask) 

250 

251 

252@triton.jit 

253def fused_moe_kernel_fp16_swiglu( 

254 A, 

255 C, 

256 B_gate, 

257 B_up, 

258 B_down, 

259 topk_weights, 

260 sorted_token_ids, 

261 expert_ids, 

262 num_tokens_post_padded, 

263 inter_ptr, 

264 N, 

265 K, 

266 EM, 

267 num_valid_tokens, 

268 stride_am, 

269 stride_ak, 

270 stride_bn, 

271 stride_bk, 

272 stride_cm, 

273 stride_cn, 

274 stride_gate_e, 

275 stride_up_e, 

276 stride_down_e, 

277 stride_gate_n, 

278 stride_gate_k, 

279 stride_up_n, 

280 stride_up_k, 

281 stride_down_k, 

282 stride_down_n, 

283 stride_inter_m, 

284 BLOCK_SIZE_K: tl.constexpr, 

285 top_k: tl.constexpr, 

286 even_Ks: tl.constexpr, 

287): 

288 """ 

289 FP16 SwiGLU MoE — complete gate(W1)/up(W3)/down(W2) in one dispatch entry. 

290 

291 FFN(x) = W2 @ (silu(W1 @ x) * (W3 @ x)) 

292 Each program processes one (token, expert) pair. 

293 All loops use 1-element scalar iterations to avoid shape-compatibility issues. 

294 """ 

295 pid = tl.program_id(0) 

296 if pid >= num_valid_tokens: 

297 return 

298 

299 token_id = tl.load(sorted_token_ids + pid).to(tl.int64) 

300 expert_id = tl.load(expert_ids + pid).to(tl.int64) 

301 weight = tl.load(topk_weights + pid).to(tl.float32) 

302 

303 # Compute inter_size = N in multiples of 32; partial blocks handled by mask 

304 inter_off = pid * stride_inter_m 

305 

306 # ---------- GEMM 1: gate_acc[n] = sum_k( A[token,k] * W1[exp,n,k] ) ---------- 

307 for n in range(N): 

308 acc = 0.0 

309 for kb in range(tl.cdiv(K, BLOCK_SIZE_K)): 

310 k_offs = kb * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 

311 k_mask = k_offs < K 

312 a_vals = tl.load( 

313 A + token_id * stride_am + k_offs, mask=k_mask, other=0.0 

314 ).to(tl.float32) 

315 w_gate = tl.load( 

316 B_gate 

317 + expert_id * stride_gate_e 

318 + n * stride_gate_n 

319 + k_offs * stride_gate_k, 

320 mask=k_mask, 

321 other=0.0, 

322 ).to(tl.float32) 

323 acc = acc + tl.sum(a_vals * w_gate) 

324 # Store gate result to inter[n] (we reuse the same buffer; gate first) 

325 gate_val = acc 

326 tl.store(inter_ptr + inter_off + n, gate_val) 

327 

328 # ---------- GEMM 2: up_acc[n] = sum_k( A[token,k] * W3[exp,n,k] ), multiply with gate ---------- 

329 for n in range(N): 

330 acc = 0.0 

331 for kb in range(tl.cdiv(K, BLOCK_SIZE_K)): 

332 k_offs = kb * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) 

333 k_mask = k_offs < K 

334 a_vals = tl.load( 

335 A + token_id * stride_am + k_offs, mask=k_mask, other=0.0 

336 ).to(tl.float32) 

337 w_up = tl.load( 

338 B_up + expert_id * stride_up_e + n * stride_up_n + k_offs * stride_up_k, 

339 mask=k_mask, 

340 other=0.0, 

341 ).to(tl.float32) 

342 acc = acc + tl.sum(a_vals * w_up) 

343 gate_val = tl.load(inter_ptr + inter_off + n).to(tl.float32) 

344 # SiLU(gate) * up -> store back as intermediate 

345 act_val = tl.sigmoid(gate_val) * acc 

346 tl.store(inter_ptr + inter_off + n, act_val) 

347 

348 # ---------- GEMM 3: down_acc[k] = sum_n( inter[n] * W2[exp,k,n] ), then scale and store ---------- 

349 for k in range(K): 

350 acc = 0.0 

351 for nb in range(tl.cdiv(N, 32)): 

352 base_n = nb * 32 

353 n_offs = base_n + tl.arange(0, 32) 

354 n_mask = n_offs < N 

355 inter_vals = tl.load( 

356 inter_ptr + inter_off + n_offs, mask=n_mask, other=0.0 

357 ).to(tl.float32) 

358 w_down = tl.load( 

359 B_down 

360 + expert_id * stride_down_e 

361 + k * stride_down_k 

362 + n_offs * stride_down_n, 

363 mask=n_mask, 

364 other=0.0, 

365 ).to(tl.float32) 

366 acc = acc + tl.sum(inter_vals * w_down) 

367 result = (acc * weight).to(tl.float16) 

368 out_idx = token_id * stride_cm + k * stride_cn 

369 cur = tl.load(C + out_idx).to(tl.float16) 

370 tl.store(C + out_idx, cur + result) 

371 

372 

373# ============================================================================ 

374# Helper Functions 

375# ============================================================================ 

376 

377 

378def get_num_experts(shape_desc: str) -> int: 

379 """Extract number of experts from shape description. 

380 

381 Common patterns: 

382 - Qwen3.5-397B-A17B: 8 experts 

383 - Mixtral-8x7B: 8 experts 

384 - Switch Transformer: variable 

385 """ 

386 if "Qwen" in shape_desc: 

387 if "397B" in shape_desc: 

388 return 8 

389 elif "72B" in shape_desc: 

390 return 8 

391 elif "Mixtral" in shape_desc: 

392 return 8 

393 elif "Switch" in shape_desc: 

394 return 64 

395 return 8 # default 

396 

397 

398def prepare_moe_inputs( 

399 x: torch.Tensor, 

400 topk_weights: torch.Tensor, 

401 topk_ids: torch.Tensor, 

402 num_experts: int, 

403) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

404 """ 

405 Prepare inputs for fused MoE kernel. 

406 

407 Args: 

408 x: Input tensor of shape (num_tokens, hidden_dim) 

409 topk_weights: Weights for selected experts, shape (num_tokens, topk) 

410 topk_ids: Expert indices, shape (num_tokens, topk) 

411 num_experts: Total number of experts 

412 

413 Returns: 

414 sorted_token_ids: Sorted token indices 

415 expert_ids: Expert index for each block 

416 num_tokens_post_padded: Total tokens after padding 

417 block_size_m: Block size for tokens 

418 """ 

419 num_tokens = x.shape[0] 

420 topk = topk_ids.shape[1] 

421 

422 # Flatten and prepare for MoE dispatch 

423 flat_topk_weights = topk_weights.view(-1) 

424 flat_topk_ids = topk_ids.view(-1) 

425 

426 # Create mapping from token to expert selection 

427 _, sorted_token_ids = torch.sort(flat_topk_weights, dim=0, descending=True) 

428 

429 # Get expert assignments 

430 expert_ids = flat_topk_ids[sorted_token_ids] 

431 

432 # Pad to block size 

433 block_size_m = 32 # Default block size 

434 num_tokens_post_padded = ( 

435 (num_tokens * topk + block_size_m - 1) // block_size_m 

436 ) * block_size_m 

437 

438 return sorted_token_ids, expert_ids, num_tokens_post_padded, block_size_m 

439 

440 

441def quantize_weights_moe( 

442 weights: torch.Tensor, 

443 num_experts: int, 

444 quant_config: QuantConfig, 

445) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: 

446 """ 

447 Quantize MoE expert weights. 

448 

449 Args: 

450 weights: Expert weights of shape (num_experts, out_features, in_features) 

451 num_experts: Number of experts 

452 quant_config: Quantization configuration 

453 

454 Returns: 

455 W_q: Quantized weights (same shape as input if int8, packed if int4) 

456 scales: Quantization scales of shape (num_experts, out_features, num_groups) 

457 zeros: Optional zero points of same shape as scales 

458 """ 

459 if quant_config.mode == QuantMode.FP16: 

460 return weights, None, None 

461 

462 num_experts_e, n_out, k_in = weights.shape 

463 num_groups = k_in // quant_config.group_size 

464 

465 if quant_config.use_int4: 

466 w_bits = 4 

467 else: 

468 w_bits = 8 

469 

470 # Reshape for per-group quantization along the last dimension 

471 # weights shape: (E, n_out, k_in) -> (E, n_out, num_groups, group_size) 

472 weights_reshaped = weights.view( 

473 num_experts, n_out, num_groups, quant_config.group_size 

474 ) 

475 w_min = weights_reshaped.min(dim=-1, keepdim=True)[0] 

476 w_max = weights_reshaped.max(dim=-1, keepdim=True)[0] 

477 scale = (w_max - w_min) / ((2**w_bits) - 1) 

478 scale = torch.where(scale > 0, scale, torch.ones_like(scale)) 

479 

480 # Quantize 

481 W_normalized = (weights_reshaped - w_min) / (scale + 1e-8) 

482 W_q = W_normalized.round().clamp(0, 2**w_bits - 1) 

483 W_q = W_q.to(torch.uint8) 

484 

485 # Reshape back - pack if int4 

486 if quant_config.use_int4: 

487 # Pack 2 int4 values per byte 

488 W_q = W_q.view(num_experts, n_out, num_groups, quant_config.group_size // 2, 2) 

489 W_q_packed = (W_q[..., 0] & 0xF) | (W_q[..., 1] << 4) 

490 W_q = W_q_packed.view(num_experts, n_out, -1) 

491 else: 

492 W_q = W_q.view(num_experts, n_out, -1) 

493 

494 # Scales shape: (num_experts, n_out, num_groups) 

495 scales = scale.squeeze(-1).view(num_experts, n_out, num_groups) 

496 

497 # Zero points if needed 

498 zeros = None 

499 if quant_config.has_zero_point: 

500 zeros = w_min.squeeze(-1).view(num_experts, n_out, num_groups) 

501 

502 return W_q, scales, zeros 

503 

504 

505def get_default_config(block_size_m=1, block_size_n=128, block_size_k=64): 

506 """Get default kernel configuration with reduced sizes for shared memory.""" 

507 return { 

508 "BLOCK_SIZE_M": block_size_m, 

509 "BLOCK_SIZE_N": block_size_n, 

510 "BLOCK_SIZE_K": block_size_k, 

511 } 

512 

513 

514def get_autotune_config(): 

515 """Get autotuning configurations for MoE kernel with reduced sizes for H20.""" 

516 return [ 

517 triton.Config( 

518 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_stages=2, num_warps=4 

519 ), 

520 triton.Config( 

521 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_stages=2, num_warps=4 

522 ), 

523 ] 

524 

525 

526# ============================================================================ 

527# Kernel Invocation 

528# ============================================================================ 

529 

530_fp16_intermediate_buf = None 

531 

532 

533def invoke_fused_moe( 

534 x: torch.Tensor, 

535 W1_q: torch.Tensor, 

536 W2_q: torch.Tensor, 

537 W3_q: Optional[torch.Tensor], 

538 output: torch.Tensor, 

539 W1_scales: torch.Tensor, 

540 W1_zeros: Optional[torch.Tensor], 

541 W2_scales: torch.Tensor, 

542 W2_zeros: Optional[torch.Tensor], 

543 W3_scales: Optional[torch.Tensor], 

544 W3_zeros: Optional[torch.Tensor], 

545 sorted_token_ids: torch.Tensor, 

546 expert_ids: torch.Tensor, 

547 num_tokens_post_padded: torch.Tensor, 

548 topk_weights: torch.Tensor, 

549 top_k: int, 

550 quant_config: Any, 

551 block_shape: List[int], 

552) -> None: 

553 """ 

554 Invoke the fused MoE kernel. 

555 FP16 mode uses a dedicated SwiGLU path; quantized modes use fused_moe_kernel_gptq_awq. 

556 """ 

557 num_tokens, hidden_dim = x.shape 

558 num_experts, inter_dim, _ = W1_q.shape 

559 num_valid_tokens = sorted_token_ids.shape[0] 

560 

561 K = hidden_dim 

562 N = inter_dim 

563 

564 if topk_weights.dim() > 1: 

565 topk_weights = topk_weights.view(-1) 

566 

567 BLOCK_SIZE_N = min(128, N) 

568 BLOCK_SIZE_K = min(64, K) 

569 grid = (num_valid_tokens,) 

570 

571 if not x.is_contiguous(): 

572 x = x.contiguous() 

573 

574 output.zero_() 

575 

576 # FP16 fast path — complete SwiGLU MoE: gate(W1) * up(W3), then W2 @ act 

577 if quant_config.mode.value == "fp16" and W2_q is not None: 

578 # FP16 SwiGLU mode requires all weights (W1, W2, optionally W3) 

579 inter_buf = torch.empty(num_valid_tokens * N, dtype=x.dtype, device=x.device) 

580 _W3 = W3_q if W3_q is not None else W1_q # use W1 if W3 missing 

581 

582 fused_moe_kernel_fp16_swiglu[grid]( 

583 x, 

584 output, 

585 W1_q, # gate 

586 _W3, # up 

587 W2_q, # down 

588 topk_weights, 

589 sorted_token_ids, 

590 expert_ids, 

591 num_tokens_post_padded, 

592 inter_buf, 

593 N=N, 

594 K=K, 

595 EM=num_valid_tokens, 

596 num_valid_tokens=num_valid_tokens, 

597 stride_am=x.stride(0), 

598 stride_ak=x.stride(1), 

599 stride_bn=W1_q.stride(1), 

600 stride_bk=W1_q.stride(2), 

601 stride_cm=output.stride(0), 

602 stride_cn=output.stride(1), 

603 stride_gate_e=W1_q.stride(0), 

604 stride_up_e=_W3.stride(0), 

605 stride_down_e=W2_q.stride(0), 

606 stride_gate_n=W1_q.stride(1), 

607 stride_gate_k=W1_q.stride(2), 

608 stride_up_n=_W3.stride(1), 

609 stride_up_k=_W3.stride(2), 

610 stride_down_k=W2_q.stride(1), 

611 stride_down_n=W2_q.stride(2), 

612 stride_inter_m=N, 

613 BLOCK_SIZE_K=BLOCK_SIZE_K, 

614 top_k=top_k, 

615 even_Ks=(K % BLOCK_SIZE_K) == 0, 

616 ) 

617 return 

618 

619 # FP16 W1-only: use vectorized torch.mm as the reference implementation 

620 # This is called when FP16 mode with W2_q=None reaches this function 

621 # (weights were not quantized, so W1_scales is None) 

622 if quant_config.mode.value == "fp16" and W2_q is None: 

623 num_experts = W1_q.shape[0] 

624 

625 # topk_weights is already flattened at this point 

626 # Vectorized approach: process each expert in batch using torch.matmul 

627 for e in range(num_experts): 

628 # Find all dispatch entries for expert e 

629 mask = expert_ids == e 

630 if not mask.any(): 

631 continue 

632 

633 indices = mask.nonzero(as_tuple=True)[0] 

634 # Bounds check for padding 

635 valid_mask = indices < num_valid_tokens 

636 indices = indices[valid_mask] 

637 

638 # Skip if no valid entries 

639 if indices.numel() == 0: 

640 continue 

641 

642 # Get token indices and weights 

643 token_indices = sorted_token_ids[indices] 

644 weights_e = topk_weights[indices] 

645 

646 # Batch compute: W1[e] @ x[token_indices].T 

647 # W1[e]: [n_out, k_in], x_e: [num_selections, k_in] 

648 # Result: [n_out, num_selections] 

649 x_e = x[token_indices] # [num_selections, k_in] 

650 result = torch.matmul(W1_q[e], x_e.t()) # [n_out, num_selections] 

651 

652 # Apply weights and transpose: result.T * weights 

653 # result.T: [num_selections, n_out], weights: [num_selections] 

654 result = result.t() * weights_e.unsqueeze(1) # [num_selections, n_out] 

655 

656 # Use index_add for efficient accumulation (avoids Python loop) 

657 output.index_add_(0, token_indices, result) 

658 

659 return 

660 

661 # Quantized path (W8A16 / W4A16) OR FP16 W1-only path 

662 # W2_q is None means W1-only projection (quantized or FP16) 

663 if W2_q is None: 

664 # Determine if we should skip dequantization (FP16 mode with unit scales) 

665 is_fp16_w1_only = ( 

666 quant_config.mode.value == "fp16" 

667 and W1_q is not None 

668 and W1_scales is not None 

669 and W1_zeros is None 

670 ) 

671 

672 # For FP16 W1-only: skip INT8 offset (use_int8_w8a16=False) 

673 # For quantized modes: use appropriate dequantization 

674 kernel_use_int8 = quant_config.use_int8 and not is_fp16_w1_only 

675 kernel_has_zp = quant_config.has_zero_point and not is_fp16_w1_only 

676 

677 # W1-only quantization path 

678 fused_moe_kernel_gptq_awq[grid]( 

679 x, 

680 W1_q, 

681 output, 

682 W1_scales, 

683 W1_zeros if W1_zeros is not None else x.new_tensor([]), 

684 topk_weights, 

685 sorted_token_ids, 

686 expert_ids, 

687 num_tokens_post_padded, 

688 N=N, 

689 K=K, 

690 EM=num_valid_tokens, 

691 num_valid_tokens=num_valid_tokens, 

692 stride_am=x.stride(0), 

693 stride_ak=x.stride(1), 

694 stride_be=W1_q.stride(0), 

695 stride_bk=W1_q.stride(2), 

696 stride_bn=W1_q.stride(1), 

697 stride_cm=output.stride(0), 

698 stride_cn=output.stride(1), 

699 stride_bse=W1_scales.stride(0), 

700 stride_bsk=W1_scales.stride(2), 

701 stride_bsn=W1_scales.stride(1), 

702 stride_bze=W1_zeros.stride(0) if W1_zeros is not None else 0, 

703 stride_bzk=W1_zeros.stride(2) if W1_zeros is not None else 0, 

704 stride_bzn=W1_zeros.stride(1) if W1_zeros is not None else 0, 

705 group_size=quant_config.group_size, 

706 BLOCK_SIZE_M=1, 

707 BLOCK_SIZE_N=BLOCK_SIZE_N, 

708 BLOCK_SIZE_K=BLOCK_SIZE_K, 

709 GROUP_SIZE_M=1, 

710 MUL_ROUTED_WEIGHT=True, 

711 top_k=top_k, 

712 compute_type=tl.float16, 

713 has_zp=kernel_has_zp, 

714 use_int4_w4a16=quant_config.use_int4, 

715 use_int8_w8a16=kernel_use_int8, 

716 even_Ks=(K % BLOCK_SIZE_K) == 0, 

717 filter_expert=False, 

718 ) 

719 else: 

720 # W1 + W2 quantization path (SwiGLU) 

721 fused_moe_kernel_gptq_awq[grid]( 

722 x, 

723 W1_q, 

724 output, 

725 W1_scales, 

726 W1_zeros if W1_zeros is not None else x.new_tensor([]), 

727 topk_weights, 

728 sorted_token_ids, 

729 expert_ids, 

730 num_tokens_post_padded, 

731 N=N, 

732 K=K, 

733 EM=num_valid_tokens, 

734 num_valid_tokens=num_valid_tokens, 

735 stride_am=x.stride(0), 

736 stride_ak=x.stride(1), 

737 stride_be=W1_q.stride(0), 

738 stride_bk=W1_q.stride(2), 

739 stride_bn=W1_q.stride(1), 

740 stride_cm=output.stride(0), 

741 stride_cn=output.stride(1), 

742 stride_bse=W1_scales.stride(0), 

743 stride_bsk=W1_scales.stride(2), 

744 stride_bsn=W1_scales.stride(1), 

745 stride_bze=W1_zeros.stride(0) if W1_zeros is not None else 0, 

746 stride_bzk=W1_zeros.stride(2) if W1_zeros is not None else 0, 

747 stride_bzn=W1_zeros.stride(1) if W1_zeros is not None else 0, 

748 group_size=quant_config.group_size, 

749 BLOCK_SIZE_M=1, 

750 BLOCK_SIZE_N=BLOCK_SIZE_N, 

751 BLOCK_SIZE_K=BLOCK_SIZE_K, 

752 GROUP_SIZE_M=1, 

753 MUL_ROUTED_WEIGHT=True, 

754 top_k=top_k, 

755 compute_type=tl.float16, 

756 has_zp=quant_config.has_zero_point, 

757 use_int4_w4a16=quant_config.use_int4, 

758 use_int8_w8a16=quant_config.use_int8, 

759 even_Ks=(K % BLOCK_SIZE_K) == 0, 

760 filter_expert=False, 

761 ) 

762 

763 

764# ============================================================================ 

765# Main fused_moe Function 

766# ============================================================================ 

767 

768 

769def fused_moe( 

770 x: torch.Tensor, 

771 w1: torch.Tensor, 

772 w2: torch.Tensor, 

773 w3: Optional[torch.Tensor] = None, 

774 topk_weights: Optional[torch.Tensor] = None, 

775 topk_ids: Optional[torch.Tensor] = None, 

776 quant_config: QuantConfig = None, 

777 num_experts: int = 8, 

778 top_k: int = 2, 

779 block_shape: Optional[List[int]] = None, 

780 # Optional pre-quantized weights (from benchmark) 

781 w1_q: Optional[torch.Tensor] = None, 

782 w1_scales: Optional[torch.Tensor] = None, 

783 w1_zeros: Optional[torch.Tensor] = None, 

784 w2_q: Optional[torch.Tensor] = None, 

785 w2_scales: Optional[torch.Tensor] = None, 

786 w2_zeros: Optional[torch.Tensor] = None, 

787 w3_q: Optional[torch.Tensor] = None, 

788 w3_scales: Optional[torch.Tensor] = None, 

789 w3_zeros: Optional[torch.Tensor] = None, 

790) -> torch.Tensor: 

791 """ 

792 Fused Mixture of Experts computation with quantization support. 

793 

794 This implements: 

795 y = sum_i(topk_weights_i * FFN(experts_i(topk_ids_i))) 

796 

797 For SwiGLU MoE: 

798 FFN(x) = Gate(x) * Up(x) = (silu(W1(x)) * W3(x)) @ W2 

799 

800 Args: 

801 x: Input tensor of shape (batch_size, seq_len, hidden_dim) or (num_tokens, hidden_dim) 

802 w1: First FFN layer weights (FP16) or can be pre-quantized (uint8) 

803 w2: Second FFN layer weights (FP16) or can be pre-quantized (uint8) 

804 w3: Optional gate weights for SwiGLU, shape (num_experts, hidden_dim, inter_dim) 

805 topk_weights: Weights for top-k experts, shape (batch_size, seq_len, top_k) 

806 topk_ids: Expert indices, shape (batch_size, seq_len, top_k) 

807 quant_config: Quantization configuration 

808 num_experts: Number of experts 

809 top_k: Number of experts to select 

810 block_shape: Block shape for block-wise quantization [block_n, block_k] 

811 # Pre-quantized weights (if provided, skips quantization) 

812 w1_q, w1_scales, w1_zeros: Pre-quantized W1 weights 

813 w2_q, w2_scales, w2_zeros: Pre-quantized W2 weights 

814 w3_q, w3_scales, w3_zeros: Pre-quantized W3 weights 

815 

816 Returns: 

817 Output tensor of same shape as x 

818 """ 

819 if quant_config is None: 

820 quant_config = QuantConfig() 

821 

822 # Handle input shape 

823 original_shape = x.shape 

824 if len(x.shape) == 3: 

825 x = x.view(-1, x.shape[-1]) # (B*S, H) 

826 

827 num_tokens = x.shape[0] 

828 

829 # Prepare routing information 

830 if topk_weights is None or topk_ids is None: 

831 # Create dummy routing for testing 

832 topk_weights = ( 

833 torch.ones(num_tokens, top_k, device=x.device, dtype=x.dtype) / top_k 

834 ) 

835 topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=x.device) 

836 

837 # Create dispatch arrays for MoE 

838 # Each token has top_k expert selections, create entries for each (token, expert) pair 

839 # sorted_token_ids: token index for each dispatch entry (repeated for each expert selection) 

840 # expert_ids: expert index for each dispatch entry 

841 

842 # Create token indices: [0,0,1,1,...] where each token repeats top_k times 

843 token_indices = torch.arange(num_tokens, device=x.device, dtype=torch.int64) 

844 sorted_token_ids = ( 

845 token_indices.unsqueeze(1).expand(num_tokens, top_k).contiguous().view(-1) 

846 ) 

847 

848 # Expert IDs: [e0_0, e0_1, ..., e1_0, e1_1, ...] 

849 flat_expert_ids = topk_ids.view(-1) 

850 

851 # Weights: [w0_0, w0_1, ..., w1_0, w1_1, ...] 

852 flat_weights = topk_weights.view(-1) 

853 

854 # Sort by weight for efficient processing (optional, helps with cache locality) 

855 sorted_indices = torch.argsort(flat_weights, dim=0, descending=True) 

856 sorted_token_ids = sorted_token_ids[sorted_indices] 

857 sorted_expert_ids = flat_expert_ids[sorted_indices] 

858 sorted_weights = flat_weights[sorted_indices] 

859 

860 # Pad to block size 

861 block_size_m = 32 

862 num_tokens_post_padded = ( 

863 (num_tokens * top_k + block_size_m - 1) // block_size_m 

864 ) * block_size_m 

865 

866 # Quantize weights if not pre-quantized 

867 if w1_q is not None and w1_scales is not None: 

868 # Use pre-quantized weights from benchmark 

869 W1_q = w1_q.contiguous() 

870 W1_scales = w1_scales.contiguous() 

871 W1_zeros = w1_zeros.contiguous() if w1_zeros is not None else None 

872 elif w1 is not None: 

873 W1_q, W1_scales, W1_zeros = quantize_weights_moe(w1, num_experts, quant_config) 

874 else: 

875 raise ValueError("Either w1 or w1_q must be provided") 

876 

877 if w2_q is not None and w2_scales is not None: 

878 W2_q = w2_q.contiguous() 

879 W2_scales = w2_scales.contiguous() 

880 W2_zeros = w2_zeros.contiguous() if w2_zeros is not None else None 

881 elif w2 is not None: 

882 W2_q, W2_scales, W2_zeros = quantize_weights_moe(w2, num_experts, quant_config) 

883 else: 

884 # W2 not provided, set to None for W1-only projection 

885 W2_q = None 

886 W2_scales = None 

887 W2_zeros = None 

888 

889 if w3 is not None: 

890 if w3_q is not None and w3_scales is not None: 

891 W3_q = w3_q.contiguous() 

892 W3_scales = w3_scales.contiguous() 

893 W3_zeros = w3_zeros.contiguous() if w3_zeros is not None else None 

894 else: 

895 W3_q, W3_scales, W3_zeros = quantize_weights_moe( 

896 w3, num_experts, quant_config 

897 ) 

898 else: 

899 W3_q, W3_scales, W3_zeros = None, None, None 

900 

901 # For FP16 W1-only mode, the weights are not quantized (quantize returns them as-is) 

902 # W1_scales will be None, so invoke_fused_moe handles this case directly 

903 # No need to create fake scales here 

904 

905 # Allocate output 

906 # For W1-only projection (W2_q is None): output shape is (num_tokens, inter_dim) 

907 # For SwiGLU (W2_q is not None): output shape is same as input (num_tokens, hidden_dim) 

908 if W2_q is None and W1_q is not None: 

909 # W1-only projection: output is (num_tokens, inter_dim) 

910 num_experts_e, n_out, k_in = W1_q.shape 

911 output = torch.zeros(num_tokens, n_out, dtype=x.dtype, device=x.device) 

912 else: 

913 output = torch.zeros_like(x) 

914 

915 # Default block shape 

916 if block_shape is None: 

917 block_shape = [128, 128] 

918 

919 # Invoke fused MoE kernel 

920 invoke_fused_moe( 

921 x, 

922 W1_q, 

923 W2_q, 

924 W3_q, 

925 output, 

926 W1_scales, 

927 W1_zeros, 

928 W2_scales, 

929 W2_zeros, 

930 W3_scales, 

931 W3_zeros, 

932 sorted_token_ids, 

933 sorted_expert_ids, 

934 num_tokens_post_padded, 

935 sorted_weights, 

936 top_k, 

937 quant_config, 

938 block_shape, 

939 ) 

940 

941 # Reshape output 

942 if len(original_shape) == 3: 

943 output = output.view(original_shape) 

944 

945 return output 

946 

947 

948# ============================================================================ 

949# FusedMoELinear Module 

950# ============================================================================ 

951 

952 

953class FusedMoELinear(torch.nn.Module): 

954 """ 

955 Fused MoE Linear layer with quantization support. 

956 

957 This module wraps the fused MoE computation for use in neural networks. 

958 """ 

959 

960 def __init__( 

961 self, 

962 hidden_dim: int, 

963 inter_dim: int, 

964 num_experts: int = 8, 

965 top_k: int = 2, 

966 quant_config: QuantConfig = None, 

967 bias: bool = False, 

968 ): 

969 super().__init__() 

970 

971 self.hidden_dim = hidden_dim 

972 self.inter_dim = inter_dim 

973 self.num_experts = num_experts 

974 self.top_k = top_k 

975 self.quant_config = quant_config or QuantConfig() 

976 

977 # SwiGLU MoE weights 

978 self.w1 = torch.nn.Parameter( 

979 torch.randn(num_experts, inter_dim, hidden_dim, requires_grad=False) 

980 ) 

981 self.w3 = torch.nn.Parameter( 

982 torch.randn(num_experts, inter_dim, hidden_dim, requires_grad=False) 

983 ) 

984 self.w2 = torch.nn.Parameter( 

985 torch.randn(num_experts, hidden_dim, inter_dim, requires_grad=False) 

986 ) 

987 

988 self._packed = False 

989 

990 def pack(self): 

991 """Prepare weights for quantized computation.""" 

992 self.W1_q, self.W1_scales, self.W1_zeros = quantize_weights_moe( 

993 self.w1.data, self.num_experts, self.quant_config 

994 ) 

995 self.W3_q, self.W3_scales, self.W3_zeros = quantize_weights_moe( 

996 self.w3.data, self.num_experts, self.quant_config 

997 ) 

998 self.W2_q, self.W2_scales, self.W2_zeros = quantize_weights_moe( 

999 self.w2.data, self.num_experts, self.quant_config 

1000 ) 

1001 self._packed = True 

1002 

1003 def forward( 

1004 self, 

1005 x: torch.Tensor, 

1006 topk_weights: Optional[torch.Tensor] = None, 

1007 topk_ids: Optional[torch.Tensor] = None, 

1008 ) -> torch.Tensor: 

1009 """ 

1010 Forward pass for MoE. 

1011 

1012 Args: 

1013 x: Input tensor (B, S, H) or (T, H) 

1014 topk_weights: Expert weights (B, S, K) or (T, K) 

1015 topk_ids: Expert indices (B, S, K) or (T, K) 

1016 

1017 Returns: 

1018 Output tensor same shape as x 

1019 """ 

1020 if not self._packed: 

1021 self.pack() 

1022 

1023 return fused_moe( 

1024 x, 

1025 self.w1, 

1026 self.w2, 

1027 self.w3, 

1028 topk_weights, 

1029 topk_ids, 

1030 self.quant_config, 

1031 self.num_experts, 

1032 self.top_k, 

1033 ) 

1034 

1035 def set_weights(self, w1: torch.Tensor, w3: torch.Tensor, w2: torch.Tensor): 

1036 """Set weights from external source (e.g., model loading).""" 

1037 self.w1.data = w1 

1038 self.w3.data = w3 

1039 self.w2.data = w2 

1040 self._packed = False 

1041 

1042 

1043# ============================================================================ 

1044# Exports 

1045# ============================================================================ 

1046 

1047__all__ = [ 

1048 "fused_moe", 

1049 "fused_moe_kernel_gptq_awq", 

1050 "fused_moe_kernel_fp16_swiglu", 

1051 "invoke_fused_moe", 

1052 "FusedMoELinear", 

1053 "QuantConfig", 

1054 "QuantMode", 

1055 "quantize_weights_moe", 

1056 "prepare_moe_inputs", 

1057 "get_num_experts", 

1058 "get_default_config", 

1059 "get_autotune_config", 

1060]