Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_llama_arch.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Apply Qwen3-equivalent TLE patches to HF Llama models (e.g. MiniCPM5-0.9B).
3Llama interfaces are nearly identical to Qwen3 for the three op-density patches:
4- apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) — same sig
5- LlamaRMSNorm.forward(self, hidden_states) — same math (weight * x)
6- LlamaDecoderLayer.forward — same structure
8This file reuses the kernels from patch_qwen3_* and just retargets Llama's
9modeling module.
10"""
11import logging
13from flag_gems.runtime.backend._arm.fused.patch_qwen3_layer_norm import (
14 _PATCHED as _LAYER_PATCHED,
15)
16from flag_gems.runtime.backend._arm.fused.patch_qwen3_layer_norm import (
17 _make_patched_forward as _make_layer_patched,
18)
19from flag_gems.runtime.backend._arm.fused.patch_qwen3_rmsnorm import (
20 _PATCHED as _RMS_PATCHED,
21)
22from flag_gems.runtime.backend._arm.fused.patch_qwen3_rmsnorm import (
23 _make_patched_forward as _make_rmsnorm_patched,
24)
26# Import the kernels + patch helpers from the Qwen3 patches.
27from flag_gems.runtime.backend._arm.fused.patch_qwen3_rope import (
28 _PATCHED as _ROPE_PATCHED,
29)
30from flag_gems.runtime.backend._arm.fused.patch_qwen3_rope import (
31 _patched_apply_rotary_pos_emb,
32)
34logger = logging.getLogger(__name__)
36_LLAMA_MODULE = "transformers.models.llama.modeling_llama"
39def patch_llama_rope() -> int:
40 """Replace Llama apply_rotary_pos_emb with our TLE @triton.jit kernel."""
41 try:
42 mod = __import__(_LLAMA_MODULE, fromlist=["apply_rotary_pos_emb"])
43 except (ImportError, AttributeError):
44 logger.warning("Llama modeling module not available")
45 return 0
46 if not hasattr(mod, "apply_rotary_pos_emb"):
47 return 0
48 if _LLAMA_MODULE in _ROPE_PATCHED:
49 return 0
50 original = getattr(mod, "apply_rotary_pos_emb")
51 _ROPE_PATCHED["original"] = original # NOTE: shared with qwen3 patch
52 _ROPE_PATCHED[_LLAMA_MODULE] = True
53 setattr(mod, "apply_rotary_pos_emb", _patched_apply_rotary_pos_emb)
54 logger.info(f"Patched {_LLAMA_MODULE}.apply_rotary_pos_emb")
55 return 1
58def patch_llama_rmsnorm() -> int:
59 """Replace LlamaRMSNorm.forward with single Triton kernel."""
60 try:
61 mod = __import__(_LLAMA_MODULE, fromlist=["LlamaRMSNorm"])
62 except (ImportError, AttributeError):
63 return 0
64 if not hasattr(mod, "LlamaRMSNorm"):
65 return 0
66 cls = getattr(mod, "LlamaRMSNorm")
67 key = (_LLAMA_MODULE, "LlamaRMSNorm")
68 if key in _RMS_PATCHED:
69 return 0
70 orig = cls.forward
71 _RMS_PATCHED[key] = (cls, orig)
72 cls.forward = _make_rmsnorm_patched(orig)
73 logger.info(f"Patched {_LLAMA_MODULE}.LlamaRMSNorm.forward")
74 return 1
77def patch_llama_layer_norm() -> int:
78 """Wire fused_add_rms_norm into LlamaDecoderLayer.forward.
80 Note: the kernel reads `self.post_attention_layernorm.weight` and
81 `.variance_epsilon` — both present on LlamaRMSNorm. ✓
82 """
83 try:
84 mod = __import__(_LLAMA_MODULE, fromlist=["LlamaDecoderLayer"])
85 except (ImportError, AttributeError):
86 return 0
87 if not hasattr(mod, "LlamaDecoderLayer"):
88 return 0
89 cls = getattr(mod, "LlamaDecoderLayer")
90 key = (_LLAMA_MODULE, "LlamaDecoderLayer")
91 if key in _LAYER_PATCHED:
92 return 0
93 orig = cls.forward
94 _LAYER_PATCHED[key] = (cls, orig)
95 cls.forward = _make_layer_patched(orig)
96 logger.info(f"Patched {_LLAMA_MODULE}.LlamaDecoderLayer.forward")
97 return 1
100def patch_llama_arch() -> dict:
101 """Apply all three patches; returns counts."""
102 return {
103 "rope": patch_llama_rope(),
104 "rmsnorm": patch_llama_rmsnorm(),
105 "layer_norm": patch_llama_layer_norm(),
106 }