Coverage for src/flag_gems/runtime/backend/_arm/int8/replace.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""Replace nn.Linear modules in a transformers model with TLEInt8Linear from
2a pre-quantized safetensors state dict.
4State dict convention (matches llm-compressor / compressed-tensors W8A8 output):
5 <module_name>.weight : int8 tensor [N, K]
6 <module_name>.weight_scale : fp32 tensor [N] or scalar
8Example:
9 from safetensors.torch import load_file
10 from flag_gems.runtime.backend._arm.int8 import replace_linears_with_tle_int8
11 state = load_file("Qwen3-1.7B-W8A8-INT8/model.safetensors")
12 n = replace_linears_with_tle_int8(model, state, skip={"lm_head"})
13"""
15import logging
16from typing import Iterable, Optional
18import torch
20from .tle_int8_linear import TLEInt8Linear
22logger = logging.getLogger(__name__)
25def replace_linears_with_tle_int8(
26 model: torch.nn.Module,
27 state_dict: dict,
28 skip: Optional[Iterable[str]] = None,
29 require_divisible_by: int = 4,
30) -> int:
31 """Walk model.named_modules(), replace each nn.Linear whose corresponding
32 <name>.weight in state_dict is int8 with a TLEInt8Linear.
34 Args:
35 model: a torch.nn.Module (typically a transformers model).
36 state_dict: {name.weight: int8 tensor, name.weight_scale: fp32 tensor, ...}.
37 skip: iterable of module names to skip (e.g. {"lm_head"} to leave
38 it as a plain Linear because it won't be followed by more
39 decoder ops that can fuse it).
40 require_divisible_by: minimum alignment for K and N (SDOT requires 4).
42 Returns: number of Linear modules replaced.
43 """
44 skip_set = set(skip) if skip else set()
45 n_replaced = 0
46 n_skipped_dtype = 0
47 n_skipped_align = 0
48 n_skipped_missing = 0
50 for name, module in list(model.named_modules()):
51 if not isinstance(module, torch.nn.Linear):
52 continue
53 if name in skip_set:
54 continue
56 w_key = f"{name}.weight"
57 s_key = f"{name}.weight_scale"
58 if w_key not in state_dict or s_key not in state_dict:
59 n_skipped_missing += 1
60 continue
62 w = state_dict[w_key]
63 s = state_dict[s_key]
64 if w.dtype != torch.int8:
65 n_skipped_dtype += 1
66 continue
68 N, K = w.shape
69 if K % require_divisible_by != 0 or N % require_divisible_by != 0:
70 logger.debug(
71 "replace_linears_with_tle_int8: %s K=%d N=%d not divisible by %d",
72 name,
73 K,
74 N,
75 require_divisible_by,
76 )
77 n_skipped_align += 1
78 continue
80 # Walk dotted name to the parent module and setattr
81 parts = name.split(".")
82 parent = model
83 for p in parts[:-1]:
84 parent = getattr(parent, p)
85 setattr(parent, parts[-1], TLEInt8Linear(w, s))
86 n_replaced += 1
88 # Engage the aten::_int_mm CPU override so TLEInt8Linear prefill's
89 # torch._int_mm routes to the Triton SVE2 i8mm kernel (see quantize_live.py).
90 from ..ops import apply_arm_overrides
92 apply_arm_overrides(include=["_int_mm"])
94 logger.info(
95 "TLEInt8Linear: replaced %d Linear modules "
96 "(skipped: %d dtype, %d alignment, %d missing, %d explicit)",
97 n_replaced,
98 n_skipped_dtype,
99 n_skipped_align,
100 n_skipped_missing,
101 len(skip_set),
102 )
103 return n_replaced