Coverage for src/flag_gems/runtime/backend/_arm/int8/quantize_live.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1"""Live (in-memory) W8 per-channel symmetric quantization of nn.Linear layers, 

2replacing each with a TLEInt8Linear. 

3 

4Use case: take a BF16 model that has no pre-quantized state dict (e.g. the 

5public Qwen3.5-2B BF16 release) and turn it into a TLE INT8 stack on the fly. 

6 

7Quantization scheme matches the llm-compressor / compressed-tensors W8A8 

8output format (per-channel symmetric INT8 weights, per-token symmetric INT8 

9activations). Activation quantization is already done inside 

10TLEInt8Linear.forward; this helper only handles the weight side. 

11 

12Example: 

13 from transformers import AutoModelForCausalLM 

14 from flag_gems.runtime.backend._arm.int8 import quantize_and_replace_linears 

15 

16 m = AutoModelForCausalLM.from_pretrained("...", dtype=torch.bfloat16) 

17 n = quantize_and_replace_linears(m, skip={"lm_head"}) 

18""" 

19import logging 

20from typing import Iterable, Optional, Tuple 

21 

22import torch 

23 

24from .tle_int8_linear import TLEInt8Linear 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29def _quantize_weight_per_channel_sym( 

30 w: torch.Tensor, 

31) -> Tuple[torch.Tensor, torch.Tensor]: 

32 """Per-output-channel symmetric INT8 quant of an [N, K] weight. 

33 

34 Returns: 

35 w_int8: [N, K] int8 

36 w_scale: [N] fp32 (per-channel) 

37 """ 

38 w_fp32 = w.detach().to(torch.float32) 

39 # max(|w|) along K (axis 1) 

40 absmax = w_fp32.abs().amax(dim=1, keepdim=True).clamp(min=1e-8) # [N, 1] 

41 scale = absmax / 127.0 # [N, 1] 

42 w_int8 = (w_fp32 / scale).round().clamp(-128, 127).to(torch.int8) # [N, K] 

43 scale_flat = scale.squeeze(-1).contiguous().to(torch.float32) # [N] 

44 return w_int8, scale_flat 

45 

46 

47def quantize_and_replace_linears( 

48 model: torch.nn.Module, 

49 skip: Optional[Iterable[str]] = None, 

50 require_divisible_by: int = 4, 

51 skip_with_bias: bool = True, 

52) -> int: 

53 """Walk model.named_modules(), in-memory quantize each nn.Linear weight to 

54 per-channel symmetric INT8, and replace it with a TLEInt8Linear. 

55 

56 Args: 

57 model: any torch.nn.Module (typically a transformers model). 

58 skip: iterable of module names to leave alone (e.g. {"lm_head"} when 

59 the head is tied or too large to benefit from INT8 GEMV). 

60 require_divisible_by: SDOT requires K%4==0 and N%4==0; smaller-aligned 

61 linears stay BF16 (e.g. tiny GDN scalar projections in_proj_a/b). 

62 skip_with_bias: TLEInt8Linear has no bias parameter; if True, leave any 

63 nn.Linear with a non-None bias as-is. Set False to assert no 

64 biases exist (matches Qwen3-style models). 

65 

66 Returns: number of Linear modules replaced. 

67 """ 

68 skip_set = set(skip) if skip else set() 

69 n_replaced = 0 

70 n_skipped_align = 0 

71 n_skipped_bias = 0 

72 

73 for name, module in list(model.named_modules()): 

74 if not isinstance(module, torch.nn.Linear): 

75 continue 

76 if name in skip_set: 

77 continue 

78 if module.bias is not None: 

79 if skip_with_bias: 

80 n_skipped_bias += 1 

81 continue 

82 raise ValueError( 

83 f"{name} has bias=True; TLEInt8Linear does not support bias" 

84 ) 

85 

86 N, K = module.weight.shape 

87 if K % require_divisible_by != 0 or N % require_divisible_by != 0: 

88 n_skipped_align += 1 

89 logger.debug( 

90 "quantize_and_replace_linears: %s K=%d N=%d not divisible by %d", 

91 name, 

92 K, 

93 N, 

94 require_divisible_by, 

95 ) 

96 continue 

97 

98 w_int8, w_scale = _quantize_weight_per_channel_sym(module.weight.data) 

99 

100 parts = name.split(".") 

101 parent = model 

102 for p in parts[:-1]: 

103 parent = getattr(parent, p) 

104 setattr(parent, parts[-1], TLEInt8Linear(w_int8, w_scale)) 

105 

106 # Free the original BF16 weight memory. 

107 del module 

108 n_replaced += 1 

109 

110 # Engage the aten::_int_mm CPU override so TLEInt8Linear prefill's 

111 # torch._int_mm routes to the Triton SVE2 i8mm kernel instead of ATen's 

112 # scalar fallback (~15x slower). Idempotent + process-global on the CPU 

113 # dispatch key. Only _int_mm is needed here; mm/argmax overrides were 

114 # measured to not help the INT8 decode path. 

115 from ..ops import apply_arm_overrides 

116 

117 apply_arm_overrides(include=["_int_mm"]) 

118 

119 logger.info( 

120 "quantize_and_replace_linears: replaced %d Linears " 

121 "(skipped: %d alignment, %d bias, %d explicit)", 

122 n_replaced, 

123 n_skipped_align, 

124 n_skipped_bias, 

125 len(skip_set), 

126 ) 

127 return n_replaced