Coverage for src/flag_gems/fused/fused_marlin_moe.py: 20%

109 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# SPDX-License-Identifier: Apache-2.0 

2""" 

3Fused Marlin MoE for FlagGems. 

4 

5Aligns the interface of vLLM v0.20.0: 

6 vllm/model_executor/layers/fused_moe/fused_marlin_moe.py :: fused_marlin_moe 

7 

8PHASE 2 (this file): bypass `fused_experts_impl`'s dequant-then-FP16-GEMM 

9shortcut and dispatch directly to the wna16 Triton kernel 

10(`fused_moe_kernel_gptq_awq`) for true fused-dequant W4A16/W8A16 GEMM. 

11 

12The local helper `_fused_marlin_moe_impl` mirrors `fused_experts_impl`'s 

13orchestration (chunk loop, moe_align, two GEMMs, activation, reduction) 

14but deletes the INT4/INT8 dequant branch and forwards `block_shape` so 

15the wna16 path is actually taken. 

16 

17MVP scope: 

18 - quant_type: GPTQ uint4b8 (INT4) and uint8b128 (INT8) 

19 - activation: SwiGLU / SiLU 

20 - act_order: NOT supported (g_idx / sort_indices must be None) 

21 - FP8 input: NOT supported 

22 - LoRA, clamp_limit, expert_map: NOT supported 

23""" 

24import functools 

25from typing import Any, Callable, Optional 

26 

27import torch 

28import triton.language as tl 

29 

30from flag_gems.fused.fused_moe import ( 

31 MoEActivation, 

32 _get_config_dtype_str, 

33 _get_config_quant_dtype, 

34 apply_moe_activation, 

35 dispatch_fused_moe_kernel, 

36 moe_kernel_quantize_input, 

37 try_get_optimal_moe_config, 

38) 

39from flag_gems.fused.moe_align_block_size import moe_align_block_size 

40from flag_gems.fused.moe_sum import moe_sum 

41 

42# ---------------------------------------------------------------------------- 

43# quant_type_id constants — mirror a subset of vLLM scalar_types ids. 

44# ---------------------------------------------------------------------------- 

45# GPTQ INT4 (weight stored as w + 8, dequant subtracts 8) 

46QUANT_TYPE_UINT4B8 = 0 

47# INT8 (weight stored as w + 128) 

48QUANT_TYPE_UINT8B128 = 1 

49 

50_QUANT_TYPE_INT4 = {QUANT_TYPE_UINT4B8} 

51_QUANT_TYPE_INT8 = {QUANT_TYPE_UINT8B128} 

52_SUPPORTED_QUANT_TYPES = _QUANT_TYPE_INT4 | _QUANT_TYPE_INT8 

53 

54 

55# ---------------------------------------------------------------------------- 

56# Phase-2 impl: copy of fused_experts_impl but with the dequant shortcut 

57# removed so the wna16 Triton kernel is actually invoked for W4A16/W8A16. 

58# ---------------------------------------------------------------------------- 

59def _fused_marlin_moe_impl( 

60 hidden_states: torch.Tensor, 

61 w1: torch.Tensor, 

62 w2: torch.Tensor, 

63 topk_weights: torch.Tensor, 

64 topk_ids: torch.Tensor, 

65 inplace: bool = False, 

66 activation: str = "silu", 

67 apply_router_weight_on_input: bool = False, 

68 use_int8_w8a16: bool = False, 

69 use_int4_w4a16: bool = False, 

70 per_channel_quant: bool = False, 

71 global_num_experts: int = -1, 

72 expert_map: torch.Tensor | None = None, 

73 w1_scale: Optional[torch.Tensor] = None, 

74 w2_scale: Optional[torch.Tensor] = None, 

75 w1_zp: torch.Tensor | None = None, 

76 w2_zp: torch.Tensor | None = None, 

77 block_shape: Optional[list[int]] = None, 

78 w1_bias: Optional[torch.Tensor] = None, 

79 w2_bias: Optional[torch.Tensor] = None, 

80) -> torch.Tensor: 

81 """ 

82 Like fused_experts_impl, but: 

83 - drops all paths irrelevant to W4A16/W8A16 (no fp8, int8_w8a8, mxfp). 

84 - REMOVES the `w = w.to(fp16) * scale.unsqueeze(-1)` dequant shortcut. 

85 - forwards block_shape so the wna16 kernel uses the right group_size. 

86 """ 

87 assert ( 

88 activation == "silu" 

89 ), f"Only 'silu' activation is supported, got {activation}" 

90 assert ( 

91 use_int4_w4a16 or use_int8_w8a16 

92 ), "_fused_marlin_moe_impl expects a quantized path" 

93 

94 activation_enum = MoEActivation.from_str(activation) 

95 

96 # Packed-aware shape check. 

97 # W4A16 (pack_factor=2): w1.size(2) == K // 2 

98 # W8A16 (pack_factor=1): w1.size(2) == K 

99 expected_packed_k = ( 

100 hidden_states.size(1) // 2 if use_int4_w4a16 else hidden_states.size(1) 

101 ) 

102 assert w1.size(2) == expected_packed_k, ( 

103 f"w1 packed K mismatch: hidden_size={hidden_states.size(1)}, " 

104 f"use_int4_w4a16={use_int4_w4a16}, expected w1.size(2)={expected_packed_k}, " 

105 f"got {w1.size(2)}" 

106 ) 

107 

108 assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" 

109 assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" 

110 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" 

111 assert w2.stride(-1) == 1, "Stride of last dimension must be 1" 

112 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] 

113 

114 num_tokens = hidden_states.size(0) 

115 E, N, _ = w1.size() 

116 K = w2.size(1) 

117 if global_num_experts == -1: 

118 global_num_experts = E 

119 top_k_num = topk_ids.size(1) 

120 

121 CHUNK_SIZE: int = 16 * 1024 

122 M = min(num_tokens, CHUNK_SIZE) 

123 

124 config_dtype = _get_config_dtype_str( 

125 use_fp8_w8a8=False, 

126 use_int8_w8a16=use_int8_w8a16, 

127 use_int4_w4a16=use_int4_w4a16, 

128 ocp_mx_scheme=None, 

129 dtype=hidden_states.dtype, 

130 ) 

131 quant_dtype = _get_config_quant_dtype( 

132 use_fp8_w8a8=False, 

133 use_int8_w8a8=False, 

134 ocp_mx_scheme=None, 

135 ) 

136 

137 get_config_func = functools.partial( 

138 try_get_optimal_moe_config, 

139 w1.size(), 

140 w2.size(), 

141 top_k_num, 

142 config_dtype, 

143 block_shape=block_shape, 

144 E=E, 

145 ) 

146 config = get_config_func(M) 

147 config["SPLIT_K"] = 1 

148 

149 # cache1 and cache3 share memory (non-overlapping lifetime) 

150 cache13 = torch.empty( 

151 M * top_k_num * max(N, K), 

152 device=hidden_states.device, 

153 dtype=hidden_states.dtype, 

154 ) 

155 intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) 

156 intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) 

157 

158 activation_out_dim = MoEActivation.adjust_N_for_activation(N, activation_enum) 

159 intermediate_cache2 = torch.empty( 

160 (M * top_k_num, activation_out_dim), 

161 device=hidden_states.device, 

162 dtype=hidden_states.dtype, 

163 ) 

164 

165 if hidden_states.dtype == torch.bfloat16: 

166 compute_type = tl.bfloat16 

167 elif hidden_states.dtype == torch.float16: 

168 compute_type = tl.float16 

169 elif hidden_states.dtype == torch.float32: 

170 compute_type = tl.float32 

171 else: 

172 raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") 

173 

174 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) 

175 

176 # ★ Phase-2 KEY DIFFERENCE: the W4A16/W8A16 dequant shortcut that lived 

177 # here in `fused_experts_impl` is intentionally REMOVED. The wna16 

178 # Triton kernel will consume INT4 weights + scale directly. 

179 

180 for chunk in range((num_tokens // CHUNK_SIZE) + 1): 

181 begin_chunk_idx, end_chunk_idx = ( 

182 chunk * CHUNK_SIZE, 

183 min((chunk + 1) * CHUNK_SIZE, num_tokens), 

184 ) 

185 curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] 

186 tokens_in_chunk, _ = curr_hidden_states.size() 

187 

188 if tokens_in_chunk == 0: 

189 break 

190 

191 if tokens_in_chunk < CHUNK_SIZE and chunk > 0: 

192 intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] 

193 intermediate_cache2 = intermediate_cache2[ 

194 : tokens_in_chunk * topk_ids.size(1) 

195 ] 

196 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] 

197 config = get_config_func(tokens_in_chunk) 

198 config["SPLIT_K"] = 1 

199 

200 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] 

201 curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] 

202 

203 # Activation quantization is a no-op for W4A16/W8A16 (no input quant). 

204 qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( 

205 A=curr_hidden_states, 

206 A_scale=None, 

207 quant_dtype=quant_dtype, 

208 per_act_token_quant=per_channel_quant, 

209 block_shape=block_shape, 

210 ocp_mx_scheme=None, 

211 ) 

212 

213 # Use the routed-path (skip the SPARSITY_FACTOR shortcut, which is 

214 # explicitly disabled for quantized + block_shape configs anyway). 

215 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( 

216 curr_topk_ids, 

217 config["BLOCK_SIZE_M"], 

218 global_num_experts, 

219 expert_map, 

220 ) 

221 

222 # ----- GEMM 1: hidden @ w1 (fused dequant on B inside the kernel) ----- 

223 dispatch_fused_moe_kernel( 

224 qcurr_hidden_states, 

225 w1, 

226 intermediate_cache1, 

227 a1q_scale, 

228 w1_scale, 

229 w1_zp, 

230 curr_topk_weights, 

231 sorted_token_ids, 

232 expert_ids, 

233 num_tokens_post_padded, 

234 apply_router_weight_on_input, 

235 top_k_num, 

236 config, 

237 compute_type=compute_type, 

238 use_fp8_w8a8=False, 

239 use_int8_w8a8=False, 

240 use_int8_w8a16=use_int8_w8a16, 

241 use_int4_w4a16=use_int4_w4a16, 

242 per_channel_quant=per_channel_quant, 

243 block_shape=block_shape, 

244 B_bias=w1_bias, 

245 ) 

246 

247 # ----- Activation: SwiGLU = silu(gate) * up ----- 

248 apply_moe_activation( 

249 activation_enum, intermediate_cache2, intermediate_cache1.view(-1, N) 

250 ) 

251 

252 qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( 

253 A=intermediate_cache2, 

254 A_scale=None, 

255 quant_dtype=quant_dtype, 

256 per_act_token_quant=per_channel_quant, 

257 block_shape=block_shape, 

258 ocp_mx_scheme=None, 

259 ) 

260 

261 if expert_map is not None: 

262 intermediate_cache3.zero_() 

263 

264 # ----- GEMM 2: act @ w2 (fused dequant on B inside the kernel) ----- 

265 dispatch_fused_moe_kernel( 

266 qintermediate_cache2, 

267 w2, 

268 intermediate_cache3, 

269 a2q_scale, 

270 w2_scale, 

271 w2_zp, 

272 curr_topk_weights, 

273 sorted_token_ids, 

274 expert_ids, 

275 num_tokens_post_padded, 

276 not apply_router_weight_on_input, 

277 1, 

278 config, 

279 compute_type=compute_type, 

280 use_fp8_w8a8=False, 

281 use_int8_w8a8=False, 

282 use_int8_w8a16=use_int8_w8a16, 

283 use_int4_w4a16=use_int4_w4a16, 

284 per_channel_quant=per_channel_quant, 

285 block_shape=block_shape, 

286 B_bias=w2_bias, 

287 ) 

288 

289 # ----- Reduce: sum topk-weighted expert outputs back per token ----- 

290 moe_sum( 

291 intermediate_cache3.view(*intermediate_cache3.size()), 

292 out_hidden_states[begin_chunk_idx:end_chunk_idx], 

293 ) 

294 

295 return out_hidden_states 

296 

297 

298# ---------------------------------------------------------------------------- 

299# Public entry point: vLLM-aligned wrapper. 

300# ---------------------------------------------------------------------------- 

301def fused_marlin_moe( 

302 hidden_states: torch.Tensor, 

303 w1: torch.Tensor, 

304 w2: torch.Tensor, 

305 bias1: Optional[torch.Tensor], 

306 bias2: Optional[torch.Tensor], 

307 w1_scale: torch.Tensor, 

308 w2_scale: torch.Tensor, 

309 topk_weights: torch.Tensor, 

310 topk_ids: torch.Tensor, 

311 quant_type_id: int, 

312 apply_router_weight_on_input: bool = False, 

313 global_num_experts: int = -1, 

314 activation: Any = None, 

315 activation_func: Optional[Callable] = None, 

316 moe_sum: Optional[Callable] = None, 

317 expert_map: Optional[torch.Tensor] = None, 

318 input_global_scale1: Optional[torch.Tensor] = None, 

319 input_global_scale2: Optional[torch.Tensor] = None, 

320 global_scale1: Optional[torch.Tensor] = None, 

321 global_scale2: Optional[torch.Tensor] = None, 

322 g_idx1: Optional[torch.Tensor] = None, 

323 g_idx2: Optional[torch.Tensor] = None, 

324 sort_indices1: Optional[torch.Tensor] = None, 

325 sort_indices2: Optional[torch.Tensor] = None, 

326 w1_zeros: Optional[torch.Tensor] = None, 

327 w2_zeros: Optional[torch.Tensor] = None, 

328 workspace: Optional[torch.Tensor] = None, 

329 intermediate_cache13: Optional[torch.Tensor] = None, 

330 intermediate_cache2: Optional[torch.Tensor] = None, 

331 is_k_full: bool = True, 

332 output: Optional[torch.Tensor] = None, 

333 input_dtype: Optional[torch.dtype] = None, 

334 inplace: bool = False, 

335 clamp_limit: Optional[float] = None, 

336 group_size: int = 128, 

337) -> torch.Tensor: 

338 """Phase-2 entry point: dispatch to local wna16-using impl.""" 

339 # ---- MVP guardrails -------------------------------------------------- 

340 if quant_type_id not in _SUPPORTED_QUANT_TYPES: 

341 raise NotImplementedError( 

342 f"MVP supports quant_type_id in {_SUPPORTED_QUANT_TYPES}, " 

343 f"got {quant_type_id}" 

344 ) 

345 if g_idx1 is not None or g_idx2 is not None: 

346 raise NotImplementedError("act_order (g_idx) not yet supported in MVP") 

347 if sort_indices1 is not None or sort_indices2 is not None: 

348 raise NotImplementedError("act_order (sort_indices) not yet supported in MVP") 

349 if input_dtype is not None: 

350 raise NotImplementedError("FP8 / INT8 input quantization not supported") 

351 if clamp_limit is not None: 

352 raise NotImplementedError("clamp_limit (GLM-4 swiglu) not supported") 

353 if input_global_scale1 is not None or input_global_scale2 is not None: 

354 raise NotImplementedError("input_global_scale not supported in MVP") 

355 if global_scale1 is not None or global_scale2 is not None: 

356 raise NotImplementedError("global_scale not supported in MVP") 

357 

358 use_int4_w4a16 = quant_type_id in _QUANT_TYPE_INT4 

359 use_int8_w8a16 = quant_type_id in _QUANT_TYPE_INT8 

360 

361 activation_str = "silu" 

362 if activation is not None: 

363 for attr in ("value", "name"): 

364 v = getattr(activation, attr, None) 

365 if isinstance(v, str): 

366 activation_str = v.lower() 

367 break 

368 if isinstance(activation, str): 

369 activation_str = activation.lower() 

370 if activation_str != "silu": 

371 raise NotImplementedError( 

372 f"MVP only supports SiLU/SwiGLU activation, got {activation_str}" 

373 ) 

374 

375 if inplace and output is not None: 

376 raise ValueError("Cannot pass both inplace=True and output") 

377 

378 result = _fused_marlin_moe_impl( 

379 hidden_states=hidden_states, 

380 w1=w1, 

381 w2=w2, 

382 topk_weights=topk_weights, 

383 topk_ids=topk_ids, 

384 inplace=inplace, 

385 activation=activation_str, 

386 apply_router_weight_on_input=apply_router_weight_on_input, 

387 use_int4_w4a16=use_int4_w4a16, 

388 use_int8_w8a16=use_int8_w8a16, 

389 global_num_experts=global_num_experts, 

390 expert_map=expert_map, 

391 w1_scale=w1_scale, 

392 w2_scale=w2_scale, 

393 w1_zp=w1_zeros, 

394 w2_zp=w2_zeros, 

395 w1_bias=bias1, 

396 w2_bias=bias2, 

397 # Critical for Phase 2: block_shape=[0, group_size] makes the 

398 # wna16 Triton kernel use the per-group scales correctly. 

399 block_shape=[0, group_size], 

400 ) 

401 

402 if output is not None: 

403 output.copy_(result) 

404 return output 

405 return result 

406 

407 

408__all__ = ["fused_marlin_moe", "QUANT_TYPE_UINT4B8", "QUANT_TYPE_UINT8B128"]