Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_qwen3_mlp.py: 0%

77 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Monkey-patch Qwen3MLP.forward to use triton-cpu's fused_mlp_bf16 kernel. 

2 

3This patch replaces the 5-op ATen sequence 

4 gate_proj(x) → silu → up_proj(x) → mul → down_proj 

5with a single fused C kernel call when: 

6 - decode shape (M=1) 

7 - BF16 activation 

8 - gate_proj and up_proj are INT8 SDOT-packed Linears 

9 (expose attributes: _packed, _w_scale, K, N) 

10 

11Measured benefit on Qwen3-1.7B W8A8-INT8 decode (3 rounds × 5 runs median, 

12CIX P1 CD8180, 8 big cores, OMP=8, performance governor): 

13 

14 ENABLE_MLP_PATCH=1 ON → 9.92 tok/s median (9.88, 10.04, 9.92) 

15 ENABLE_MLP_PATCH=0 OFF → 9.73 tok/s median (9.61, 9.73, 9.76) 

16 → +1.95% median (+2.5% mean) consistent across 3 rounds. 

17 

18Usage: 

19 from flag_gems.runtime.backend._arm.fused.patch_qwen3_mlp import patch_qwen3_mlp 

20 patch_qwen3_mlp(model) 

21""" 

22 

23import logging 

24import types 

25 

26import torch 

27import triton 

28import triton.language as tl 

29from triton.language.extra.cpu.tle_ops import fused_mlp as _tle_fused_mlp 

30 

31logger = logging.getLogger(__name__) 

32 

33_PATCHED: set = set() 

34 

35 

36@triton.jit 

37def _fused_mlp_kernel( 

38 x_ptr, 

39 gate_packed_ptr, 

40 up_packed_ptr, 

41 gate_scale_ptr, 

42 up_scale_ptr, 

43 out_ptr, 

44 K: tl.constexpr, 

45 N: tl.constexpr, 

46): 

47 # Coarse TLE op: gate GEMV + up GEMV + SWIGLU fused in one kernel launch. 

48 _tle_fused_mlp( 

49 x_ptr, 

50 gate_packed_ptr, 

51 up_packed_ptr, 

52 gate_scale_ptr, 

53 up_scale_ptr, 

54 out_ptr, 

55 K, 

56 N, 

57 ) 

58 

59 

60class FusedMLPWrapper: 

61 """Holds references to a Qwen3MLP's 3 projections + act_fn, exposes 

62 a forward that dispatches to the fused C kernel on M=1 BF16 decode. 

63 

64 Falls back to composing gate/up/down via their own forward for: 

65 - M > 1 (prefill) 

66 - non-BF16 activation 

67 - gate/up not SDOT-packed INT8 Linears 

68 """ 

69 

70 def __init__(self, gate_linear, up_linear, down_linear, act_fn): 

71 self._gate_linear = gate_linear 

72 self._up_linear = up_linear 

73 self.down_proj = down_linear 

74 self.act_fn = act_fn 

75 

76 self._fused = ( 

77 hasattr(gate_linear, "_packed") 

78 and hasattr(up_linear, "_packed") 

79 and hasattr(gate_linear, "_w_scale") 

80 and hasattr(up_linear, "_w_scale") 

81 and hasattr(gate_linear, "K") 

82 and hasattr(gate_linear, "N") 

83 ) 

84 if self._fused: 

85 self._gate_packed = gate_linear._packed 

86 self._up_packed = up_linear._packed 

87 self._gate_scale = gate_linear._w_scale 

88 self._up_scale = up_linear._w_scale 

89 self._K = gate_linear.K 

90 self._N = gate_linear.N 

91 

92 def forward(self, x): 

93 shape = x.shape 

94 M = x.numel() // shape[-1] 

95 if self._fused and M == 1 and x.dtype == torch.bfloat16: 

96 xc = x.reshape(-1).contiguous() 

97 out = torch.empty(self._N, dtype=torch.bfloat16) 

98 _fused_mlp_kernel[(1,)]( 

99 xc, 

100 self._gate_packed, 

101 self._up_packed, 

102 self._gate_scale, 

103 self._up_scale, 

104 out, 

105 K=self._K, 

106 N=self._N, 

107 ) 

108 return self.down_proj(out.reshape(*shape[:-1], self._N)) 

109 # ATen fallback: compose gate+up+silu+mul+down via each Linear's own forward 

110 gate = self._gate_linear(x) 

111 up = self._up_linear(x) 

112 return self.down_proj(self.act_fn(gate) * up) 

113 

114 

115def _get_qwen_mlp_classes() -> tuple: 

116 """Return a tuple of MLP classes to patch (Qwen3MLP + Qwen3_5MLP if available). 

117 

118 Both classes share the same structure (gate_proj, up_proj, down_proj, act_fn), 

119 so the FusedMLPWrapper works on either. 

120 """ 

121 classes = [] 

122 for modname, clsname in [ 

123 ("transformers.models.qwen3.modeling_qwen3", "Qwen3MLP"), 

124 ("transformers.models.qwen3_5.modeling_qwen3_5", "Qwen3_5MLP"), 

125 ("transformers.models.llama.modeling_llama", "LlamaMLP"), # MiniCPM5 etc. 

126 ]: 

127 try: 

128 mod = __import__(modname, fromlist=[clsname]) 

129 classes.append(getattr(mod, clsname)) 

130 except (ImportError, AttributeError): 

131 pass 

132 return tuple(classes) 

133 

134 

135def patch_qwen3_mlp(model) -> int: 

136 """Walk model, replace Qwen3MLP / Qwen3_5MLP forward with FusedMLPWrapper. 

137 

138 Returns number of MLP instances patched. Safe to call multiple times — 

139 each instance is patched once (tracked via id). 

140 """ 

141 mlp_classes = _get_qwen_mlp_classes() 

142 if not mlp_classes: 

143 logger.debug("No Qwen MLP classes found in transformers, skipping patch") 

144 return 0 

145 

146 n = 0 

147 for name, module in list(model.named_modules()): 

148 if isinstance(module, mlp_classes) and id(module) not in _PATCHED: 

149 wrapper = FusedMLPWrapper( 

150 module.gate_proj, 

151 module.up_proj, 

152 module.down_proj, 

153 module.act_fn, 

154 ) 

155 module._original_forward = module.forward 

156 module._fused_mlp_wrapper = wrapper 

157 module.forward = types.MethodType( 

158 lambda self, x, _w=wrapper: _w.forward(x), 

159 module, 

160 ) 

161 _PATCHED.add(id(module)) 

162 n += 1 

163 if n > 0: 

164 cls_names = ", ".join(c.__name__ for c in mlp_classes) 

165 logger.info( 

166 "Patched %d MLP modules (classes: %s) with fused_mlp_bf16", n, cls_names 

167 ) 

168 return n 

169 

170 

171def unpatch_qwen3_mlp(model) -> int: 

172 """Restore original MLP forward (for testing / revert).""" 

173 mlp_classes = _get_qwen_mlp_classes() 

174 if not mlp_classes: 

175 return 0 

176 n = 0 

177 for name, module in list(model.named_modules()): 

178 if isinstance(module, mlp_classes) and id(module) in _PATCHED: 

179 if hasattr(module, "_original_forward"): 

180 module.forward = module._original_forward 

181 del module._original_forward 

182 del module._fused_mlp_wrapper 

183 _PATCHED.discard(id(module)) 

184 n += 1 

185 return n