Coverage for src/flag_gems/runtime/backend/_arm/ops/__init__.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1from .addmm import addmm, addmm_out
2from .all import all
3from .any import any
4from .argmax import argmax
5from .attention import scaled_dot_product_attention
6from .bmm import bmm
7from .cumsum import cumsum
8from .div import ( # noqa: F401
9 div_mode,
10 div_mode_,
11 floor_divide,
12 floor_divide_,
13 remainder,
14 remainder_,
15 true_divide,
16 true_divide_,
17)
18from .exponential_ import exponential_
19from .full import full
20from .gather import gather
21from .index_select import index_select
22from .isin import isin
23from .lt import lt
24from .masked_fill import masked_fill
25from .min import min
26from .mm import mm, mm_out
27from .multinomial import multinomial
28from .pow import (
29 pow_scalar,
30 pow_tensor_scalar,
31 pow_tensor_scalar_,
32 pow_tensor_tensor,
33 pow_tensor_tensor_,
34)
35from .quantile import quantile
36from .scatter import scatter
37from .sort import sort
38from .sub import sub
39from .topk import topk
40from .where import where_self_out
42__all__ = [
43 "addmm",
44 "addmm_out",
45 "all",
46 "any",
47 "argmax",
48 "bmm",
49 "cumsum",
50 "div_mode",
51 "div_mode_",
52 "exponential_",
53 "floor_divide",
54 "floor_divide_",
55 "full",
56 "gather",
57 "index_select",
58 "isin",
59 "lt",
60 "masked_fill",
61 "min",
62 "mm",
63 "mm_out",
64 "multinomial",
65 "pow_scalar",
66 "pow_tensor_scalar",
67 "pow_tensor_scalar_",
68 "pow_tensor_tensor",
69 "pow_tensor_tensor_",
70 "quantile",
71 "remainder",
72 "remainder_",
73 "scaled_dot_product_attention",
74 "scatter",
75 "sort",
76 "sub",
77 "topk",
78 "where_self_out",
79 "apply_arm_overrides",
80]
82import logging as _logging # noqa: E402
84import torch as _torch # noqa: E402
86# ---------------------------------------------------------------------------
87# ARM CPU op overrides (torch.library impls + flag_gems / F.* monkeypatches).
88#
89# These are NOT standard aten_lib ops registered through the FlagGems Registrar,
90# so they cannot ride enable()/only_enable()'s _FULL_CONFIG include/exclude. They
91# are collected in a name-keyed registry and applied through a single idempotent
92# entry point, apply_arm_overrides(include=, exclude=), so callers can select a
93# subset (e.g. exclude {"mm"} to avoid its prefill regression on prefill-heavy
94# workloads).
95# ---------------------------------------------------------------------------
97# torch.library handles must stay alive — GC would revoke the registration.
98_argmax_aten_lib = None
99_mm_aten_lib = None
102def _register_quantized_linear_dynamic():
103 # Triton-CPU INT8 GEMM for quantized::linear_dynamic (quantized:: namespace).
104 from .quantized_linear_dynamic import register as _reg
106 _reg()
109def _register_int_mm():
110 # Triton-CPU INT8 GEMM for aten::_int_mm (enables torchao INT8 paths).
111 from .int_mm import register as _reg
113 _reg()
116def _register_argmax():
117 # FlagGems argmax for aten::argmax (decode lm_head: 2.2x faster for [1,151936]).
118 global _argmax_aten_lib
119 if _argmax_aten_lib is not None:
120 return
121 from .argmax import argmax as _fg_argmax
123 _argmax_aten_lib = _torch.library.Library("aten", "IMPL")
124 _argmax_aten_lib.impl("argmax", _fg_argmax, "CPU", allow_override=True)
125 _logging.getLogger(__name__).debug(
126 "FlagGems ARM: registered Triton-CPU argmax for aten::argmax"
127 )
130def _override_fused_add_rms_norm_with_arm():
131 # Standalone rms_norm override dropped: no measurable benefit vs ATen native
132 # (Qwen3-1.7B INT8 decode A/B 3 rounds: 9.93 vs 9.97 tok/s, within noise).
133 # fused_add_rms_norm kept: saves a residual-add memory roundtrip (vLLM path).
134 import flag_gems as _fg
136 from .rms_norm import fused_add_rms_norm as _arm_fused_add_rms_norm
138 _fg.fused_add_rms_norm = _arm_fused_add_rms_norm
139 _logging.getLogger(__name__).debug(
140 "FlagGems ARM: overrode flag_gems.fused_add_rms_norm with ARM Triton kernel"
141 )
144def _override_rope_with_arm():
145 # Generic fused/rotary_embedding.py uses @libentry() → DEVICE_COUNT crash on CPU.
146 import flag_gems as _fg
147 import flag_gems.fused as _fg_fused
149 from .rope import arm_apply_rotary_pos_emb as _arm_rope
151 _fg.apply_rotary_pos_emb = _arm_rope
152 _fg_fused.apply_rotary_pos_emb = _arm_rope
153 _logging.getLogger(__name__).debug(
154 "FlagGems ARM: overrode flag_gems.apply_rotary_pos_emb with pure-PyTorch"
155 )
158def _override_silu_and_mul_with_arm():
159 # Generic fused/silu_and_mul.py uses @pointwise_dynamic → @libentry() → crash.
160 import flag_gems as _fg
161 import flag_gems.fused as _fg_fused
163 from .silu_and_mul import arm_silu_and_mul as _arm_sam
164 from .silu_and_mul import arm_silu_and_mul_out as _arm_sam_out
166 _fg.silu_and_mul = _arm_sam
167 _fg.silu_and_mul_out = _arm_sam_out
168 _fg_fused.silu_and_mul = _arm_sam
169 _fg_fused.silu_and_mul_out = _arm_sam_out
170 _logging.getLogger(__name__).debug(
171 "FlagGems ARM: overrode flag_gems.silu_and_mul / silu_and_mul_out"
172 )
175def _register_mm():
176 # aten::mm BF16 override. M=1 decode: 2-5x faster than ATen (unoptimized GEMV).
177 # M=64 prefill: 2-3x slower (ATen uses native BF16 BFMMLA) — exclude "mm" for
178 # prefill-heavy workloads. _mm_aten_lib must stay alive.
179 global _mm_aten_lib
180 if _mm_aten_lib is not None:
181 return
182 from .mm import mm as _fg_mm
184 _mm_aten_lib = _torch.library.Library("aten", "IMPL")
185 _mm_aten_lib.impl("mm", _fg_mm, "CPU", allow_override=True)
186 _logging.getLogger(__name__).debug(
187 "FlagGems ARM: registered Triton-CPU mm for aten::mm"
188 )
191def _register_sdpa():
192 # Triton-CPU Flash Attention for F.scaled_dot_product_attention (prefill 4-5x;
193 # decode/other cases fall back to ATen).
194 #
195 # Intentionally a monkey-patch, NOT torch.library: a Library("aten","IMPL")
196 # override would route the ATen fallback inside our own wrapper back to us
197 # (infinite recursion). The monkey-patch leaves the original C++ dispatch
198 # reachable via _aten_sdpa captured at import in attention.py.
199 import torch.nn.functional as _F
201 from .attention import scaled_dot_product_attention as _fg_sdpa
203 _F.scaled_dot_product_attention = _fg_sdpa
204 _logging.getLogger(__name__).debug(
205 "FlagGems ARM: monkey-patched F.scaled_dot_product_attention "
206 "with Triton Flash Attention (prefill 4-5x speedup)"
207 )
210# Name → applier. Names match the aten/flag_gems op they override so callers can
211# select with the same vocabulary as only_enable(include=[...]).
212_ARM_OVERRIDE_REGISTRY = {
213 "quantized_linear_dynamic": _register_quantized_linear_dynamic,
214 "_int_mm": _register_int_mm,
215 "argmax": _register_argmax,
216 "fused_add_rms_norm": _override_fused_add_rms_norm_with_arm,
217 "apply_rotary_pos_emb": _override_rope_with_arm,
218 "silu_and_mul": _override_silu_and_mul_with_arm,
219 "mm": _register_mm,
220 "scaled_dot_product_attention": _register_sdpa,
221}
223_ARM_OVERRIDES_APPLIED = set()
226def apply_arm_overrides(include=None, exclude=None):
227 """Apply ARM CPU op overrides, optionally restricted to a subset.
229 Args:
230 include: iterable of override names to apply (None = all known).
231 exclude: iterable of override names to skip (applied after include).
233 Idempotent: each override is applied at most once per process. Names are the
234 keys of _ARM_OVERRIDE_REGISTRY (the aten/flag_gems op each one overrides).
235 """
236 names = (
237 set(_ARM_OVERRIDE_REGISTRY)
238 if include is None
239 else set(include) & set(_ARM_OVERRIDE_REGISTRY)
240 )
241 if exclude:
242 names -= set(exclude)
243 for name in names:
244 if name in _ARM_OVERRIDES_APPLIED:
245 continue
246 try:
247 _ARM_OVERRIDE_REGISTRY[name]()
248 _ARM_OVERRIDES_APPLIED.add(name)
249 except Exception as e: # noqa: BLE001
250 _logging.getLogger(__name__).warning(
251 f"FlagGems ARM: failed to apply override '{name}': {e}"
252 )
255# NOTE: overrides are NOT applied on import. The caller selects which ones to
256# engage via apply_arm_overrides(include=[...]) — mirroring FlagGems'
257# only_enable() opt-in model. This avoids silently monkeypatching aten::mm /
258# F.scaled_dot_product_attention / flag_gems.* process-wide just by importing,
259# and lets a workload pick the net-positive subset (e.g. exclude "mm" on
260# prefill-heavy runs where the decode-tuned mm regresses prefill).
261#
262# from flag_gems.runtime.backend._arm.ops import apply_arm_overrides
263# apply_arm_overrides() # engage all curated overrides
264# apply_arm_overrides(include=["mm", "argmax"]) # only these
265# apply_arm_overrides(exclude=["mm"]) # all but mm