Coverage for src/flag_gems/runtime/backend/_arm/fused/patch_llama_arch.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1"""Apply Qwen3-equivalent TLE patches to HF Llama models (e.g. MiniCPM5-0.9B). 

2 

3Llama interfaces are nearly identical to Qwen3 for the three op-density patches: 

4- apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) — same sig 

5- LlamaRMSNorm.forward(self, hidden_states) — same math (weight * x) 

6- LlamaDecoderLayer.forward — same structure 

7 

8This file reuses the kernels from patch_qwen3_* and just retargets Llama's 

9modeling module. 

10""" 

11import logging 

12 

13from flag_gems.runtime.backend._arm.fused.patch_qwen3_layer_norm import ( 

14 _PATCHED as _LAYER_PATCHED, 

15) 

16from flag_gems.runtime.backend._arm.fused.patch_qwen3_layer_norm import ( 

17 _make_patched_forward as _make_layer_patched, 

18) 

19from flag_gems.runtime.backend._arm.fused.patch_qwen3_rmsnorm import ( 

20 _PATCHED as _RMS_PATCHED, 

21) 

22from flag_gems.runtime.backend._arm.fused.patch_qwen3_rmsnorm import ( 

23 _make_patched_forward as _make_rmsnorm_patched, 

24) 

25 

26# Import the kernels + patch helpers from the Qwen3 patches. 

27from flag_gems.runtime.backend._arm.fused.patch_qwen3_rope import ( 

28 _PATCHED as _ROPE_PATCHED, 

29) 

30from flag_gems.runtime.backend._arm.fused.patch_qwen3_rope import ( 

31 _patched_apply_rotary_pos_emb, 

32) 

33 

34logger = logging.getLogger(__name__) 

35 

36_LLAMA_MODULE = "transformers.models.llama.modeling_llama" 

37 

38 

39def patch_llama_rope() -> int: 

40 """Replace Llama apply_rotary_pos_emb with our TLE @triton.jit kernel.""" 

41 try: 

42 mod = __import__(_LLAMA_MODULE, fromlist=["apply_rotary_pos_emb"]) 

43 except (ImportError, AttributeError): 

44 logger.warning("Llama modeling module not available") 

45 return 0 

46 if not hasattr(mod, "apply_rotary_pos_emb"): 

47 return 0 

48 if _LLAMA_MODULE in _ROPE_PATCHED: 

49 return 0 

50 original = getattr(mod, "apply_rotary_pos_emb") 

51 _ROPE_PATCHED["original"] = original # NOTE: shared with qwen3 patch 

52 _ROPE_PATCHED[_LLAMA_MODULE] = True 

53 setattr(mod, "apply_rotary_pos_emb", _patched_apply_rotary_pos_emb) 

54 logger.info(f"Patched {_LLAMA_MODULE}.apply_rotary_pos_emb") 

55 return 1 

56 

57 

58def patch_llama_rmsnorm() -> int: 

59 """Replace LlamaRMSNorm.forward with single Triton kernel.""" 

60 try: 

61 mod = __import__(_LLAMA_MODULE, fromlist=["LlamaRMSNorm"]) 

62 except (ImportError, AttributeError): 

63 return 0 

64 if not hasattr(mod, "LlamaRMSNorm"): 

65 return 0 

66 cls = getattr(mod, "LlamaRMSNorm") 

67 key = (_LLAMA_MODULE, "LlamaRMSNorm") 

68 if key in _RMS_PATCHED: 

69 return 0 

70 orig = cls.forward 

71 _RMS_PATCHED[key] = (cls, orig) 

72 cls.forward = _make_rmsnorm_patched(orig) 

73 logger.info(f"Patched {_LLAMA_MODULE}.LlamaRMSNorm.forward") 

74 return 1 

75 

76 

77def patch_llama_layer_norm() -> int: 

78 """Wire fused_add_rms_norm into LlamaDecoderLayer.forward. 

79 

80 Note: the kernel reads `self.post_attention_layernorm.weight` and 

81 `.variance_epsilon` — both present on LlamaRMSNorm. ✓ 

82 """ 

83 try: 

84 mod = __import__(_LLAMA_MODULE, fromlist=["LlamaDecoderLayer"]) 

85 except (ImportError, AttributeError): 

86 return 0 

87 if not hasattr(mod, "LlamaDecoderLayer"): 

88 return 0 

89 cls = getattr(mod, "LlamaDecoderLayer") 

90 key = (_LLAMA_MODULE, "LlamaDecoderLayer") 

91 if key in _LAYER_PATCHED: 

92 return 0 

93 orig = cls.forward 

94 _LAYER_PATCHED[key] = (cls, orig) 

95 cls.forward = _make_layer_patched(orig) 

96 logger.info(f"Patched {_LLAMA_MODULE}.LlamaDecoderLayer.forward") 

97 return 1 

98 

99 

100def patch_llama_arch() -> dict: 

101 """Apply all three patches; returns counts.""" 

102 return { 

103 "rope": patch_llama_rope(), 

104 "rmsnorm": patch_llama_rmsnorm(), 

105 "layer_norm": patch_llama_layer_norm(), 

106 }