Coverage for src/flag_gems/runtime/backend/_arm/ops/__init__.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from .addmm import addmm, addmm_out 

2from .all import all 

3from .any import any 

4from .argmax import argmax 

5from .attention import scaled_dot_product_attention 

6from .bmm import bmm 

7from .cumsum import cumsum 

8from .div import ( # noqa: F401 

9 div_mode, 

10 div_mode_, 

11 floor_divide, 

12 floor_divide_, 

13 remainder, 

14 remainder_, 

15 true_divide, 

16 true_divide_, 

17) 

18from .exponential_ import exponential_ 

19from .full import full 

20from .gather import gather 

21from .index_select import index_select 

22from .isin import isin 

23from .lt import lt 

24from .masked_fill import masked_fill 

25from .min import min 

26from .mm import mm, mm_out 

27from .multinomial import multinomial 

28from .pow import ( 

29 pow_scalar, 

30 pow_tensor_scalar, 

31 pow_tensor_scalar_, 

32 pow_tensor_tensor, 

33 pow_tensor_tensor_, 

34) 

35from .quantile import quantile 

36from .scatter import scatter 

37from .sort import sort 

38from .sub import sub 

39from .topk import topk 

40from .where import where_self_out 

41 

42__all__ = [ 

43 "addmm", 

44 "addmm_out", 

45 "all", 

46 "any", 

47 "argmax", 

48 "bmm", 

49 "cumsum", 

50 "div_mode", 

51 "div_mode_", 

52 "exponential_", 

53 "floor_divide", 

54 "floor_divide_", 

55 "full", 

56 "gather", 

57 "index_select", 

58 "isin", 

59 "lt", 

60 "masked_fill", 

61 "min", 

62 "mm", 

63 "mm_out", 

64 "multinomial", 

65 "pow_scalar", 

66 "pow_tensor_scalar", 

67 "pow_tensor_scalar_", 

68 "pow_tensor_tensor", 

69 "pow_tensor_tensor_", 

70 "quantile", 

71 "remainder", 

72 "remainder_", 

73 "scaled_dot_product_attention", 

74 "scatter", 

75 "sort", 

76 "sub", 

77 "topk", 

78 "where_self_out", 

79 "apply_arm_overrides", 

80] 

81 

82import logging as _logging # noqa: E402 

83 

84import torch as _torch # noqa: E402 

85 

86# --------------------------------------------------------------------------- 

87# ARM CPU op overrides (torch.library impls + flag_gems / F.* monkeypatches). 

88# 

89# These are NOT standard aten_lib ops registered through the FlagGems Registrar, 

90# so they cannot ride enable()/only_enable()'s _FULL_CONFIG include/exclude. They 

91# are collected in a name-keyed registry and applied through a single idempotent 

92# entry point, apply_arm_overrides(include=, exclude=), so callers can select a 

93# subset (e.g. exclude {"mm"} to avoid its prefill regression on prefill-heavy 

94# workloads). 

95# --------------------------------------------------------------------------- 

96 

97# torch.library handles must stay alive — GC would revoke the registration. 

98_argmax_aten_lib = None 

99_mm_aten_lib = None 

100 

101 

102def _register_quantized_linear_dynamic(): 

103 # Triton-CPU INT8 GEMM for quantized::linear_dynamic (quantized:: namespace). 

104 from .quantized_linear_dynamic import register as _reg 

105 

106 _reg() 

107 

108 

109def _register_int_mm(): 

110 # Triton-CPU INT8 GEMM for aten::_int_mm (enables torchao INT8 paths). 

111 from .int_mm import register as _reg 

112 

113 _reg() 

114 

115 

116def _register_argmax(): 

117 # FlagGems argmax for aten::argmax (decode lm_head: 2.2x faster for [1,151936]). 

118 global _argmax_aten_lib 

119 if _argmax_aten_lib is not None: 

120 return 

121 from .argmax import argmax as _fg_argmax 

122 

123 _argmax_aten_lib = _torch.library.Library("aten", "IMPL") 

124 _argmax_aten_lib.impl("argmax", _fg_argmax, "CPU", allow_override=True) 

125 _logging.getLogger(__name__).debug( 

126 "FlagGems ARM: registered Triton-CPU argmax for aten::argmax" 

127 ) 

128 

129 

130def _override_fused_add_rms_norm_with_arm(): 

131 # Standalone rms_norm override dropped: no measurable benefit vs ATen native 

132 # (Qwen3-1.7B INT8 decode A/B 3 rounds: 9.93 vs 9.97 tok/s, within noise). 

133 # fused_add_rms_norm kept: saves a residual-add memory roundtrip (vLLM path). 

134 import flag_gems as _fg 

135 

136 from .rms_norm import fused_add_rms_norm as _arm_fused_add_rms_norm 

137 

138 _fg.fused_add_rms_norm = _arm_fused_add_rms_norm 

139 _logging.getLogger(__name__).debug( 

140 "FlagGems ARM: overrode flag_gems.fused_add_rms_norm with ARM Triton kernel" 

141 ) 

142 

143 

144def _override_rope_with_arm(): 

145 # Generic fused/rotary_embedding.py uses @libentry() → DEVICE_COUNT crash on CPU. 

146 import flag_gems as _fg 

147 import flag_gems.fused as _fg_fused 

148 

149 from .rope import arm_apply_rotary_pos_emb as _arm_rope 

150 

151 _fg.apply_rotary_pos_emb = _arm_rope 

152 _fg_fused.apply_rotary_pos_emb = _arm_rope 

153 _logging.getLogger(__name__).debug( 

154 "FlagGems ARM: overrode flag_gems.apply_rotary_pos_emb with pure-PyTorch" 

155 ) 

156 

157 

158def _override_silu_and_mul_with_arm(): 

159 # Generic fused/silu_and_mul.py uses @pointwise_dynamic → @libentry() → crash. 

160 import flag_gems as _fg 

161 import flag_gems.fused as _fg_fused 

162 

163 from .silu_and_mul import arm_silu_and_mul as _arm_sam 

164 from .silu_and_mul import arm_silu_and_mul_out as _arm_sam_out 

165 

166 _fg.silu_and_mul = _arm_sam 

167 _fg.silu_and_mul_out = _arm_sam_out 

168 _fg_fused.silu_and_mul = _arm_sam 

169 _fg_fused.silu_and_mul_out = _arm_sam_out 

170 _logging.getLogger(__name__).debug( 

171 "FlagGems ARM: overrode flag_gems.silu_and_mul / silu_and_mul_out" 

172 ) 

173 

174 

175def _register_mm(): 

176 # aten::mm BF16 override. M=1 decode: 2-5x faster than ATen (unoptimized GEMV). 

177 # M=64 prefill: 2-3x slower (ATen uses native BF16 BFMMLA) — exclude "mm" for 

178 # prefill-heavy workloads. _mm_aten_lib must stay alive. 

179 global _mm_aten_lib 

180 if _mm_aten_lib is not None: 

181 return 

182 from .mm import mm as _fg_mm 

183 

184 _mm_aten_lib = _torch.library.Library("aten", "IMPL") 

185 _mm_aten_lib.impl("mm", _fg_mm, "CPU", allow_override=True) 

186 _logging.getLogger(__name__).debug( 

187 "FlagGems ARM: registered Triton-CPU mm for aten::mm" 

188 ) 

189 

190 

191def _register_sdpa(): 

192 # Triton-CPU Flash Attention for F.scaled_dot_product_attention (prefill 4-5x; 

193 # decode/other cases fall back to ATen). 

194 # 

195 # Intentionally a monkey-patch, NOT torch.library: a Library("aten","IMPL") 

196 # override would route the ATen fallback inside our own wrapper back to us 

197 # (infinite recursion). The monkey-patch leaves the original C++ dispatch 

198 # reachable via _aten_sdpa captured at import in attention.py. 

199 import torch.nn.functional as _F 

200 

201 from .attention import scaled_dot_product_attention as _fg_sdpa 

202 

203 _F.scaled_dot_product_attention = _fg_sdpa 

204 _logging.getLogger(__name__).debug( 

205 "FlagGems ARM: monkey-patched F.scaled_dot_product_attention " 

206 "with Triton Flash Attention (prefill 4-5x speedup)" 

207 ) 

208 

209 

210# Name → applier. Names match the aten/flag_gems op they override so callers can 

211# select with the same vocabulary as only_enable(include=[...]). 

212_ARM_OVERRIDE_REGISTRY = { 

213 "quantized_linear_dynamic": _register_quantized_linear_dynamic, 

214 "_int_mm": _register_int_mm, 

215 "argmax": _register_argmax, 

216 "fused_add_rms_norm": _override_fused_add_rms_norm_with_arm, 

217 "apply_rotary_pos_emb": _override_rope_with_arm, 

218 "silu_and_mul": _override_silu_and_mul_with_arm, 

219 "mm": _register_mm, 

220 "scaled_dot_product_attention": _register_sdpa, 

221} 

222 

223_ARM_OVERRIDES_APPLIED = set() 

224 

225 

226def apply_arm_overrides(include=None, exclude=None): 

227 """Apply ARM CPU op overrides, optionally restricted to a subset. 

228 

229 Args: 

230 include: iterable of override names to apply (None = all known). 

231 exclude: iterable of override names to skip (applied after include). 

232 

233 Idempotent: each override is applied at most once per process. Names are the 

234 keys of _ARM_OVERRIDE_REGISTRY (the aten/flag_gems op each one overrides). 

235 """ 

236 names = ( 

237 set(_ARM_OVERRIDE_REGISTRY) 

238 if include is None 

239 else set(include) & set(_ARM_OVERRIDE_REGISTRY) 

240 ) 

241 if exclude: 

242 names -= set(exclude) 

243 for name in names: 

244 if name in _ARM_OVERRIDES_APPLIED: 

245 continue 

246 try: 

247 _ARM_OVERRIDE_REGISTRY[name]() 

248 _ARM_OVERRIDES_APPLIED.add(name) 

249 except Exception as e: # noqa: BLE001 

250 _logging.getLogger(__name__).warning( 

251 f"FlagGems ARM: failed to apply override '{name}': {e}" 

252 ) 

253 

254 

255# NOTE: overrides are NOT applied on import. The caller selects which ones to 

256# engage via apply_arm_overrides(include=[...]) — mirroring FlagGems' 

257# only_enable() opt-in model. This avoids silently monkeypatching aten::mm / 

258# F.scaled_dot_product_attention / flag_gems.* process-wide just by importing, 

259# and lets a workload pick the net-positive subset (e.g. exclude "mm" on 

260# prefill-heavy runs where the decode-tuned mm regresses prefill). 

261# 

262# from flag_gems.runtime.backend._arm.ops import apply_arm_overrides 

263# apply_arm_overrides() # engage all curated overrides 

264# apply_arm_overrides(include=["mm", "argmax"]) # only these 

265# apply_arm_overrides(exclude=["mm"]) # all but mm