Coverage for src/flag_gems/runtime/backend/_arm/int8/tle_int8_linear.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Drop-in replacement for nn.Linear using TLE SDOT GEMV decode + torch._int_mm prefill. 

2 

3Decode (M=1, BF16): BF16 activation → in-kernel quant → INT8 SDOT matmul → 

4BF16 output, all in one @triton.jit that calls triton-cpu's sdot_gemv_fused_bf16 

5TLE builtin (which dispatches to the NEON SDOT C runtime). 

6 

7Prefill (M>1): per-row dynamic INT8 quantization in fp32 Python, then 

8torch._int_mm (which is hooked by flag_gems _arm/ops/int_mm.py to route 

9to the Triton SVE2 i8mm kernel), then external dequant. 

10 

11The class exposes the attributes FusedMLPWrapper (in fused/patch_qwen3_mlp.py) 

12checks for: _packed, _w_scale, K, N — so when gate/up/down are replaced with 

13TLEInt8Linear, the MLP patch can fuse them into fused_mlp_bf16. 

14""" 

15 

16import torch 

17import triton 

18import triton.language as tl 

19from triton.language.extra.cpu.tle_ops import sdot_gemv_fused_bf16 as _cpu_fused_gemv 

20 

21# Prefill GEMM goes through torch._int_mm, which is routed to FlagGems' Triton 

22# SVE2 i8mm kernel by the aten::_int_mm CPU override. That override is opt-in 

23# (apply_arm_overrides), engaged by quantize_and_replace_linears / replace_* 

24# at setup; without it torch._int_mm falls back to ATen's scalar _int_mm. 

25 

26 

27@triton.jit 

28def _tle_fused_bf16_gemv_kernel( 

29 x_ptr, 

30 b_packed_ptr, 

31 w_scale_ptr, 

32 out_ptr, 

33 K: tl.constexpr, 

34 N: tl.constexpr, 

35): 

36 """TLE GEMV: BF16 x [K] @ INT8 packed W [K//4, N//4, 4, 4] → BF16 out [N]. 

37 

38 Performs in-kernel dynamic INT8 quantization of x and dequantization of 

39 output. Single OMP region, no intermediate tensors. 

40 """ 

41 _cpu_fused_gemv(x_ptr, b_packed_ptr, w_scale_ptr, out_ptr, K, N) 

42 

43 

44def pack_weights_sdot(w_kn: torch.Tensor) -> torch.Tensor: 

45 """Pack row-major [K, N] INT8 weight into SDOT-friendly [K//4, N//4, 4, 4]. 

46 

47 SDOT loads 4 consecutive K-bytes from one lane and broadcasts to 4 N-lanes. 

48 The packed layout ensures each SDOT tile is contiguous in memory for 

49 maximum L1 cache efficiency. Requires K%4==0 and N%4==0. 

50 """ 

51 K, N = w_kn.shape 

52 if K % 4 != 0 or N % 4 != 0: 

53 raise ValueError( 

54 f"pack_weights_sdot requires K%4==0 and N%4==0, got K={K} N={N}" 

55 ) 

56 return w_kn.reshape(K // 4, 4, N // 4, 4).permute(0, 2, 3, 1).contiguous() 

57 

58 

59class TLEInt8Linear(torch.nn.Module): 

60 """nn.Linear replacement with TLE SDOT decode + torch._int_mm prefill. 

61 

62 Args: 

63 w_int8: [N, K] int8 tensor (pre-quantized weight, same layout as nn.Linear's 

64 .weight.data attribute but dtype=int8). 

65 w_scale: [N] fp32 tensor (per-column weight scales); scalar broadcasted 

66 tensors also accepted. 

67 

68 Required: K % 4 == 0 and N % 4 == 0 (SDOT lane requirement). 

69 

70 Attributes exposed for downstream fusion passes (e.g. patch_qwen3_mlp): 

71 _packed: [K//4, N//4, 4, 4] int8 — for SDOT decode 

72 _w_int8_kn: [K, N] int8 — for torch._int_mm prefill 

73 _w_scale: [N] fp32 — per-column scale 

74 K, N: ints 

75 """ 

76 

77 def __init__(self, w_int8: torch.Tensor, w_scale: torch.Tensor): 

78 super().__init__() 

79 if w_int8.dtype != torch.int8: 

80 raise TypeError(f"w_int8 must be int8, got {w_int8.dtype}") 

81 self.N, self.K = w_int8.shape 

82 w_kn = w_int8.t().contiguous() # [K, N] 

83 self._packed = pack_weights_sdot(w_kn) # [K//4, N//4, 4, 4] 

84 self._w_int8_kn = w_kn # [K, N] for torch._int_mm 

85 self._w_scale = w_scale.squeeze().to(torch.float32).contiguous() # [N] 

86 

87 def forward(self, x: torch.Tensor) -> torch.Tensor: 

88 shape = x.shape 

89 M = x.numel() // shape[-1] 

90 if M == 1 and x.dtype == torch.bfloat16: 

91 # Decode fast path — one TLE SDOT GEMV kernel call 

92 xc = x.reshape(-1).contiguous() 

93 out = torch.empty(self.N, dtype=torch.bfloat16) 

94 _tle_fused_bf16_gemv_kernel[(1,)]( 

95 xc, 

96 self._packed, 

97 self._w_scale, 

98 out, 

99 K=self.K, 

100 N=self.N, 

101 ) 

102 return out.reshape(*shape[:-1], self.N) 

103 

104 # Prefill: per-row dynamic INT8 quant → _int_mm → dequant 

105 xf = x.reshape(-1, self.K).contiguous() 

106 xf32 = xf.float() 

107 absmax = xf32.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) 

108 x_scale = absmax / 127.0 

109 x_int8 = (xf32 / x_scale).clamp_(-128, 127).to(torch.int8) 

110 try: 

111 out_i32 = torch._int_mm(x_int8, self._w_int8_kn) 

112 out_f32 = out_i32.float() * x_scale * self._w_scale.unsqueeze(0) 

113 except Exception: 

114 # FlagGems _int_mm may fall back to aten::mm with int32 operands, 

115 # which re-enters FlagGems mm and fails for non-BF16 dtype. 

116 # Use an fp32 matmul fallback that bypasses that chain. 

117 w_fp32 = self._w_int8_kn.to(torch.float32) * self._w_scale.unsqueeze( 

118 0 

119 ) # [K, N] 

120 out_f32 = xf32 @ w_fp32 # dynamic quant of x was identity here 

121 return out_f32.to(torch.bfloat16).reshape(*shape[:-1], self.N) 

122 

123 def extra_repr(self) -> str: 

124 return f"in_features={self.K}, out_features={self.N}, dtype=int8"