Coverage for src/flag_gems/runtime/backend/_arm/int8/tle_int8_linear.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Drop-in replacement for nn.Linear using TLE SDOT GEMV decode + torch._int_mm prefill.
3Decode (M=1, BF16): BF16 activation → in-kernel quant → INT8 SDOT matmul →
4BF16 output, all in one @triton.jit that calls triton-cpu's sdot_gemv_fused_bf16
5TLE builtin (which dispatches to the NEON SDOT C runtime).
7Prefill (M>1): per-row dynamic INT8 quantization in fp32 Python, then
8torch._int_mm (which is hooked by flag_gems _arm/ops/int_mm.py to route
9to the Triton SVE2 i8mm kernel), then external dequant.
11The class exposes the attributes FusedMLPWrapper (in fused/patch_qwen3_mlp.py)
12checks for: _packed, _w_scale, K, N — so when gate/up/down are replaced with
13TLEInt8Linear, the MLP patch can fuse them into fused_mlp_bf16.
14"""
16import torch
17import triton
18import triton.language as tl
19from triton.language.extra.cpu.tle_ops import sdot_gemv_fused_bf16 as _cpu_fused_gemv
21# Prefill GEMM goes through torch._int_mm, which is routed to FlagGems' Triton
22# SVE2 i8mm kernel by the aten::_int_mm CPU override. That override is opt-in
23# (apply_arm_overrides), engaged by quantize_and_replace_linears / replace_*
24# at setup; without it torch._int_mm falls back to ATen's scalar _int_mm.
27@triton.jit
28def _tle_fused_bf16_gemv_kernel(
29 x_ptr,
30 b_packed_ptr,
31 w_scale_ptr,
32 out_ptr,
33 K: tl.constexpr,
34 N: tl.constexpr,
35):
36 """TLE GEMV: BF16 x [K] @ INT8 packed W [K//4, N//4, 4, 4] → BF16 out [N].
38 Performs in-kernel dynamic INT8 quantization of x and dequantization of
39 output. Single OMP region, no intermediate tensors.
40 """
41 _cpu_fused_gemv(x_ptr, b_packed_ptr, w_scale_ptr, out_ptr, K, N)
44def pack_weights_sdot(w_kn: torch.Tensor) -> torch.Tensor:
45 """Pack row-major [K, N] INT8 weight into SDOT-friendly [K//4, N//4, 4, 4].
47 SDOT loads 4 consecutive K-bytes from one lane and broadcasts to 4 N-lanes.
48 The packed layout ensures each SDOT tile is contiguous in memory for
49 maximum L1 cache efficiency. Requires K%4==0 and N%4==0.
50 """
51 K, N = w_kn.shape
52 if K % 4 != 0 or N % 4 != 0:
53 raise ValueError(
54 f"pack_weights_sdot requires K%4==0 and N%4==0, got K={K} N={N}"
55 )
56 return w_kn.reshape(K // 4, 4, N // 4, 4).permute(0, 2, 3, 1).contiguous()
59class TLEInt8Linear(torch.nn.Module):
60 """nn.Linear replacement with TLE SDOT decode + torch._int_mm prefill.
62 Args:
63 w_int8: [N, K] int8 tensor (pre-quantized weight, same layout as nn.Linear's
64 .weight.data attribute but dtype=int8).
65 w_scale: [N] fp32 tensor (per-column weight scales); scalar broadcasted
66 tensors also accepted.
68 Required: K % 4 == 0 and N % 4 == 0 (SDOT lane requirement).
70 Attributes exposed for downstream fusion passes (e.g. patch_qwen3_mlp):
71 _packed: [K//4, N//4, 4, 4] int8 — for SDOT decode
72 _w_int8_kn: [K, N] int8 — for torch._int_mm prefill
73 _w_scale: [N] fp32 — per-column scale
74 K, N: ints
75 """
77 def __init__(self, w_int8: torch.Tensor, w_scale: torch.Tensor):
78 super().__init__()
79 if w_int8.dtype != torch.int8:
80 raise TypeError(f"w_int8 must be int8, got {w_int8.dtype}")
81 self.N, self.K = w_int8.shape
82 w_kn = w_int8.t().contiguous() # [K, N]
83 self._packed = pack_weights_sdot(w_kn) # [K//4, N//4, 4, 4]
84 self._w_int8_kn = w_kn # [K, N] for torch._int_mm
85 self._w_scale = w_scale.squeeze().to(torch.float32).contiguous() # [N]
87 def forward(self, x: torch.Tensor) -> torch.Tensor:
88 shape = x.shape
89 M = x.numel() // shape[-1]
90 if M == 1 and x.dtype == torch.bfloat16:
91 # Decode fast path — one TLE SDOT GEMV kernel call
92 xc = x.reshape(-1).contiguous()
93 out = torch.empty(self.N, dtype=torch.bfloat16)
94 _tle_fused_bf16_gemv_kernel[(1,)](
95 xc,
96 self._packed,
97 self._w_scale,
98 out,
99 K=self.K,
100 N=self.N,
101 )
102 return out.reshape(*shape[:-1], self.N)
104 # Prefill: per-row dynamic INT8 quant → _int_mm → dequant
105 xf = x.reshape(-1, self.K).contiguous()
106 xf32 = xf.float()
107 absmax = xf32.abs().amax(dim=1, keepdim=True).clamp(min=1e-8)
108 x_scale = absmax / 127.0
109 x_int8 = (xf32 / x_scale).clamp_(-128, 127).to(torch.int8)
110 try:
111 out_i32 = torch._int_mm(x_int8, self._w_int8_kn)
112 out_f32 = out_i32.float() * x_scale * self._w_scale.unsqueeze(0)
113 except Exception:
114 # FlagGems _int_mm may fall back to aten::mm with int32 operands,
115 # which re-enters FlagGems mm and fails for non-BF16 dtype.
116 # Use an fp32 matmul fallback that bypasses that chain.
117 w_fp32 = self._w_int8_kn.to(torch.float32) * self._w_scale.unsqueeze(
118 0
119 ) # [K, N]
120 out_f32 = xf32 @ w_fp32 # dynamic quant of x was identity here
121 return out_f32.to(torch.bfloat16).reshape(*shape[:-1], self.N)
123 def extra_repr(self) -> str:
124 return f"in_features={self.K}, out_features={self.N}, dtype=int8"