Coverage for src/flag_gems/runtime/backend/_arm/ops/silu_and_mul.py: 0%
24 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""
2ARM CPU fused silu_and_mul — TLE NEON SWIGLU for decode, ATen for prefill.
4For decode (M=1): TLE cpu_swiglu (NEON fast exp + fused silu*mul) via a
5@triton.jit kernel (no ctypes — goes through the create_cpu_swiglu TLE path).
6For prefill (M>1): ATen F.silu(x1) * x2 (fallback).
8Benchmarks (CIX P1 CD8180, BF16, OMP=8):
9 N=6144 decode: ATen ~76μs → TLE SWIGLU ~33μs (2.3x speedup)
10 28 layers × savings = 1.2ms/tok
11"""
14import torch
15import torch.nn.functional as F
16import triton
17import triton.language as tl
18from triton.language.extra.cpu.tle_ops import swiglu as _tle_swiglu
20# None = not yet tried, True = TLE path works, False = fall back to ATen.
21_TLE_SWIGLU_OK = None
24@triton.jit
25def _swiglu_kernel(gate_ptr, up_ptr, out_ptr, N: tl.constexpr):
26 # One coarse TLE op = the whole SWIGLU (silu(gate) * up over N elements),
27 # OMP-parallelized inside the C runtime → 1 kernel launch.
28 _tle_swiglu(gate_ptr, up_ptr, out_ptr, N)
31def arm_silu_and_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
32 """ARM CPU fused silu_and_mul: silu(x1) * x2.
34 Decode (1D/2D with M=1): TLE NEON SWIGLU (2.3x faster than ATen).
35 Otherwise: ATen fallback.
36 """
37 global _TLE_SWIGLU_OK
38 # Decode path: contiguous BF16, single row.
39 if (
40 _TLE_SWIGLU_OK is not False
41 and x1.dtype == torch.bfloat16
42 and x1.is_contiguous()
43 and x2.is_contiguous()
44 and x1.numel() == x1.shape[-1]
45 ): # M=1
46 try:
47 N = x1.numel()
48 out = torch.empty_like(x1)
49 _swiglu_kernel[(1,)](x1, x2, out, N=N)
50 _TLE_SWIGLU_OK = True
51 return out
52 except Exception:
53 _TLE_SWIGLU_OK = False
54 return F.silu(x1) * x2
57def arm_silu_and_mul_out(
58 x1: torch.Tensor, x2: torch.Tensor, out: torch.Tensor
59) -> torch.Tensor:
60 """ARM CPU fused silu_and_mul with pre-allocated output."""
61 result = arm_silu_and_mul(x1, x2)
62 out.copy_(result)
63 return out