Coverage for src/flag_gems/runtime/backend/_arm/int8/replace.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Replace nn.Linear modules in a transformers model with TLEInt8Linear from 

2a pre-quantized safetensors state dict. 

3 

4State dict convention (matches llm-compressor / compressed-tensors W8A8 output): 

5 <module_name>.weight : int8 tensor [N, K] 

6 <module_name>.weight_scale : fp32 tensor [N] or scalar 

7 

8Example: 

9 from safetensors.torch import load_file 

10 from flag_gems.runtime.backend._arm.int8 import replace_linears_with_tle_int8 

11 state = load_file("Qwen3-1.7B-W8A8-INT8/model.safetensors") 

12 n = replace_linears_with_tle_int8(model, state, skip={"lm_head"}) 

13""" 

14 

15import logging 

16from typing import Iterable, Optional 

17 

18import torch 

19 

20from .tle_int8_linear import TLEInt8Linear 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25def replace_linears_with_tle_int8( 

26 model: torch.nn.Module, 

27 state_dict: dict, 

28 skip: Optional[Iterable[str]] = None, 

29 require_divisible_by: int = 4, 

30) -> int: 

31 """Walk model.named_modules(), replace each nn.Linear whose corresponding 

32 <name>.weight in state_dict is int8 with a TLEInt8Linear. 

33 

34 Args: 

35 model: a torch.nn.Module (typically a transformers model). 

36 state_dict: {name.weight: int8 tensor, name.weight_scale: fp32 tensor, ...}. 

37 skip: iterable of module names to skip (e.g. {"lm_head"} to leave 

38 it as a plain Linear because it won't be followed by more 

39 decoder ops that can fuse it). 

40 require_divisible_by: minimum alignment for K and N (SDOT requires 4). 

41 

42 Returns: number of Linear modules replaced. 

43 """ 

44 skip_set = set(skip) if skip else set() 

45 n_replaced = 0 

46 n_skipped_dtype = 0 

47 n_skipped_align = 0 

48 n_skipped_missing = 0 

49 

50 for name, module in list(model.named_modules()): 

51 if not isinstance(module, torch.nn.Linear): 

52 continue 

53 if name in skip_set: 

54 continue 

55 

56 w_key = f"{name}.weight" 

57 s_key = f"{name}.weight_scale" 

58 if w_key not in state_dict or s_key not in state_dict: 

59 n_skipped_missing += 1 

60 continue 

61 

62 w = state_dict[w_key] 

63 s = state_dict[s_key] 

64 if w.dtype != torch.int8: 

65 n_skipped_dtype += 1 

66 continue 

67 

68 N, K = w.shape 

69 if K % require_divisible_by != 0 or N % require_divisible_by != 0: 

70 logger.debug( 

71 "replace_linears_with_tle_int8: %s K=%d N=%d not divisible by %d", 

72 name, 

73 K, 

74 N, 

75 require_divisible_by, 

76 ) 

77 n_skipped_align += 1 

78 continue 

79 

80 # Walk dotted name to the parent module and setattr 

81 parts = name.split(".") 

82 parent = model 

83 for p in parts[:-1]: 

84 parent = getattr(parent, p) 

85 setattr(parent, parts[-1], TLEInt8Linear(w, s)) 

86 n_replaced += 1 

87 

88 # Engage the aten::_int_mm CPU override so TLEInt8Linear prefill's 

89 # torch._int_mm routes to the Triton SVE2 i8mm kernel (see quantize_live.py). 

90 from ..ops import apply_arm_overrides 

91 

92 apply_arm_overrides(include=["_int_mm"]) 

93 

94 logger.info( 

95 "TLEInt8Linear: replaced %d Linear modules " 

96 "(skipped: %d dtype, %d alignment, %d missing, %d explicit)", 

97 n_replaced, 

98 n_skipped_dtype, 

99 n_skipped_align, 

100 n_skipped_missing, 

101 len(skip_set), 

102 ) 

103 return n_replaced