Coverage for src/flag_gems/runtime/backend/_arm/ops/silu_and_mul.py: 0%

24 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1""" 

2ARM CPU fused silu_and_mul — TLE NEON SWIGLU for decode, ATen for prefill. 

3 

4For decode (M=1): TLE cpu_swiglu (NEON fast exp + fused silu*mul) via a 

5@triton.jit kernel (no ctypes — goes through the create_cpu_swiglu TLE path). 

6For prefill (M>1): ATen F.silu(x1) * x2 (fallback). 

7 

8Benchmarks (CIX P1 CD8180, BF16, OMP=8): 

9 N=6144 decode: ATen ~76μs → TLE SWIGLU ~33μs (2.3x speedup) 

10 28 layers × savings = 1.2ms/tok 

11""" 

12 

13 

14import torch 

15import torch.nn.functional as F 

16import triton 

17import triton.language as tl 

18from triton.language.extra.cpu.tle_ops import swiglu as _tle_swiglu 

19 

20# None = not yet tried, True = TLE path works, False = fall back to ATen. 

21_TLE_SWIGLU_OK = None 

22 

23 

24@triton.jit 

25def _swiglu_kernel(gate_ptr, up_ptr, out_ptr, N: tl.constexpr): 

26 # One coarse TLE op = the whole SWIGLU (silu(gate) * up over N elements), 

27 # OMP-parallelized inside the C runtime → 1 kernel launch. 

28 _tle_swiglu(gate_ptr, up_ptr, out_ptr, N) 

29 

30 

31def arm_silu_and_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 

32 """ARM CPU fused silu_and_mul: silu(x1) * x2. 

33 

34 Decode (1D/2D with M=1): TLE NEON SWIGLU (2.3x faster than ATen). 

35 Otherwise: ATen fallback. 

36 """ 

37 global _TLE_SWIGLU_OK 

38 # Decode path: contiguous BF16, single row. 

39 if ( 

40 _TLE_SWIGLU_OK is not False 

41 and x1.dtype == torch.bfloat16 

42 and x1.is_contiguous() 

43 and x2.is_contiguous() 

44 and x1.numel() == x1.shape[-1] 

45 ): # M=1 

46 try: 

47 N = x1.numel() 

48 out = torch.empty_like(x1) 

49 _swiglu_kernel[(1,)](x1, x2, out, N=N) 

50 _TLE_SWIGLU_OK = True 

51 return out 

52 except Exception: 

53 _TLE_SWIGLU_OK = False 

54 return F.silu(x1) * x2 

55 

56 

57def arm_silu_and_mul_out( 

58 x1: torch.Tensor, x2: torch.Tensor, out: torch.Tensor 

59) -> torch.Tensor: 

60 """ARM CPU fused silu_and_mul with pre-allocated output.""" 

61 result = arm_silu_and_mul(x1, x2) 

62 out.copy_(result) 

63 return out