Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_mlp.py: 0%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Monkey-patch Qwen3MLP.forward to use triton-cpu's fused_mlp_bf16 kernel.
3This patch replaces the 5-op ATen sequence
4 gate_proj(x) → silu → up_proj(x) → mul → down_proj
5with a single fused C kernel call when:
6 - decode shape (M=1)
7 - BF16 activation
8 - gate_proj and up_proj are INT8 SDOT-packed Linears
9 (expose attributes: _packed, _w_scale, K, N)
11Measured benefit on Qwen3-1.7B W8A8-INT8 decode (3 rounds × 5 runs median,
12CIX P1 CD8180, 8 big cores, OMP=8, performance governor):
14 ENABLE_MLP_PATCH=1 ON → 9.92 tok/s median (9.88, 10.04, 9.92)
15 ENABLE_MLP_PATCH=0 OFF → 9.73 tok/s median (9.61, 9.73, 9.76)
16 → +1.95% median (+2.5% mean) consistent across 3 rounds.
18Usage:
19 from flag_gems.runtime.backend._arm.fused.patch_qwen3_mlp import patch_qwen3_mlp
20 patch_qwen3_mlp(model)
21"""
23import logging
24import types
26import torch
27import triton
28import triton.language as tl
29from triton.language.extra.cpu.tle_ops import fused_mlp as _tle_fused_mlp
31logger = logging.getLogger(__name__)
33_PATCHED: set = set()
36@triton.jit
37def _fused_mlp_kernel(
38 x_ptr,
39 gate_packed_ptr,
40 up_packed_ptr,
41 gate_scale_ptr,
42 up_scale_ptr,
43 out_ptr,
44 K: tl.constexpr,
45 N: tl.constexpr,
46):
47 # Coarse TLE op: gate GEMV + up GEMV + SWIGLU fused in one kernel launch.
48 _tle_fused_mlp(
49 x_ptr,
50 gate_packed_ptr,
51 up_packed_ptr,
52 gate_scale_ptr,
53 up_scale_ptr,
54 out_ptr,
55 K,
56 N,
57 )
60class FusedMLPWrapper:
61 """Holds references to a Qwen3MLP's 3 projections + act_fn, exposes
62 a forward that dispatches to the fused C kernel on M=1 BF16 decode.
64 Falls back to composing gate/up/down via their own forward for:
65 - M > 1 (prefill)
66 - non-BF16 activation
67 - gate/up not SDOT-packed INT8 Linears
68 """
70 def __init__(self, gate_linear, up_linear, down_linear, act_fn):
71 self._gate_linear = gate_linear
72 self._up_linear = up_linear
73 self.down_proj = down_linear
74 self.act_fn = act_fn
76 self._fused = (
77 hasattr(gate_linear, "_packed")
78 and hasattr(up_linear, "_packed")
79 and hasattr(gate_linear, "_w_scale")
80 and hasattr(up_linear, "_w_scale")
81 and hasattr(gate_linear, "K")
82 and hasattr(gate_linear, "N")
83 )
84 if self._fused:
85 self._gate_packed = gate_linear._packed
86 self._up_packed = up_linear._packed
87 self._gate_scale = gate_linear._w_scale
88 self._up_scale = up_linear._w_scale
89 self._K = gate_linear.K
90 self._N = gate_linear.N
92 def forward(self, x):
93 shape = x.shape
94 M = x.numel() // shape[-1]
95 if self._fused and M == 1 and x.dtype == torch.bfloat16:
96 xc = x.reshape(-1).contiguous()
97 out = torch.empty(self._N, dtype=torch.bfloat16)
98 _fused_mlp_kernel[(1,)](
99 xc,
100 self._gate_packed,
101 self._up_packed,
102 self._gate_scale,
103 self._up_scale,
104 out,
105 K=self._K,
106 N=self._N,
107 )
108 return self.down_proj(out.reshape(*shape[:-1], self._N))
109 # ATen fallback: compose gate+up+silu+mul+down via each Linear's own forward
110 gate = self._gate_linear(x)
111 up = self._up_linear(x)
112 return self.down_proj(self.act_fn(gate) * up)
115def _get_qwen_mlp_classes() -> tuple:
116 """Return a tuple of MLP classes to patch (Qwen3MLP + Qwen3_5MLP if available).
118 Both classes share the same structure (gate_proj, up_proj, down_proj, act_fn),
119 so the FusedMLPWrapper works on either.
120 """
121 classes = []
122 for modname, clsname in [
123 ("transformers.models.qwen3.modeling_qwen3", "Qwen3MLP"),
124 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5MLP"),
125 ("transformers.models.llama.modeling_llama", "LlamaMLP"), # MiniCPM5 etc.
126 ]:
127 try:
128 mod = __import__(modname, fromlist=[clsname])
129 classes.append(getattr(mod, clsname))
130 except (ImportError, AttributeError):
131 pass
132 return tuple(classes)
135def patch_qwen3_mlp(model) -> int:
136 """Walk model, replace Qwen3MLP / Qwen3_5MLP forward with FusedMLPWrapper.
138 Returns number of MLP instances patched. Safe to call multiple times —
139 each instance is patched once (tracked via id).
140 """
141 mlp_classes = _get_qwen_mlp_classes()
142 if not mlp_classes:
143 logger.debug("No Qwen MLP classes found in transformers, skipping patch")
144 return 0
146 n = 0
147 for name, module in list(model.named_modules()):
148 if isinstance(module, mlp_classes) and id(module) not in _PATCHED:
149 wrapper = FusedMLPWrapper(
150 module.gate_proj,
151 module.up_proj,
152 module.down_proj,
153 module.act_fn,
154 )
155 module._original_forward = module.forward
156 module._fused_mlp_wrapper = wrapper
157 module.forward = types.MethodType(
158 lambda self, x, _w=wrapper: _w.forward(x),
159 module,
160 )
161 _PATCHED.add(id(module))
162 n += 1
163 if n > 0:
164 cls_names = ", ".join(c.__name__ for c in mlp_classes)
165 logger.info(
166 "Patched %d MLP modules (classes: %s) with fused_mlp_bf16", n, cls_names
167 )
168 return n
171def unpatch_qwen3_mlp(model) -> int:
172 """Restore original MLP forward (for testing / revert)."""
173 mlp_classes = _get_qwen_mlp_classes()
174 if not mlp_classes:
175 return 0
176 n = 0
177 for name, module in list(model.named_modules()):
178 if isinstance(module, mlp_classes) and id(module) in _PATCHED:
179 if hasattr(module, "_original_forward"):
180 module.forward = module._original_forward
181 del module._original_forward
182 del module._fused_mlp_wrapper
183 _PATCHED.discard(id(module))
184 n += 1
185 return n