Coverage for src/flag_gems/fused/fused_marlin_moe.py: 14%

263 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# SPDX-License-Identifier: Apache-2.0 

2""" 

3Fused Marlin MoE — v7: Tunable BLOCK_SIZE_K for improved pipelining. 

4 

5Key changes from v6: 

6- BLOCK_SIZE_K is now an autotune parameter (32, 64, 128) instead of fixed at 

7 group_size=128. Smaller K tiles enable more software pipeline stages and 

8 reduce register pressure, improving memory latency hiding for bandwidth-bound 

9 small batch sizes. 

10- GROUP_SIZE_K constexpr correctly indexes scales when BLOCK_SIZE_K < group_size. 

11 Math: accumulating partial sums within a scale group gives identical results. 

12- Transposed B layout [E, K//2, N] from v6 is preserved for coalesced N-loads. 

13- Two-pass GEMM1 (gate/up) with fused SiLU preserved from v6. 

14""" 

15 

16from typing import Any, Callable, Optional 

17 

18import torch 

19import triton 

20import triton.language as tl 

21 

22from flag_gems.fused.fused_moe import write_zeros_to_output 

23from flag_gems.fused.moe_align_block_size import moe_align_block_size 

24from flag_gems.fused.moe_sum import moe_sum 

25 

26QUANT_TYPE_UINT4B8 = 0 

27QUANT_TYPE_UINT8B128 = 1 

28_QUANT_TYPE_INT4 = {QUANT_TYPE_UINT4B8} 

29_QUANT_TYPE_INT8 = {QUANT_TYPE_UINT8B128} 

30_SUPPORTED_QUANT_TYPES = _QUANT_TYPE_INT4 | _QUANT_TYPE_INT8 

31 

32 

33# ---------- Transpose cache ---------- 

34 

35_B_CACHE: dict = {} 

36_SCALE_CACHE: dict = {} 

37 

38 

39def _transpose_b(b: torch.Tensor) -> torch.Tensor: 

40 """Transpose B from [E, N, K//2] to [E, K//2, N] for coalesced N-loads.""" 

41 key = (b.data_ptr(), b.shape[0], b.shape[1], b.shape[2]) 

42 cached = _B_CACHE.get(key) 

43 if cached is not None: 

44 return cached 

45 bt = b.transpose(1, 2).contiguous() 

46 _B_CACHE[key] = bt 

47 return bt 

48 

49 

50def _transpose_scale(s: torch.Tensor) -> torch.Tensor: 

51 """Transpose scale from [E, N, K//gs] to [E, K//gs, N] for coalesced loads.""" 

52 key = (s.data_ptr(), s.shape[0], s.shape[1], s.shape[2]) 

53 cached = _SCALE_CACHE.get(key) 

54 if cached is not None: 

55 return cached 

56 st = s.transpose(1, 2).contiguous() 

57 _SCALE_CACHE[key] = st 

58 return st 

59 

60 

61# ---------- Autotune configs ---------- 

62 

63_AUTOTUNE_CONFIGS = [ 

64 # BLOCK_SIZE_K=128: compute-bound regime (large M) 

65 triton.Config( 

66 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1}, 

67 num_warps=4, 

68 num_stages=4, 

69 ), 

70 triton.Config( 

71 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1}, 

72 num_warps=4, 

73 num_stages=3, 

74 ), 

75 triton.Config( 

76 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4}, 

77 num_warps=8, 

78 num_stages=3, 

79 ), 

80 triton.Config( 

81 {"BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4}, 

82 num_warps=8, 

83 num_stages=2, 

84 ), 

85 # BLOCK_SIZE_K=64: balanced pipelining, reduced register pressure 

86 triton.Config( 

87 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1}, 

88 num_warps=4, 

89 num_stages=5, 

90 ), 

91 triton.Config( 

92 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1}, 

93 num_warps=4, 

94 num_stages=5, 

95 ), 

96 triton.Config( 

97 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4}, 

98 num_warps=8, 

99 num_stages=4, 

100 ), 

101 # BLOCK_SIZE_K=32: max pipelining for bandwidth-bound small batches 

102 triton.Config( 

103 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1}, 

104 num_warps=4, 

105 num_stages=8, 

106 ), 

107 triton.Config( 

108 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1}, 

109 num_warps=4, 

110 num_stages=8, 

111 ), 

112 triton.Config( 

113 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4}, 

114 num_warps=8, 

115 num_stages=6, 

116 ), 

117] 

118 

119 

120def _select_block_m(M, E, top_k): 

121 avg_tokens = max(M * top_k / max(E, 1), 1) 

122 if avg_tokens <= 4: 

123 return 16 

124 elif avg_tokens <= 32: 

125 return 32 

126 else: 

127 return 64 

128 

129 

130# ---------- GEMM1: two-pass gate/up with fused SiLU ---------- 

131# B layout: [E, K//2, N] (transposed), stride_bk=N, stride_bn=1 

132 

133 

134@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["N", "K"]) 

135@triton.jit 

136def _int4_gemm_silu_kernel( 

137 a_ptr, 

138 b_ptr, 

139 c_ptr, 

140 b_scale_ptr, 

141 topk_weights_ptr, 

142 sorted_token_ids_ptr, 

143 expert_ids_ptr, 

144 num_tokens_post_padded_ptr, 

145 N: tl.constexpr, 

146 K: tl.constexpr, 

147 EM, 

148 num_valid_tokens, 

149 stride_am, 

150 stride_ak, 

151 stride_be, 

152 stride_bk, 

153 stride_bn, 

154 stride_cm, 

155 stride_cn, 

156 stride_bse, 

157 stride_bsk, 

158 stride_bsn, 

159 BLOCK_SIZE_M: tl.constexpr, 

160 BLOCK_SIZE_N: tl.constexpr, 

161 BLOCK_SIZE_K: tl.constexpr, 

162 GROUP_SIZE_M: tl.constexpr, 

163 GROUP_SIZE_K: tl.constexpr, 

164 MUL_ROUTED_WEIGHT: tl.constexpr, 

165 top_k: tl.constexpr, 

166 compute_type: tl.constexpr, 

167): 

168 N_out = N // 2 

169 

170 pid = tl.program_id(axis=0) 

171 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

172 num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N) 

173 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

174 group_id = pid // num_pid_in_group 

175 first_pid_m = group_id * GROUP_SIZE_M 

176 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

177 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

178 pid_n = (pid % num_pid_in_group) // group_size_m 

179 

180 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

181 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

182 return 

183 

184 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

185 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

186 token_mask = offs_token < num_valid_tokens 

187 

188 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

189 if off_experts == -1: 

190 write_zeros_to_output( 

191 c_ptr, 

192 stride_cm, 

193 stride_cn, 

194 pid_n, 

195 N_out, 

196 offs_token, 

197 token_mask, 

198 BLOCK_SIZE_M, 

199 BLOCK_SIZE_N, 

200 compute_type, 

201 ) 

202 return 

203 

204 offs_bn_gate = ( 

205 pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) 

206 ) % N_out 

207 offs_bn_up = offs_bn_gate + N_out 

208 offs_k = tl.arange(0, BLOCK_SIZE_K) 

209 

210 a_base = a_ptr + (offs_token[:, None] // top_k * stride_am) 

211 b_expert_base = b_ptr + off_experts * stride_be 

212 b_shifter = (offs_k[:, None] % 2) * 4 

213 

214 # B is transposed: [E, K//2, N], stride_bk=N (between packed K), stride_bn=1 (N contiguous) 

215 b_ptrs_gate = ( 

216 b_expert_base 

217 + (offs_k[:, None] // 2) * stride_bk 

218 + offs_bn_gate[None, :] * stride_bn 

219 ) 

220 b_ptrs_up = ( 

221 b_expert_base 

222 + (offs_k[:, None] // 2) * stride_bk 

223 + offs_bn_up[None, :] * stride_bn 

224 ) 

225 

226 # Scale is transposed: [E, K//gs, N], stride_bsk=N, stride_bsn=1 

227 scale_base_gate = b_scale_ptr + off_experts * stride_bse + offs_bn_gate * stride_bsn 

228 scale_base_up = b_scale_ptr + off_experts * stride_bse + offs_bn_up * stride_bsn 

229 

230 # ---- Pass 1: Gate projection ---- 

231 a_ptrs = a_base + offs_k[None, :] * stride_ak 

232 acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

233 

234 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

235 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

236 b_g = ((tl.load(b_ptrs_gate) >> b_shifter) & 0xF).to(compute_type) 

237 raw_dot = tl.dot(a, b_g) 

238 row_sum = tl.sum(a.to(tl.float32), axis=1) 

239 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K 

240 scale_g = tl.load(scale_base_gate + scale_idx * stride_bsk).to(tl.float32) 

241 acc_gate += scale_g[None, :] * (raw_dot - 8.0 * row_sum[:, None]) 

242 

243 a_ptrs += BLOCK_SIZE_K * stride_ak 

244 b_ptrs_gate += (BLOCK_SIZE_K // 2) * stride_bk 

245 

246 # ---- Pass 2: Up projection ---- 

247 a_ptrs = a_base + offs_k[None, :] * stride_ak 

248 acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

249 

250 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

251 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

252 b_u = ((tl.load(b_ptrs_up) >> b_shifter) & 0xF).to(compute_type) 

253 raw_dot = tl.dot(a, b_u) 

254 row_sum = tl.sum(a.to(tl.float32), axis=1) 

255 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K 

256 scale_u = tl.load(scale_base_up + scale_idx * stride_bsk).to(tl.float32) 

257 acc_up += scale_u[None, :] * (raw_dot - 8.0 * row_sum[:, None]) 

258 

259 a_ptrs += BLOCK_SIZE_K * stride_ak 

260 b_ptrs_up += (BLOCK_SIZE_K // 2) * stride_bk 

261 

262 # ---- Fused SiLU: silu(gate) * up ---- 

263 accumulator = tl.fdiv(acc_gate, (1.0 + tl.exp(-acc_gate))) * acc_up 

264 

265 if MUL_ROUTED_WEIGHT: 

266 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

267 accumulator = accumulator * moe_weight[:, None] 

268 

269 accumulator = accumulator.to(compute_type) 

270 

271 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

272 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

273 c_mask = token_mask[:, None] & (offs_cn[None, :] < N_out) 

274 tl.store(c_ptrs, accumulator, mask=c_mask) 

275 

276 

277# ---------- GEMM2: standard INT4 GEMM with factored zero-point ---------- 

278# B layout: [E, K//2, N] (transposed), stride_bk=N, stride_bn=1 

279 

280 

281@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["N", "K"]) 

282@triton.jit 

283def _int4_gemm_kernel( 

284 a_ptr, 

285 b_ptr, 

286 c_ptr, 

287 b_scale_ptr, 

288 topk_weights_ptr, 

289 sorted_token_ids_ptr, 

290 expert_ids_ptr, 

291 num_tokens_post_padded_ptr, 

292 N: tl.constexpr, 

293 K: tl.constexpr, 

294 EM, 

295 num_valid_tokens, 

296 stride_am, 

297 stride_ak, 

298 stride_be, 

299 stride_bk, 

300 stride_bn, 

301 stride_cm, 

302 stride_cn, 

303 stride_bse, 

304 stride_bsk, 

305 stride_bsn, 

306 BLOCK_SIZE_M: tl.constexpr, 

307 BLOCK_SIZE_N: tl.constexpr, 

308 BLOCK_SIZE_K: tl.constexpr, 

309 GROUP_SIZE_M: tl.constexpr, 

310 GROUP_SIZE_K: tl.constexpr, 

311 MUL_ROUTED_WEIGHT: tl.constexpr, 

312 top_k: tl.constexpr, 

313 compute_type: tl.constexpr, 

314): 

315 pid = tl.program_id(axis=0) 

316 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) 

317 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

318 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

319 group_id = pid // num_pid_in_group 

320 first_pid_m = group_id * GROUP_SIZE_M 

321 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

322 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

323 pid_n = (pid % num_pid_in_group) // group_size_m 

324 

325 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) 

326 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: 

327 return 

328 

329 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) 

330 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64) 

331 token_mask = offs_token < num_valid_tokens 

332 

333 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) 

334 if off_experts == -1: 

335 write_zeros_to_output( 

336 c_ptr, 

337 stride_cm, 

338 stride_cn, 

339 pid_n, 

340 N, 

341 offs_token, 

342 token_mask, 

343 BLOCK_SIZE_M, 

344 BLOCK_SIZE_N, 

345 compute_type, 

346 ) 

347 return 

348 

349 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N 

350 offs_k = tl.arange(0, BLOCK_SIZE_K) 

351 

352 a_ptrs = a_ptr + ( 

353 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak 

354 ) 

355 # B transposed: [E, K//2, N], stride_bk=N, stride_bn=1 

356 b_ptrs = ( 

357 b_ptr 

358 + off_experts * stride_be 

359 + (offs_k[:, None] // 2) * stride_bk 

360 + offs_bn[None, :] * stride_bn 

361 ) 

362 b_shifter = (offs_k[:, None] % 2) * 4 

363 scale_base = b_scale_ptr + off_experts * stride_bse + offs_bn * stride_bsn 

364 

365 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

366 

367 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

368 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) 

369 b_int = ((tl.load(b_ptrs) >> b_shifter) & 0xF).to(compute_type) 

370 raw_dot = tl.dot(a, b_int) 

371 row_sum = tl.sum(a.to(tl.float32), axis=1) 

372 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K 

373 scale = tl.load(scale_base + scale_idx * stride_bsk).to(tl.float32) 

374 accumulator += scale[None, :] * (raw_dot - 8.0 * row_sum[:, None]) 

375 

376 a_ptrs += BLOCK_SIZE_K * stride_ak 

377 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk 

378 

379 if MUL_ROUTED_WEIGHT: 

380 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) 

381 accumulator = accumulator * moe_weight[:, None] 

382 

383 accumulator = accumulator.to(compute_type) 

384 

385 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

386 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] 

387 c_mask = token_mask[:, None] & (offs_cn[None, :] < N) 

388 tl.store(c_ptrs, accumulator, mask=c_mask) 

389 

390 

391# ---------- Launch wrappers ---------- 

392 

393 

394def _invoke_gemm1_silu( 

395 A, 

396 B, 

397 C, 

398 B_scale, 

399 topk_weights, 

400 sorted_token_ids, 

401 expert_ids, 

402 num_tokens_post_padded, 

403 mul_routed_weight, 

404 top_k, 

405 block_m, 

406 group_size, 

407 compute_type, 

408): 

409 # B is transposed: [E, K//2, N] 

410 N = B.size(2) # N is now dim 2 

411 K = A.size(1) 

412 N_out = N // 2 

413 M = A.size(0) 

414 

415 EM = sorted_token_ids.size(0) 

416 if M < block_m: 

417 EM = min(EM, M * top_k * block_m) 

418 

419 grid = lambda META: ( 

420 triton.cdiv(EM, META["BLOCK_SIZE_M"]) 

421 * triton.cdiv(N_out, META["BLOCK_SIZE_N"]), 

422 ) 

423 

424 _int4_gemm_silu_kernel[grid]( 

425 A, 

426 B, 

427 C, 

428 B_scale, 

429 topk_weights, 

430 sorted_token_ids, 

431 expert_ids, 

432 num_tokens_post_padded, 

433 N, 

434 K, 

435 EM, 

436 M * top_k, 

437 A.stride(0), 

438 A.stride(1), 

439 # B transposed [E, K//2, N]: stride(0)=expert, stride(1)=K, stride(2)=N 

440 B.stride(0), 

441 B.stride(1), 

442 B.stride(2), 

443 C.stride(1), 

444 C.stride(2), 

445 # B_scale transposed [E, K//gs, N]: stride(0)=expert, stride(1)=K, stride(2)=N 

446 B_scale.stride(0), 

447 B_scale.stride(1), 

448 B_scale.stride(2), 

449 BLOCK_SIZE_M=block_m, 

450 GROUP_SIZE_K=group_size, 

451 MUL_ROUTED_WEIGHT=mul_routed_weight, 

452 top_k=top_k, 

453 compute_type=compute_type, 

454 ) 

455 

456 

457def _invoke_gemm2( 

458 A, 

459 B, 

460 C, 

461 B_scale, 

462 topk_weights, 

463 sorted_token_ids, 

464 expert_ids, 

465 num_tokens_post_padded, 

466 mul_routed_weight, 

467 top_k, 

468 block_m, 

469 group_size, 

470 compute_type, 

471): 

472 # B is transposed: [E, K//2, N] 

473 N = B.size(2) # N is now dim 2 

474 K = A.size(1) 

475 M = A.size(0) 

476 

477 EM = sorted_token_ids.size(0) 

478 if M < block_m: 

479 EM = min(EM, M * top_k * block_m) 

480 

481 grid = lambda META: ( 

482 triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

483 ) 

484 

485 _int4_gemm_kernel[grid]( 

486 A, 

487 B, 

488 C, 

489 B_scale, 

490 topk_weights, 

491 sorted_token_ids, 

492 expert_ids, 

493 num_tokens_post_padded, 

494 N, 

495 K, 

496 EM, 

497 M * top_k, 

498 A.stride(0), 

499 A.stride(1), 

500 B.stride(0), 

501 B.stride(1), 

502 B.stride(2), 

503 C.stride(1), 

504 C.stride(2), 

505 B_scale.stride(0), 

506 B_scale.stride(1), 

507 B_scale.stride(2), 

508 BLOCK_SIZE_M=block_m, 

509 GROUP_SIZE_K=group_size, 

510 MUL_ROUTED_WEIGHT=mul_routed_weight, 

511 top_k=top_k, 

512 compute_type=compute_type, 

513 ) 

514 

515 

516# ---------- Implementation ---------- 

517 

518 

519def _fused_marlin_moe_impl( 

520 hidden_states: torch.Tensor, 

521 w1: torch.Tensor, 

522 w2: torch.Tensor, 

523 topk_weights: torch.Tensor, 

524 topk_ids: torch.Tensor, 

525 inplace: bool = False, 

526 activation: str = "silu", 

527 apply_router_weight_on_input: bool = False, 

528 use_int8_w8a16: bool = False, 

529 use_int4_w4a16: bool = False, 

530 per_channel_quant: bool = False, 

531 global_num_experts: int = -1, 

532 expert_map: torch.Tensor | None = None, 

533 w1_scale: Optional[torch.Tensor] = None, 

534 w2_scale: Optional[torch.Tensor] = None, 

535 w1_zp: torch.Tensor | None = None, 

536 w2_zp: torch.Tensor | None = None, 

537 block_shape: Optional[list[int]] = None, 

538 w1_bias: Optional[torch.Tensor] = None, 

539 w2_bias: Optional[torch.Tensor] = None, 

540) -> torch.Tensor: 

541 assert activation == "silu" 

542 assert use_int4_w4a16 

543 assert w1_zp is None and w2_zp is None 

544 

545 expected_packed_k = hidden_states.size(1) // 2 

546 assert w1.size(2) == expected_packed_k 

547 assert topk_weights.size() == topk_ids.size() 

548 assert hidden_states.is_contiguous() 

549 assert w1.stride(-1) == 1 

550 assert w2.stride(-1) == 1 

551 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

552 

553 num_tokens = hidden_states.size(0) 

554 E, N, _ = w1.size() 

555 K = w2.size(1) 

556 if global_num_experts == -1: 

557 global_num_experts = E 

558 top_k_num = topk_ids.size(1) 

559 group_size = block_shape[1] 

560 

561 # Transpose weights for coalesced N-dimension loads (cached) 

562 w1_t = _transpose_b(w1) # [E, N, K//2] -> [E, K//2, N] 

563 w2_t = _transpose_b(w2) # [E, N, K//2] -> [E, K//2, N] 

564 w1_scale_t = _transpose_scale(w1_scale) # [E, N, K//gs] -> [E, K//gs, N] 

565 w2_scale_t = _transpose_scale(w2_scale) # [E, N, K//gs] -> [E, K//gs, N] 

566 

567 CHUNK_SIZE: int = 16 * 1024 

568 M = min(num_tokens, CHUNK_SIZE) 

569 

570 activation_out_dim = N // 2 

571 

572 block_m = _select_block_m(M, E, top_k_num) 

573 

574 intermediate_cache3 = torch.empty( 

575 (M, top_k_num, K), 

576 device=hidden_states.device, 

577 dtype=hidden_states.dtype, 

578 ) 

579 intermediate_cache2 = torch.empty( 

580 (M * top_k_num, activation_out_dim), 

581 device=hidden_states.device, 

582 dtype=hidden_states.dtype, 

583 ) 

584 

585 if hidden_states.dtype == torch.bfloat16: 

586 compute_type = tl.bfloat16 

587 elif hidden_states.dtype == torch.float16: 

588 compute_type = tl.float16 

589 elif hidden_states.dtype == torch.float32: 

590 compute_type = tl.float32 

591 else: 

592 raise ValueError(f"Unsupported dtype: {hidden_states.dtype}") 

593 

594 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

595 

596 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

597 begin_idx = chunk * CHUNK_SIZE 

598 end_idx = min(begin_idx + CHUNK_SIZE, num_tokens) 

599 curr_hidden = hidden_states[begin_idx:end_idx] 

600 tokens_in_chunk = curr_hidden.size(0) 

601 

602 if tokens_in_chunk == 0: 

603 break 

604 

605 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

606 intermediate_cache2 = intermediate_cache2[: tokens_in_chunk * top_k_num] 

607 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

608 block_m = _select_block_m(tokens_in_chunk, E, top_k_num) 

609 

610 curr_topk_ids = topk_ids[begin_idx:end_idx] 

611 curr_topk_weights = topk_weights[begin_idx:end_idx] 

612 

613 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

614 curr_topk_ids, 

615 block_m, 

616 global_num_experts, 

617 expert_map, 

618 ) 

619 

620 # ----- GEMM1: gate/up + SiLU fused (two-pass) ----- 

621 cache2_3d = intermediate_cache2.view( 

622 tokens_in_chunk, top_k_num, activation_out_dim 

623 ) 

624 _invoke_gemm1_silu( 

625 A=curr_hidden, 

626 B=w1_t, 

627 C=cache2_3d, 

628 B_scale=w1_scale_t, 

629 topk_weights=curr_topk_weights, 

630 sorted_token_ids=sorted_token_ids, 

631 expert_ids=expert_ids, 

632 num_tokens_post_padded=num_tokens_post_padded, 

633 mul_routed_weight=apply_router_weight_on_input, 

634 top_k=top_k_num, 

635 block_m=block_m, 

636 group_size=group_size, 

637 compute_type=compute_type, 

638 ) 

639 

640 if expert_map is not None: 

641 intermediate_cache3.zero_() 

642 

643 # ----- GEMM2: activated intermediate @ w2 ----- 

644 _invoke_gemm2( 

645 A=intermediate_cache2, 

646 B=w2_t, 

647 C=intermediate_cache3, 

648 B_scale=w2_scale_t, 

649 topk_weights=curr_topk_weights, 

650 sorted_token_ids=sorted_token_ids, 

651 expert_ids=expert_ids, 

652 num_tokens_post_padded=num_tokens_post_padded, 

653 mul_routed_weight=not apply_router_weight_on_input, 

654 top_k=1, 

655 block_m=block_m, 

656 group_size=group_size, 

657 compute_type=compute_type, 

658 ) 

659 

660 # ----- Reduce: sum expert outputs per token ----- 

661 moe_sum( 

662 intermediate_cache3.view(*intermediate_cache3.size()), 

663 out_hidden_states[begin_idx:end_idx], 

664 ) 

665 

666 return out_hidden_states 

667 

668 

669def fused_marlin_moe( 

670 hidden_states: torch.Tensor, 

671 w1: torch.Tensor, 

672 w2: torch.Tensor, 

673 bias1: Optional[torch.Tensor], 

674 bias2: Optional[torch.Tensor], 

675 w1_scale: torch.Tensor, 

676 w2_scale: torch.Tensor, 

677 topk_weights: torch.Tensor, 

678 topk_ids: torch.Tensor, 

679 quant_type_id: int, 

680 apply_router_weight_on_input: bool = False, 

681 global_num_experts: int = -1, 

682 activation: Any = None, 

683 activation_func: Optional[Callable] = None, 

684 moe_sum: Optional[Callable] = None, 

685 expert_map: Optional[torch.Tensor] = None, 

686 input_global_scale1: Optional[torch.Tensor] = None, 

687 input_global_scale2: Optional[torch.Tensor] = None, 

688 global_scale1: Optional[torch.Tensor] = None, 

689 global_scale2: Optional[torch.Tensor] = None, 

690 g_idx1: Optional[torch.Tensor] = None, 

691 g_idx2: Optional[torch.Tensor] = None, 

692 sort_indices1: Optional[torch.Tensor] = None, 

693 sort_indices2: Optional[torch.Tensor] = None, 

694 w1_zeros: Optional[torch.Tensor] = None, 

695 w2_zeros: Optional[torch.Tensor] = None, 

696 workspace: Optional[torch.Tensor] = None, 

697 intermediate_cache13: Optional[torch.Tensor] = None, 

698 intermediate_cache2: Optional[torch.Tensor] = None, 

699 is_k_full: bool = True, 

700 output: Optional[torch.Tensor] = None, 

701 input_dtype: Optional[torch.dtype] = None, 

702 inplace: bool = False, 

703 clamp_limit: Optional[float] = None, 

704 group_size: int = 128, 

705) -> torch.Tensor: 

706 if quant_type_id not in _SUPPORTED_QUANT_TYPES: 

707 raise NotImplementedError( 

708 f"MVP supports quant_type_id in {_SUPPORTED_QUANT_TYPES}, " 

709 f"got {quant_type_id}" 

710 ) 

711 if g_idx1 is not None or g_idx2 is not None: 

712 raise NotImplementedError("act_order (g_idx) not yet supported in MVP") 

713 if sort_indices1 is not None or sort_indices2 is not None: 

714 raise NotImplementedError("act_order (sort_indices) not yet supported in MVP") 

715 if input_dtype is not None: 

716 raise NotImplementedError("FP8 / INT8 input quantization not supported") 

717 if clamp_limit is not None: 

718 raise NotImplementedError("clamp_limit (GLM-4 swiglu) not supported") 

719 if input_global_scale1 is not None or input_global_scale2 is not None: 

720 raise NotImplementedError("input_global_scale not supported in MVP") 

721 if global_scale1 is not None or global_scale2 is not None: 

722 raise NotImplementedError("global_scale not supported in MVP") 

723 

724 use_int4_w4a16 = quant_type_id in _QUANT_TYPE_INT4 

725 use_int8_w8a16 = quant_type_id in _QUANT_TYPE_INT8 

726 

727 activation_str = "silu" 

728 if activation is not None: 

729 for attr in ("value", "name"): 

730 v = getattr(activation, attr, None) 

731 if isinstance(v, str): 

732 activation_str = v.lower() 

733 break 

734 if isinstance(activation, str): 

735 activation_str = activation.lower() 

736 if activation_str != "silu": 

737 raise NotImplementedError( 

738 f"MVP only supports SiLU/SwiGLU activation, got {activation_str}" 

739 ) 

740 

741 if inplace and output is not None: 

742 raise ValueError("Cannot pass both inplace=True and output") 

743 

744 result = _fused_marlin_moe_impl( 

745 hidden_states=hidden_states, 

746 w1=w1, 

747 w2=w2, 

748 topk_weights=topk_weights, 

749 topk_ids=topk_ids, 

750 inplace=inplace, 

751 activation=activation_str, 

752 apply_router_weight_on_input=apply_router_weight_on_input, 

753 use_int4_w4a16=use_int4_w4a16, 

754 use_int8_w8a16=use_int8_w8a16, 

755 global_num_experts=global_num_experts, 

756 expert_map=expert_map, 

757 w1_scale=w1_scale, 

758 w2_scale=w2_scale, 

759 w1_zp=w1_zeros, 

760 w2_zp=w2_zeros, 

761 w1_bias=bias1, 

762 w2_bias=bias2, 

763 block_shape=[0, group_size], 

764 ) 

765 

766 if output is not None: 

767 output.copy_(result) 

768 return output 

769 return result 

770 

771 

772__all__ = ["fused_marlin_moe", "QUANT_TYPE_UINT4B8", "QUANT_TYPE_UINT8B128"]