Coverage for src/flag_gems/runtime/backend/_arm/int8/quantize_live.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Live (in-memory) W8 per-channel symmetric quantization of nn.Linear layers,
2replacing each with a TLEInt8Linear.
4Use case: take a BF16 model that has no pre-quantized state dict (e.g. the
5public Qwen3.5-2B BF16 release) and turn it into a TLE INT8 stack on the fly.
7Quantization scheme matches the llm-compressor / compressed-tensors W8A8
8output format (per-channel symmetric INT8 weights, per-token symmetric INT8
9activations). Activation quantization is already done inside
10TLEInt8Linear.forward; this helper only handles the weight side.
12Example:
13 from transformers import AutoModelForCausalLM
14 from flag_gems.runtime.backend._arm.int8 import quantize_and_replace_linears
16 m = AutoModelForCausalLM.from_pretrained("...", dtype=torch.bfloat16)
17 n = quantize_and_replace_linears(m, skip={"lm_head"})
18"""
19import logging
20from typing import Iterable, Optional, Tuple
22import torch
24from .tle_int8_linear import TLEInt8Linear
26logger = logging.getLogger(__name__)
29def _quantize_weight_per_channel_sym(
30 w: torch.Tensor,
31) -> Tuple[torch.Tensor, torch.Tensor]:
32 """Per-output-channel symmetric INT8 quant of an [N, K] weight.
34 Returns:
35 w_int8: [N, K] int8
36 w_scale: [N] fp32 (per-channel)
37 """
38 w_fp32 = w.detach().to(torch.float32)
39 # max(|w|) along K (axis 1)
40 absmax = w_fp32.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) # [N, 1]
41 scale = absmax / 127.0 # [N, 1]
42 w_int8 = (w_fp32 / scale).round().clamp(-128, 127).to(torch.int8) # [N, K]
43 scale_flat = scale.squeeze(-1).contiguous().to(torch.float32) # [N]
44 return w_int8, scale_flat
47def quantize_and_replace_linears(
48 model: torch.nn.Module,
49 skip: Optional[Iterable[str]] = None,
50 require_divisible_by: int = 4,
51 skip_with_bias: bool = True,
52) -> int:
53 """Walk model.named_modules(), in-memory quantize each nn.Linear weight to
54 per-channel symmetric INT8, and replace it with a TLEInt8Linear.
56 Args:
57 model: any torch.nn.Module (typically a transformers model).
58 skip: iterable of module names to leave alone (e.g. {"lm_head"} when
59 the head is tied or too large to benefit from INT8 GEMV).
60 require_divisible_by: SDOT requires K%4==0 and N%4==0; smaller-aligned
61 linears stay BF16 (e.g. tiny GDN scalar projections in_proj_a/b).
62 skip_with_bias: TLEInt8Linear has no bias parameter; if True, leave any
63 nn.Linear with a non-None bias as-is. Set False to assert no
64 biases exist (matches Qwen3-style models).
66 Returns: number of Linear modules replaced.
67 """
68 skip_set = set(skip) if skip else set()
69 n_replaced = 0
70 n_skipped_align = 0
71 n_skipped_bias = 0
73 for name, module in list(model.named_modules()):
74 if not isinstance(module, torch.nn.Linear):
75 continue
76 if name in skip_set:
77 continue
78 if module.bias is not None:
79 if skip_with_bias:
80 n_skipped_bias += 1
81 continue
82 raise ValueError(
83 f"{name} has bias=True; TLEInt8Linear does not support bias"
84 )
86 N, K = module.weight.shape
87 if K % require_divisible_by != 0 or N % require_divisible_by != 0:
88 n_skipped_align += 1
89 logger.debug(
90 "quantize_and_replace_linears: %s K=%d N=%d not divisible by %d",
91 name,
92 K,
93 N,
94 require_divisible_by,
95 )
96 continue
98 w_int8, w_scale = _quantize_weight_per_channel_sym(module.weight.data)
100 parts = name.split(".")
101 parent = model
102 for p in parts[:-1]:
103 parent = getattr(parent, p)
104 setattr(parent, parts[-1], TLEInt8Linear(w_int8, w_scale))
106 # Free the original BF16 weight memory.
107 del module
108 n_replaced += 1
110 # Engage the aten::_int_mm CPU override so TLEInt8Linear prefill's
111 # torch._int_mm routes to the Triton SVE2 i8mm kernel instead of ATen's
112 # scalar fallback (~15x slower). Idempotent + process-global on the CPU
113 # dispatch key. Only _int_mm is needed here; mm/argmax overrides were
114 # measured to not help the INT8 decode path.
115 from ..ops import apply_arm_overrides
117 apply_arm_overrides(include=["_int_mm"])
119 logger.info(
120 "quantize_and_replace_linears: replaced %d Linears "
121 "(skipped: %d alignment, %d bias, %d explicit)",
122 n_replaced,
123 n_skipped_align,
124 n_skipped_bias,
125 len(skip_set),
126 )
127 return n_replaced