Coverage for src/flag_gems/fused/fused_marlin_moe.py: 14%
263 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1# SPDX-License-Identifier: Apache-2.0
2"""
3Fused Marlin MoE — v7: Tunable BLOCK_SIZE_K for improved pipelining.
5Key changes from v6:
6- BLOCK_SIZE_K is now an autotune parameter (32, 64, 128) instead of fixed at
7 group_size=128. Smaller K tiles enable more software pipeline stages and
8 reduce register pressure, improving memory latency hiding for bandwidth-bound
9 small batch sizes.
10- GROUP_SIZE_K constexpr correctly indexes scales when BLOCK_SIZE_K < group_size.
11 Math: accumulating partial sums within a scale group gives identical results.
12- Transposed B layout [E, K//2, N] from v6 is preserved for coalesced N-loads.
13- Two-pass GEMM1 (gate/up) with fused SiLU preserved from v6.
14"""
16from typing import Any, Callable, Optional
18import torch
19import triton
20import triton.language as tl
22from flag_gems.fused.fused_moe import write_zeros_to_output
23from flag_gems.fused.moe_align_block_size import moe_align_block_size
24from flag_gems.fused.moe_sum import moe_sum
26QUANT_TYPE_UINT4B8 = 0
27QUANT_TYPE_UINT8B128 = 1
28_QUANT_TYPE_INT4 = {QUANT_TYPE_UINT4B8}
29_QUANT_TYPE_INT8 = {QUANT_TYPE_UINT8B128}
30_SUPPORTED_QUANT_TYPES = _QUANT_TYPE_INT4 | _QUANT_TYPE_INT8
33# ---------- Transpose cache ----------
35_B_CACHE: dict = {}
36_SCALE_CACHE: dict = {}
39def _transpose_b(b: torch.Tensor) -> torch.Tensor:
40 """Transpose B from [E, N, K//2] to [E, K//2, N] for coalesced N-loads."""
41 key = (b.data_ptr(), b.shape[0], b.shape[1], b.shape[2])
42 cached = _B_CACHE.get(key)
43 if cached is not None:
44 return cached
45 bt = b.transpose(1, 2).contiguous()
46 _B_CACHE[key] = bt
47 return bt
50def _transpose_scale(s: torch.Tensor) -> torch.Tensor:
51 """Transpose scale from [E, N, K//gs] to [E, K//gs, N] for coalesced loads."""
52 key = (s.data_ptr(), s.shape[0], s.shape[1], s.shape[2])
53 cached = _SCALE_CACHE.get(key)
54 if cached is not None:
55 return cached
56 st = s.transpose(1, 2).contiguous()
57 _SCALE_CACHE[key] = st
58 return st
61# ---------- Autotune configs ----------
63_AUTOTUNE_CONFIGS = [
64 # BLOCK_SIZE_K=128: compute-bound regime (large M)
65 triton.Config(
66 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1},
67 num_warps=4,
68 num_stages=4,
69 ),
70 triton.Config(
71 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1},
72 num_warps=4,
73 num_stages=3,
74 ),
75 triton.Config(
76 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4},
77 num_warps=8,
78 num_stages=3,
79 ),
80 triton.Config(
81 {"BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4},
82 num_warps=8,
83 num_stages=2,
84 ),
85 # BLOCK_SIZE_K=64: balanced pipelining, reduced register pressure
86 triton.Config(
87 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1},
88 num_warps=4,
89 num_stages=5,
90 ),
91 triton.Config(
92 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1},
93 num_warps=4,
94 num_stages=5,
95 ),
96 triton.Config(
97 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4},
98 num_warps=8,
99 num_stages=4,
100 ),
101 # BLOCK_SIZE_K=32: max pipelining for bandwidth-bound small batches
102 triton.Config(
103 {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1},
104 num_warps=4,
105 num_stages=8,
106 ),
107 triton.Config(
108 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1},
109 num_warps=4,
110 num_stages=8,
111 ),
112 triton.Config(
113 {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4},
114 num_warps=8,
115 num_stages=6,
116 ),
117]
120def _select_block_m(M, E, top_k):
121 avg_tokens = max(M * top_k / max(E, 1), 1)
122 if avg_tokens <= 4:
123 return 16
124 elif avg_tokens <= 32:
125 return 32
126 else:
127 return 64
130# ---------- GEMM1: two-pass gate/up with fused SiLU ----------
131# B layout: [E, K//2, N] (transposed), stride_bk=N, stride_bn=1
134@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["N", "K"])
135@triton.jit
136def _int4_gemm_silu_kernel(
137 a_ptr,
138 b_ptr,
139 c_ptr,
140 b_scale_ptr,
141 topk_weights_ptr,
142 sorted_token_ids_ptr,
143 expert_ids_ptr,
144 num_tokens_post_padded_ptr,
145 N: tl.constexpr,
146 K: tl.constexpr,
147 EM,
148 num_valid_tokens,
149 stride_am,
150 stride_ak,
151 stride_be,
152 stride_bk,
153 stride_bn,
154 stride_cm,
155 stride_cn,
156 stride_bse,
157 stride_bsk,
158 stride_bsn,
159 BLOCK_SIZE_M: tl.constexpr,
160 BLOCK_SIZE_N: tl.constexpr,
161 BLOCK_SIZE_K: tl.constexpr,
162 GROUP_SIZE_M: tl.constexpr,
163 GROUP_SIZE_K: tl.constexpr,
164 MUL_ROUTED_WEIGHT: tl.constexpr,
165 top_k: tl.constexpr,
166 compute_type: tl.constexpr,
167):
168 N_out = N // 2
170 pid = tl.program_id(axis=0)
171 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
172 num_pid_n = tl.cdiv(N_out, BLOCK_SIZE_N)
173 num_pid_in_group = GROUP_SIZE_M * num_pid_n
174 group_id = pid // num_pid_in_group
175 first_pid_m = group_id * GROUP_SIZE_M
176 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
177 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
178 pid_n = (pid % num_pid_in_group) // group_size_m
180 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
181 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
182 return
184 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
185 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
186 token_mask = offs_token < num_valid_tokens
188 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
189 if off_experts == -1:
190 write_zeros_to_output(
191 c_ptr,
192 stride_cm,
193 stride_cn,
194 pid_n,
195 N_out,
196 offs_token,
197 token_mask,
198 BLOCK_SIZE_M,
199 BLOCK_SIZE_N,
200 compute_type,
201 )
202 return
204 offs_bn_gate = (
205 pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
206 ) % N_out
207 offs_bn_up = offs_bn_gate + N_out
208 offs_k = tl.arange(0, BLOCK_SIZE_K)
210 a_base = a_ptr + (offs_token[:, None] // top_k * stride_am)
211 b_expert_base = b_ptr + off_experts * stride_be
212 b_shifter = (offs_k[:, None] % 2) * 4
214 # B is transposed: [E, K//2, N], stride_bk=N (between packed K), stride_bn=1 (N contiguous)
215 b_ptrs_gate = (
216 b_expert_base
217 + (offs_k[:, None] // 2) * stride_bk
218 + offs_bn_gate[None, :] * stride_bn
219 )
220 b_ptrs_up = (
221 b_expert_base
222 + (offs_k[:, None] // 2) * stride_bk
223 + offs_bn_up[None, :] * stride_bn
224 )
226 # Scale is transposed: [E, K//gs, N], stride_bsk=N, stride_bsn=1
227 scale_base_gate = b_scale_ptr + off_experts * stride_bse + offs_bn_gate * stride_bsn
228 scale_base_up = b_scale_ptr + off_experts * stride_bse + offs_bn_up * stride_bsn
230 # ---- Pass 1: Gate projection ----
231 a_ptrs = a_base + offs_k[None, :] * stride_ak
232 acc_gate = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
234 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
235 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
236 b_g = ((tl.load(b_ptrs_gate) >> b_shifter) & 0xF).to(compute_type)
237 raw_dot = tl.dot(a, b_g)
238 row_sum = tl.sum(a.to(tl.float32), axis=1)
239 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K
240 scale_g = tl.load(scale_base_gate + scale_idx * stride_bsk).to(tl.float32)
241 acc_gate += scale_g[None, :] * (raw_dot - 8.0 * row_sum[:, None])
243 a_ptrs += BLOCK_SIZE_K * stride_ak
244 b_ptrs_gate += (BLOCK_SIZE_K // 2) * stride_bk
246 # ---- Pass 2: Up projection ----
247 a_ptrs = a_base + offs_k[None, :] * stride_ak
248 acc_up = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
250 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
251 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
252 b_u = ((tl.load(b_ptrs_up) >> b_shifter) & 0xF).to(compute_type)
253 raw_dot = tl.dot(a, b_u)
254 row_sum = tl.sum(a.to(tl.float32), axis=1)
255 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K
256 scale_u = tl.load(scale_base_up + scale_idx * stride_bsk).to(tl.float32)
257 acc_up += scale_u[None, :] * (raw_dot - 8.0 * row_sum[:, None])
259 a_ptrs += BLOCK_SIZE_K * stride_ak
260 b_ptrs_up += (BLOCK_SIZE_K // 2) * stride_bk
262 # ---- Fused SiLU: silu(gate) * up ----
263 accumulator = tl.fdiv(acc_gate, (1.0 + tl.exp(-acc_gate))) * acc_up
265 if MUL_ROUTED_WEIGHT:
266 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
267 accumulator = accumulator * moe_weight[:, None]
269 accumulator = accumulator.to(compute_type)
271 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
272 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
273 c_mask = token_mask[:, None] & (offs_cn[None, :] < N_out)
274 tl.store(c_ptrs, accumulator, mask=c_mask)
277# ---------- GEMM2: standard INT4 GEMM with factored zero-point ----------
278# B layout: [E, K//2, N] (transposed), stride_bk=N, stride_bn=1
281@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["N", "K"])
282@triton.jit
283def _int4_gemm_kernel(
284 a_ptr,
285 b_ptr,
286 c_ptr,
287 b_scale_ptr,
288 topk_weights_ptr,
289 sorted_token_ids_ptr,
290 expert_ids_ptr,
291 num_tokens_post_padded_ptr,
292 N: tl.constexpr,
293 K: tl.constexpr,
294 EM,
295 num_valid_tokens,
296 stride_am,
297 stride_ak,
298 stride_be,
299 stride_bk,
300 stride_bn,
301 stride_cm,
302 stride_cn,
303 stride_bse,
304 stride_bsk,
305 stride_bsn,
306 BLOCK_SIZE_M: tl.constexpr,
307 BLOCK_SIZE_N: tl.constexpr,
308 BLOCK_SIZE_K: tl.constexpr,
309 GROUP_SIZE_M: tl.constexpr,
310 GROUP_SIZE_K: tl.constexpr,
311 MUL_ROUTED_WEIGHT: tl.constexpr,
312 top_k: tl.constexpr,
313 compute_type: tl.constexpr,
314):
315 pid = tl.program_id(axis=0)
316 num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
317 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
318 num_pid_in_group = GROUP_SIZE_M * num_pid_n
319 group_id = pid // num_pid_in_group
320 first_pid_m = group_id * GROUP_SIZE_M
321 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
322 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
323 pid_n = (pid % num_pid_in_group) // group_size_m
325 num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
326 if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
327 return
329 offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
330 offs_token = tl.load(sorted_token_ids_ptr + offs_token_id).to(tl.int64)
331 token_mask = offs_token < num_valid_tokens
333 off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
334 if off_experts == -1:
335 write_zeros_to_output(
336 c_ptr,
337 stride_cm,
338 stride_cn,
339 pid_n,
340 N,
341 offs_token,
342 token_mask,
343 BLOCK_SIZE_M,
344 BLOCK_SIZE_N,
345 compute_type,
346 )
347 return
349 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
350 offs_k = tl.arange(0, BLOCK_SIZE_K)
352 a_ptrs = a_ptr + (
353 offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
354 )
355 # B transposed: [E, K//2, N], stride_bk=N, stride_bn=1
356 b_ptrs = (
357 b_ptr
358 + off_experts * stride_be
359 + (offs_k[:, None] // 2) * stride_bk
360 + offs_bn[None, :] * stride_bn
361 )
362 b_shifter = (offs_k[:, None] % 2) * 4
363 scale_base = b_scale_ptr + off_experts * stride_bse + offs_bn * stride_bsn
365 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
367 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
368 a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0)
369 b_int = ((tl.load(b_ptrs) >> b_shifter) & 0xF).to(compute_type)
370 raw_dot = tl.dot(a, b_int)
371 row_sum = tl.sum(a.to(tl.float32), axis=1)
372 scale_idx = k * BLOCK_SIZE_K // GROUP_SIZE_K
373 scale = tl.load(scale_base + scale_idx * stride_bsk).to(tl.float32)
374 accumulator += scale[None, :] * (raw_dot - 8.0 * row_sum[:, None])
376 a_ptrs += BLOCK_SIZE_K * stride_ak
377 b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
379 if MUL_ROUTED_WEIGHT:
380 moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
381 accumulator = accumulator * moe_weight[:, None]
383 accumulator = accumulator.to(compute_type)
385 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
386 c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
387 c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
388 tl.store(c_ptrs, accumulator, mask=c_mask)
391# ---------- Launch wrappers ----------
394def _invoke_gemm1_silu(
395 A,
396 B,
397 C,
398 B_scale,
399 topk_weights,
400 sorted_token_ids,
401 expert_ids,
402 num_tokens_post_padded,
403 mul_routed_weight,
404 top_k,
405 block_m,
406 group_size,
407 compute_type,
408):
409 # B is transposed: [E, K//2, N]
410 N = B.size(2) # N is now dim 2
411 K = A.size(1)
412 N_out = N // 2
413 M = A.size(0)
415 EM = sorted_token_ids.size(0)
416 if M < block_m:
417 EM = min(EM, M * top_k * block_m)
419 grid = lambda META: (
420 triton.cdiv(EM, META["BLOCK_SIZE_M"])
421 * triton.cdiv(N_out, META["BLOCK_SIZE_N"]),
422 )
424 _int4_gemm_silu_kernel[grid](
425 A,
426 B,
427 C,
428 B_scale,
429 topk_weights,
430 sorted_token_ids,
431 expert_ids,
432 num_tokens_post_padded,
433 N,
434 K,
435 EM,
436 M * top_k,
437 A.stride(0),
438 A.stride(1),
439 # B transposed [E, K//2, N]: stride(0)=expert, stride(1)=K, stride(2)=N
440 B.stride(0),
441 B.stride(1),
442 B.stride(2),
443 C.stride(1),
444 C.stride(2),
445 # B_scale transposed [E, K//gs, N]: stride(0)=expert, stride(1)=K, stride(2)=N
446 B_scale.stride(0),
447 B_scale.stride(1),
448 B_scale.stride(2),
449 BLOCK_SIZE_M=block_m,
450 GROUP_SIZE_K=group_size,
451 MUL_ROUTED_WEIGHT=mul_routed_weight,
452 top_k=top_k,
453 compute_type=compute_type,
454 )
457def _invoke_gemm2(
458 A,
459 B,
460 C,
461 B_scale,
462 topk_weights,
463 sorted_token_ids,
464 expert_ids,
465 num_tokens_post_padded,
466 mul_routed_weight,
467 top_k,
468 block_m,
469 group_size,
470 compute_type,
471):
472 # B is transposed: [E, K//2, N]
473 N = B.size(2) # N is now dim 2
474 K = A.size(1)
475 M = A.size(0)
477 EM = sorted_token_ids.size(0)
478 if M < block_m:
479 EM = min(EM, M * top_k * block_m)
481 grid = lambda META: (
482 triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
483 )
485 _int4_gemm_kernel[grid](
486 A,
487 B,
488 C,
489 B_scale,
490 topk_weights,
491 sorted_token_ids,
492 expert_ids,
493 num_tokens_post_padded,
494 N,
495 K,
496 EM,
497 M * top_k,
498 A.stride(0),
499 A.stride(1),
500 B.stride(0),
501 B.stride(1),
502 B.stride(2),
503 C.stride(1),
504 C.stride(2),
505 B_scale.stride(0),
506 B_scale.stride(1),
507 B_scale.stride(2),
508 BLOCK_SIZE_M=block_m,
509 GROUP_SIZE_K=group_size,
510 MUL_ROUTED_WEIGHT=mul_routed_weight,
511 top_k=top_k,
512 compute_type=compute_type,
513 )
516# ---------- Implementation ----------
519def _fused_marlin_moe_impl(
520 hidden_states: torch.Tensor,
521 w1: torch.Tensor,
522 w2: torch.Tensor,
523 topk_weights: torch.Tensor,
524 topk_ids: torch.Tensor,
525 inplace: bool = False,
526 activation: str = "silu",
527 apply_router_weight_on_input: bool = False,
528 use_int8_w8a16: bool = False,
529 use_int4_w4a16: bool = False,
530 per_channel_quant: bool = False,
531 global_num_experts: int = -1,
532 expert_map: torch.Tensor | None = None,
533 w1_scale: Optional[torch.Tensor] = None,
534 w2_scale: Optional[torch.Tensor] = None,
535 w1_zp: torch.Tensor | None = None,
536 w2_zp: torch.Tensor | None = None,
537 block_shape: Optional[list[int]] = None,
538 w1_bias: Optional[torch.Tensor] = None,
539 w2_bias: Optional[torch.Tensor] = None,
540) -> torch.Tensor:
541 assert activation == "silu"
542 assert use_int4_w4a16
543 assert w1_zp is None and w2_zp is None
545 expected_packed_k = hidden_states.size(1) // 2
546 assert w1.size(2) == expected_packed_k
547 assert topk_weights.size() == topk_ids.size()
548 assert hidden_states.is_contiguous()
549 assert w1.stride(-1) == 1
550 assert w2.stride(-1) == 1
551 assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
553 num_tokens = hidden_states.size(0)
554 E, N, _ = w1.size()
555 K = w2.size(1)
556 if global_num_experts == -1:
557 global_num_experts = E
558 top_k_num = topk_ids.size(1)
559 group_size = block_shape[1]
561 # Transpose weights for coalesced N-dimension loads (cached)
562 w1_t = _transpose_b(w1) # [E, N, K//2] -> [E, K//2, N]
563 w2_t = _transpose_b(w2) # [E, N, K//2] -> [E, K//2, N]
564 w1_scale_t = _transpose_scale(w1_scale) # [E, N, K//gs] -> [E, K//gs, N]
565 w2_scale_t = _transpose_scale(w2_scale) # [E, N, K//gs] -> [E, K//gs, N]
567 CHUNK_SIZE: int = 16 * 1024
568 M = min(num_tokens, CHUNK_SIZE)
570 activation_out_dim = N // 2
572 block_m = _select_block_m(M, E, top_k_num)
574 intermediate_cache3 = torch.empty(
575 (M, top_k_num, K),
576 device=hidden_states.device,
577 dtype=hidden_states.dtype,
578 )
579 intermediate_cache2 = torch.empty(
580 (M * top_k_num, activation_out_dim),
581 device=hidden_states.device,
582 dtype=hidden_states.dtype,
583 )
585 if hidden_states.dtype == torch.bfloat16:
586 compute_type = tl.bfloat16
587 elif hidden_states.dtype == torch.float16:
588 compute_type = tl.float16
589 elif hidden_states.dtype == torch.float32:
590 compute_type = tl.float32
591 else:
592 raise ValueError(f"Unsupported dtype: {hidden_states.dtype}")
594 out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
596 for chunk in range((num_tokens // CHUNK_SIZE) + 1):
597 begin_idx = chunk * CHUNK_SIZE
598 end_idx = min(begin_idx + CHUNK_SIZE, num_tokens)
599 curr_hidden = hidden_states[begin_idx:end_idx]
600 tokens_in_chunk = curr_hidden.size(0)
602 if tokens_in_chunk == 0:
603 break
605 if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
606 intermediate_cache2 = intermediate_cache2[: tokens_in_chunk * top_k_num]
607 intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
608 block_m = _select_block_m(tokens_in_chunk, E, top_k_num)
610 curr_topk_ids = topk_ids[begin_idx:end_idx]
611 curr_topk_weights = topk_weights[begin_idx:end_idx]
613 sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
614 curr_topk_ids,
615 block_m,
616 global_num_experts,
617 expert_map,
618 )
620 # ----- GEMM1: gate/up + SiLU fused (two-pass) -----
621 cache2_3d = intermediate_cache2.view(
622 tokens_in_chunk, top_k_num, activation_out_dim
623 )
624 _invoke_gemm1_silu(
625 A=curr_hidden,
626 B=w1_t,
627 C=cache2_3d,
628 B_scale=w1_scale_t,
629 topk_weights=curr_topk_weights,
630 sorted_token_ids=sorted_token_ids,
631 expert_ids=expert_ids,
632 num_tokens_post_padded=num_tokens_post_padded,
633 mul_routed_weight=apply_router_weight_on_input,
634 top_k=top_k_num,
635 block_m=block_m,
636 group_size=group_size,
637 compute_type=compute_type,
638 )
640 if expert_map is not None:
641 intermediate_cache3.zero_()
643 # ----- GEMM2: activated intermediate @ w2 -----
644 _invoke_gemm2(
645 A=intermediate_cache2,
646 B=w2_t,
647 C=intermediate_cache3,
648 B_scale=w2_scale_t,
649 topk_weights=curr_topk_weights,
650 sorted_token_ids=sorted_token_ids,
651 expert_ids=expert_ids,
652 num_tokens_post_padded=num_tokens_post_padded,
653 mul_routed_weight=not apply_router_weight_on_input,
654 top_k=1,
655 block_m=block_m,
656 group_size=group_size,
657 compute_type=compute_type,
658 )
660 # ----- Reduce: sum expert outputs per token -----
661 moe_sum(
662 intermediate_cache3.view(*intermediate_cache3.size()),
663 out_hidden_states[begin_idx:end_idx],
664 )
666 return out_hidden_states
669def fused_marlin_moe(
670 hidden_states: torch.Tensor,
671 w1: torch.Tensor,
672 w2: torch.Tensor,
673 bias1: Optional[torch.Tensor],
674 bias2: Optional[torch.Tensor],
675 w1_scale: torch.Tensor,
676 w2_scale: torch.Tensor,
677 topk_weights: torch.Tensor,
678 topk_ids: torch.Tensor,
679 quant_type_id: int,
680 apply_router_weight_on_input: bool = False,
681 global_num_experts: int = -1,
682 activation: Any = None,
683 activation_func: Optional[Callable] = None,
684 moe_sum: Optional[Callable] = None,
685 expert_map: Optional[torch.Tensor] = None,
686 input_global_scale1: Optional[torch.Tensor] = None,
687 input_global_scale2: Optional[torch.Tensor] = None,
688 global_scale1: Optional[torch.Tensor] = None,
689 global_scale2: Optional[torch.Tensor] = None,
690 g_idx1: Optional[torch.Tensor] = None,
691 g_idx2: Optional[torch.Tensor] = None,
692 sort_indices1: Optional[torch.Tensor] = None,
693 sort_indices2: Optional[torch.Tensor] = None,
694 w1_zeros: Optional[torch.Tensor] = None,
695 w2_zeros: Optional[torch.Tensor] = None,
696 workspace: Optional[torch.Tensor] = None,
697 intermediate_cache13: Optional[torch.Tensor] = None,
698 intermediate_cache2: Optional[torch.Tensor] = None,
699 is_k_full: bool = True,
700 output: Optional[torch.Tensor] = None,
701 input_dtype: Optional[torch.dtype] = None,
702 inplace: bool = False,
703 clamp_limit: Optional[float] = None,
704 group_size: int = 128,
705) -> torch.Tensor:
706 if quant_type_id not in _SUPPORTED_QUANT_TYPES:
707 raise NotImplementedError(
708 f"MVP supports quant_type_id in {_SUPPORTED_QUANT_TYPES}, "
709 f"got {quant_type_id}"
710 )
711 if g_idx1 is not None or g_idx2 is not None:
712 raise NotImplementedError("act_order (g_idx) not yet supported in MVP")
713 if sort_indices1 is not None or sort_indices2 is not None:
714 raise NotImplementedError("act_order (sort_indices) not yet supported in MVP")
715 if input_dtype is not None:
716 raise NotImplementedError("FP8 / INT8 input quantization not supported")
717 if clamp_limit is not None:
718 raise NotImplementedError("clamp_limit (GLM-4 swiglu) not supported")
719 if input_global_scale1 is not None or input_global_scale2 is not None:
720 raise NotImplementedError("input_global_scale not supported in MVP")
721 if global_scale1 is not None or global_scale2 is not None:
722 raise NotImplementedError("global_scale not supported in MVP")
724 use_int4_w4a16 = quant_type_id in _QUANT_TYPE_INT4
725 use_int8_w8a16 = quant_type_id in _QUANT_TYPE_INT8
727 activation_str = "silu"
728 if activation is not None:
729 for attr in ("value", "name"):
730 v = getattr(activation, attr, None)
731 if isinstance(v, str):
732 activation_str = v.lower()
733 break
734 if isinstance(activation, str):
735 activation_str = activation.lower()
736 if activation_str != "silu":
737 raise NotImplementedError(
738 f"MVP only supports SiLU/SwiGLU activation, got {activation_str}"
739 )
741 if inplace and output is not None:
742 raise ValueError("Cannot pass both inplace=True and output")
744 result = _fused_marlin_moe_impl(
745 hidden_states=hidden_states,
746 w1=w1,
747 w2=w2,
748 topk_weights=topk_weights,
749 topk_ids=topk_ids,
750 inplace=inplace,
751 activation=activation_str,
752 apply_router_weight_on_input=apply_router_weight_on_input,
753 use_int4_w4a16=use_int4_w4a16,
754 use_int8_w8a16=use_int8_w8a16,
755 global_num_experts=global_num_experts,
756 expert_map=expert_map,
757 w1_scale=w1_scale,
758 w2_scale=w2_scale,
759 w1_zp=w1_zeros,
760 w2_zp=w2_zeros,
761 w1_bias=bias1,
762 w2_bias=bias2,
763 block_shape=[0, group_size],
764 )
766 if output is not None:
767 output.copy_(result)
768 return output
769 return result
772__all__ = ["fused_marlin_moe", "QUANT_TYPE_UINT4B8", "QUANT_TYPE_UINT8B128"]