Coverage for src/flag_gems/runtime/backend/_arm/ops/rms_norm.py: 0%

3 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1""" 

2ARM CPU fused_add_rms_norm wrapper. 

3 

4Wraps the _arm/fused/fused_add_rms_norm.py Triton kernel so it can be used 

5as a drop-in replacement for flag_gems.fused_add_rms_norm on ARM64 CPU. 

6 

7Standalone rms_norm (without residual add) was removed: A/B measurement on 

8Qwen3-1.7B INT8 decode showed no measurable benefit over ATen's native 

9Qwen3RMSNorm.forward. See _arm/fused/fused_add_rms_norm.py for the note. 

10""" 

11 

12from flag_gems.runtime.backend._arm.fused.fused_add_rms_norm import ( 

13 fused_add_rms_norm as _arm_fused_add_rms_norm, 

14) 

15 

16 

17def fused_add_rms_norm(x, residual, normalized_shape, weight, eps=1e-5): 

18 """ 

19 ARM CPU drop-in for flag_gems.fused_add_rms_norm. 

20 

21 In-place: residual = x + residual; x = rms_norm(residual) * weight. 

22 Returns (x, residual). 

23 """ 

24 return _arm_fused_add_rms_norm(x, residual, normalized_shape, weight, eps)