Coverage for src/flag_gems/runtime/backend/_arm/ops/quantized_linear_dynamic.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""
2FlagGems ARM backend: Triton-CPU INT8 GEMM for quantized::linear_dynamic.
4Replaces the OneDNN/ACL implementation of torch.ops.quantized.linear_dynamic
5with a Triton-CPU i8mm kernel on ARM64 (SVE2 + i8mm).
7Kernel configs (validated on CIX P1 CD8180):
8 M=1 → BM=1, BN=64, BK=4 (ConvertDotGeneric, 63 GOPS decode)
9 M=2 → BM=2, BN=64, BK=4 (ConvertDotGeneric, LLVM unrolls K=4)
10 M%64==0 → BM=64, BN=64, BK=32 (SVE2 i8mm dynamic ForOp, 411 GOPS)
11 M%8==0 → BM=8, BN=64, BK=32 (SVE2 i8mm dynamic ForOp, 100-128 GOPS)
12 otherwise → pad M to next %8==0, BM=8 (zero-pad extra rows, then slice output)
14Fusion optimisation (2026-03-06):
15 _i8mm_fused_kernel takes FP32 activation input directly and outputs FP32.
16 Quantisation (FP32→INT8) and dequantisation (INT32→FP32) are fused inside
17 the kernel, eliminating 7 separate PyTorch operator calls per linear layer:
18 BEFORE: abs, max, div, round_, clamp_, to(int8), empty(int32),
19 dot-kernel, to(float32), mul_
20 AFTER: abs, max, fused-kernel (saves ~17 ms/tok on Qwen3-1.7B)
22Weight tiling optimisation (2026-03-06):
23 _i8mm_fused_tiled_kernel uses pre-tiled weights [K//BK, N//BN, BK, BN].
24 Each B tile is contiguous in memory, eliminating strided cache-miss pattern
25 of the row-major [K,N] layout (stride_bk = N = 18944 causes L2 misses).
26 Applied to all prefill paths (M≥4); decode (M=1,2) keeps row-major layout.
27 Extra memory: ~1x weight size (e.g. +1.7 GB for Qwen3-1.7B). One-time cost
28 at first inference per weight.
30Weight cache: keyed on w.data_ptr() (stable physical address).
31"""
33import logging
35import torch
36import triton
37import triton.language as tl
39logger = logging.getLogger(__name__)
41# Tile dimensions for prefill weight layout (must match kernel constexprs)
42_TILE_BK = 32
43_TILE_BN = 64
45# Runtime flag: enable M-padding for non-M%8 prefill shapes (Phase 4).
46# Set to False to revert to Phase 3 BM=4 static path (for benchmarking).
47_ENABLE_PADDING = True
50# ---------------------------------------------------------------------------
51# Fused + tiled kernel: FP32 input → INT8 quant → tiled INT8 GEMM → FP32 out
52# Used for prefill paths (M≥4, BK=32) where B tile is contiguous in memory.
53# ---------------------------------------------------------------------------
56@triton.jit
57def _i8mm_fused_tiled_kernel(
58 a_ptr,
59 b_ptr,
60 c_ptr,
61 M,
62 N,
63 K,
64 stride_am,
65 stride_ak,
66 stride_cm,
67 stride_cn,
68 N_TILES, # int32: N // BLOCK_N (number of N-tiles)
69 inv_x_scale, # float32 scalar: 127.0 / x_abs_max
70 out_scale, # float32 scalar: (x_abs_max / 127.0) * w_scale
71 BLOCK_M: tl.constexpr,
72 BLOCK_N: tl.constexpr, # must equal _TILE_BN (64)
73 BLOCK_K: tl.constexpr, # must equal _TILE_BK (32)
74):
75 """
76 Fused INT8 GEMM with tiled weight layout.
78 A[M,K] fp32 (activation, row-major)
79 B tiled int8 layout [K//BK, N//BN, BK, BN] — each tile contiguous
80 C[M,N] fp32 output
82 The tiled layout ensures each B tile load is a contiguous BK*BN-byte
83 block, eliminating the stride-bk=N cache-miss pattern of row-major [K,N].
84 SVE2 i8mm (smmla) path fires as before: both operands are int8.
85 """
86 pid_m = tl.program_id(0)
87 pid_n = tl.program_id(1)
88 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
89 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
91 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
92 for k in range(0, tl.cdiv(K, BLOCK_K)):
93 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K)
95 # Load FP32 activation tile; quantise to INT8 in-kernel
96 a_fp32 = tl.load(
97 a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
98 )
99 a_scaled = a_fp32 * inv_x_scale
100 a_clamped = tl.minimum(tl.maximum(a_scaled, -128.0), 127.0)
101 a_int8 = a_clamped.to(tl.int8)
103 # Load tiled B: tile (k, pid_n) is contiguous BK*BN bytes
104 # Layout: b_ptr[k * N_TILES + pid_n][BK][BN]
105 b_base = b_ptr + (k * N_TILES + pid_n) * BLOCK_K * BLOCK_N
106 b = tl.load(
107 b_base
108 + tl.arange(0, BLOCK_K)[:, None] * BLOCK_N
109 + tl.arange(0, BLOCK_N)[None, :]
110 )
111 acc += tl.dot(a_int8, b)
113 # Dequantise: int32 → float32, scale and store
114 c_fp32 = acc.to(tl.float32) * out_scale
115 tl.store(
116 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
117 c_fp32,
118 )
121# ---------------------------------------------------------------------------
122# Fused kernel: FP32 input → INT8 quant → row-major INT8 GEMM → FP32 out
123# Used for decode paths (M=1,2, BK=4) where tile is tiny (4×64 bytes).
124# ---------------------------------------------------------------------------
127@triton.jit
128def _i8mm_fused_kernel(
129 a_ptr,
130 b_ptr,
131 c_ptr,
132 M,
133 N,
134 K,
135 stride_am,
136 stride_ak,
137 stride_bk,
138 stride_bn,
139 stride_cm,
140 stride_cn,
141 inv_x_scale, # float32 scalar: 127.0 / x_abs_max
142 out_scale, # float32 scalar: (x_abs_max / 127.0) * w_scale
143 BLOCK_M: tl.constexpr,
144 BLOCK_N: tl.constexpr,
145 BLOCK_K: tl.constexpr,
146):
147 """
148 Fused INT8 GEMM with row-major weight layout [K, N].
149 Used for decode (M=1,2, BK=4): LLVM fully unrolls K=4 loop.
150 """
151 pid_m = tl.program_id(0)
152 pid_n = tl.program_id(1)
153 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
154 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
156 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
157 for k in range(0, tl.cdiv(K, BLOCK_K)):
158 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K)
160 a_fp32 = tl.load(
161 a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
162 )
163 a_scaled = a_fp32 * inv_x_scale
164 a_clamped = tl.minimum(tl.maximum(a_scaled, -128.0), 127.0)
165 a_int8 = a_clamped.to(tl.int8)
167 b = tl.load(b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
168 acc += tl.dot(a_int8, b)
170 c_fp32 = acc.to(tl.float32) * out_scale
171 tl.store(
172 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
173 c_fp32,
174 )
177# ---------------------------------------------------------------------------
178# Legacy unfused kernel (kept for reference / debugging)
179# ---------------------------------------------------------------------------
182@triton.jit
183def _i8mm_kernel(
184 a_ptr,
185 b_ptr,
186 c_ptr,
187 M,
188 N,
189 K,
190 stride_am,
191 stride_ak,
192 stride_bk,
193 stride_bn,
194 stride_cm,
195 stride_cn,
196 BLOCK_M: tl.constexpr,
197 BLOCK_N: tl.constexpr,
198 BLOCK_K: tl.constexpr,
199):
200 """Unfused INT8 GEMM: A int8, B int8 → C int32."""
201 pid_m = tl.program_id(0)
202 pid_n = tl.program_id(1)
203 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
204 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
205 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
206 for k in range(0, tl.cdiv(K, BLOCK_K)):
207 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K)
208 a = tl.load(a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
209 b = tl.load(b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
210 acc += tl.dot(a, b)
211 tl.store(
212 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
213 acc.to(tl.int32),
214 )
217# ---------------------------------------------------------------------------
218# Weight cache
219# ---------------------------------------------------------------------------
221# w_raw.data_ptr() → (weight_kn [K,N], weight_tiled [K//BK,N//BN,BK,BN] or None,
222# weight_scale float, bias or None)
223_weight_cache: dict = {}
226def _get_weight(W_prepack):
227 w, bias = W_prepack.unpack() # w: qint8 [N, K]
228 key = w.data_ptr() # stable physical address
229 if key in _weight_cache:
230 return _weight_cache[key]
232 # Row-major [K, N] for decode (M=1,2, BK=4)
233 weight_kn = w.int_repr().T.contiguous() # int8 [K, N]
234 K, N = weight_kn.shape
236 # Tiled [K//BK, N//BN, BK, BN] for prefill (M≥4, BK=32, BN=64)
237 # Each tile is BK*BN contiguous bytes → eliminates strided cache misses.
238 BK, BN = _TILE_BK, _TILE_BN
239 if K % BK == 0 and N % BN == 0:
240 weight_tiled = (
241 weight_kn.reshape(K // BK, BK, N // BN, BN).permute(0, 2, 1, 3).contiguous()
242 ) # int8 [K//BK, N//BN, BK, BN]
243 else:
244 weight_tiled = None
245 logger.debug(
246 "FlagGems ARM: K=%d N=%d not divisible by BK=%d BN=%d; "
247 "tiled layout disabled for this layer",
248 K,
249 N,
250 BK,
251 BN,
252 )
254 weight_scale = float(w.q_scale())
255 entry = (weight_kn, weight_tiled, weight_scale, bias)
256 _weight_cache[key] = entry
257 return entry
260# ---------------------------------------------------------------------------
261# Core implementation
262# ---------------------------------------------------------------------------
265def _triton_quantized_linear_dynamic(X, W_prepack, reduce_range=False):
266 """
267 Triton-CPU replacement for torch.ops.quantized.linear_dynamic (CPU).
269 X : float32 tensor, shape [..., K]
270 W_prepack: torch.ScriptObject (LinearPackedParamsBase), qint8 [N, K]
271 Returns : float32 tensor, shape [..., N]
273 Decode (M=1,2): _i8mm_fused_kernel, row-major weight [K,N], BK=4.
274 LLVM fully unrolls K=4 loop → fastest for tiny GEMV.
276 Prefill (M≥3): _i8mm_fused_tiled_kernel, tiled weight [K//32,N//64,32,64], BK=32.
277 BM=64 for M%64==0; BM=8 for all other M (with zero-padding if M%8≠0).
278 Padding: M=84 → M_kernel=88 (+4 zero rows), unlocks Dynamic ForOp path
279 (100-128 GOPS) vs old BM=4 static path (57-73 GOPS).
280 """
281 weight_kn, weight_tiled, weight_scale, bias = _get_weight(W_prepack)
283 K = X.shape[-1]
284 N = weight_kn.shape[1]
285 orig_shape = X.shape
287 x2d = X.view(-1, K)
288 M = x2d.shape[0]
290 # Compute activation scale (one reduction, unavoidable for per-tensor quant)
291 x_abs_max = x2d.abs().max().item()
292 if x_abs_max == 0.0:
293 out2d = torch.zeros(M, N, dtype=torch.float32)
294 if bias is not None:
295 out2d = out2d + bias
296 return out2d.view(*orig_shape[:-1], N)
298 inv_x_scale = 127.0 / x_abs_max
299 out_scale = (x_abs_max / 127.0) * weight_scale
301 # ------------------------------------------------------------------
302 # Decode paths (M=1,2): row-major weight, BK=4, ConvertDotGeneric.
303 # ------------------------------------------------------------------
304 if M == 1:
305 BM, BN, BK = 1, 64, 4
306 out2d = torch.empty(M, N, dtype=torch.float32)
307 _i8mm_fused_kernel[(1, N // BN)](
308 x2d,
309 weight_kn,
310 out2d,
311 M,
312 N,
313 K,
314 x2d.stride(0),
315 x2d.stride(1),
316 weight_kn.stride(0),
317 weight_kn.stride(1),
318 out2d.stride(0),
319 out2d.stride(1),
320 inv_x_scale=inv_x_scale,
321 out_scale=out_scale,
322 BLOCK_M=BM,
323 BLOCK_N=BN,
324 BLOCK_K=BK,
325 )
327 elif M == 2:
328 BM, BN, BK = 2, 64, 4
329 out2d = torch.empty(M, N, dtype=torch.float32)
330 _i8mm_fused_kernel[(1, N // BN)](
331 x2d,
332 weight_kn,
333 out2d,
334 M,
335 N,
336 K,
337 x2d.stride(0),
338 x2d.stride(1),
339 weight_kn.stride(0),
340 weight_kn.stride(1),
341 out2d.stride(0),
342 out2d.stride(1),
343 inv_x_scale=inv_x_scale,
344 out_scale=out_scale,
345 BLOCK_M=BM,
346 BLOCK_N=BN,
347 BLOCK_K=BK,
348 )
350 # ------------------------------------------------------------------
351 # Prefill path (M≥3).
352 #
353 # Routing observed empirically via A/B vs commit 80be6a2e^:
354 # M%64==0 → legacy _i8mm_kernel (fused kernel regresses ~15-20%
355 # here due to BM=64 BK=32 epilog register pressure).
356 # M=4 → legacy BM=1 BK=4 (BM=4 BK=32 SVE2 static path is slower
357 # than BM=1 BK=4 ConvertDotGeneric for this tiny shape).
358 # M%8==0 → fused kernel BM=8 BK=32 (SVE2 i8mm Dynamic ForOp, ~1.4x).
359 # otherwise → pad to %8, fused BM=8 BK=32.
360 # ------------------------------------------------------------------
361 elif M % 64 == 0:
362 # Legacy path: external quant → _i8mm_kernel (int8×int8→int32) → external dequant.
363 # Fused kernel's BM=64 epilog hurts LLVM register allocation here.
364 BM, BN, BK = 64, 64, 32
365 # NOTE: no .round_() — match fused kernel's .to(int8) truncate behavior.
366 # Rounding here (when fused kernel truncates) creates argmax drift at
367 # long generations because this M's rounding mode differs from other M's.
368 x_q = (x2d * inv_x_scale).clamp_(-128, 127).to(torch.int8)
369 c_i32 = torch.empty(M, N, dtype=torch.int32)
370 _i8mm_kernel[(M // BM, N // BN)](
371 x_q,
372 weight_kn,
373 c_i32,
374 M,
375 N,
376 K,
377 x_q.stride(0),
378 x_q.stride(1),
379 weight_kn.stride(0),
380 weight_kn.stride(1),
381 c_i32.stride(0),
382 c_i32.stride(1),
383 BLOCK_M=BM,
384 BLOCK_N=BN,
385 BLOCK_K=BK,
386 )
387 out2d = c_i32.to(torch.float32).mul_(out_scale)
389 elif M == 4:
390 # Legacy BM=1 BK=4 path: faster than BM=4 BK=32 static i8mm here.
391 BM, BN, BK = 1, 64, 4
392 # NOTE: no .round_() — match fused kernel's .to(int8) truncate behavior.
393 # Rounding here (when fused kernel truncates) creates argmax drift at
394 # long generations because this M's rounding mode differs from other M's.
395 x_q = (x2d * inv_x_scale).clamp_(-128, 127).to(torch.int8)
396 c_i32 = torch.empty(M, N, dtype=torch.int32)
397 _i8mm_kernel[(M, N // BN)](
398 x_q,
399 weight_kn,
400 c_i32,
401 M,
402 N,
403 K,
404 x_q.stride(0),
405 x_q.stride(1),
406 weight_kn.stride(0),
407 weight_kn.stride(1),
408 c_i32.stride(0),
409 c_i32.stride(1),
410 BLOCK_M=BM,
411 BLOCK_N=BN,
412 BLOCK_K=BK,
413 )
414 out2d = c_i32.to(torch.float32).mul_(out_scale)
416 else:
417 # Fused kernel path: BM=8 BK=32 (Dynamic ForOp SVE2 i8mm, wins here).
418 use_tiled = weight_tiled is not None
419 BN, BK = 64, 32
421 if M % 8 == 0:
422 BM = 8
423 x_kernel, M_kernel = x2d, M
424 elif _ENABLE_PADDING:
425 # Pad to next multiple of 8 → Dynamic ForOp path
426 # e.g. M=84 → M_kernel=88 (4 extra zero rows)
427 M_kernel = ((M + 7) // 8) * 8
428 BM = 8
429 x_kernel = torch.zeros(M_kernel, K, dtype=x2d.dtype)
430 x_kernel[:M].copy_(x2d)
431 else:
432 # Phase 3 fallback: no padding, BM=4 if aligned else BM=1
433 BM = 4 if M % 4 == 0 else 1
434 x_kernel, M_kernel = x2d, M
436 out_kernel = torch.empty(M_kernel, N, dtype=torch.float32)
437 grid = (M_kernel // BM, N // BN)
439 if use_tiled:
440 _i8mm_fused_tiled_kernel[grid](
441 x_kernel,
442 weight_tiled,
443 out_kernel,
444 M_kernel,
445 N,
446 K,
447 x_kernel.stride(0),
448 x_kernel.stride(1),
449 out_kernel.stride(0),
450 out_kernel.stride(1),
451 N // BN,
452 inv_x_scale=inv_x_scale,
453 out_scale=out_scale,
454 BLOCK_M=BM,
455 BLOCK_N=BN,
456 BLOCK_K=BK,
457 )
458 else:
459 _i8mm_fused_kernel[grid](
460 x_kernel,
461 weight_kn,
462 out_kernel,
463 M_kernel,
464 N,
465 K,
466 x_kernel.stride(0),
467 x_kernel.stride(1),
468 weight_kn.stride(0),
469 weight_kn.stride(1),
470 out_kernel.stride(0),
471 out_kernel.stride(1),
472 inv_x_scale=inv_x_scale,
473 out_scale=out_scale,
474 BLOCK_M=BM,
475 BLOCK_N=BN,
476 BLOCK_K=BK,
477 )
479 # Slice off the padding rows (out_kernel[:M] is a view, no copy)
480 out2d = out_kernel[:M] if M_kernel != M else out_kernel
482 if bias is not None:
483 out2d = out2d + bias
484 return out2d.view(*orig_shape[:-1], N)
487# ---------------------------------------------------------------------------
488# Registration
489# ---------------------------------------------------------------------------
491_quantized_lib = None # keep reference alive to prevent GC
494def register():
495 """
496 Register Triton implementation for quantized::linear_dynamic on CPU.
497 Idempotent: safe to call multiple times.
498 """
499 global _quantized_lib
500 if _quantized_lib is not None:
501 return
503 try:
504 _quantized_lib = torch.library.Library("quantized", "IMPL")
505 _quantized_lib.impl(
506 "linear_dynamic",
507 _triton_quantized_linear_dynamic,
508 "CPU",
509 allow_override=True,
510 )
511 logger.debug(
512 "FlagGems ARM: registered Triton-CPU i8mm (fused+tiled) for quantized::linear_dynamic"
513 )
514 except Exception as e:
515 logger.warning(
516 f"FlagGems ARM: failed to register quantized::linear_dynamic override: {e}"
517 )