Coverage for src/flag_gems/runtime/backend/_arm/ops/quantized_linear_dynamic.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1""" 

2FlagGems ARM backend: Triton-CPU INT8 GEMM for quantized::linear_dynamic. 

3 

4Replaces the OneDNN/ACL implementation of torch.ops.quantized.linear_dynamic 

5with a Triton-CPU i8mm kernel on ARM64 (SVE2 + i8mm). 

6 

7Kernel configs (validated on CIX P1 CD8180): 

8 M=1 → BM=1, BN=64, BK=4 (ConvertDotGeneric, 63 GOPS decode) 

9 M=2 → BM=2, BN=64, BK=4 (ConvertDotGeneric, LLVM unrolls K=4) 

10 M%64==0 → BM=64, BN=64, BK=32 (SVE2 i8mm dynamic ForOp, 411 GOPS) 

11 M%8==0 → BM=8, BN=64, BK=32 (SVE2 i8mm dynamic ForOp, 100-128 GOPS) 

12 otherwise → pad M to next %8==0, BM=8 (zero-pad extra rows, then slice output) 

13 

14Fusion optimisation (2026-03-06): 

15 _i8mm_fused_kernel takes FP32 activation input directly and outputs FP32. 

16 Quantisation (FP32→INT8) and dequantisation (INT32→FP32) are fused inside 

17 the kernel, eliminating 7 separate PyTorch operator calls per linear layer: 

18 BEFORE: abs, max, div, round_, clamp_, to(int8), empty(int32), 

19 dot-kernel, to(float32), mul_ 

20 AFTER: abs, max, fused-kernel (saves ~17 ms/tok on Qwen3-1.7B) 

21 

22Weight tiling optimisation (2026-03-06): 

23 _i8mm_fused_tiled_kernel uses pre-tiled weights [K//BK, N//BN, BK, BN]. 

24 Each B tile is contiguous in memory, eliminating strided cache-miss pattern 

25 of the row-major [K,N] layout (stride_bk = N = 18944 causes L2 misses). 

26 Applied to all prefill paths (M≥4); decode (M=1,2) keeps row-major layout. 

27 Extra memory: ~1x weight size (e.g. +1.7 GB for Qwen3-1.7B). One-time cost 

28 at first inference per weight. 

29 

30Weight cache: keyed on w.data_ptr() (stable physical address). 

31""" 

32 

33import logging 

34 

35import torch 

36import triton 

37import triton.language as tl 

38 

39logger = logging.getLogger(__name__) 

40 

41# Tile dimensions for prefill weight layout (must match kernel constexprs) 

42_TILE_BK = 32 

43_TILE_BN = 64 

44 

45# Runtime flag: enable M-padding for non-M%8 prefill shapes (Phase 4). 

46# Set to False to revert to Phase 3 BM=4 static path (for benchmarking). 

47_ENABLE_PADDING = True 

48 

49 

50# --------------------------------------------------------------------------- 

51# Fused + tiled kernel: FP32 input → INT8 quant → tiled INT8 GEMM → FP32 out 

52# Used for prefill paths (M≥4, BK=32) where B tile is contiguous in memory. 

53# --------------------------------------------------------------------------- 

54 

55 

56@triton.jit 

57def _i8mm_fused_tiled_kernel( 

58 a_ptr, 

59 b_ptr, 

60 c_ptr, 

61 M, 

62 N, 

63 K, 

64 stride_am, 

65 stride_ak, 

66 stride_cm, 

67 stride_cn, 

68 N_TILES, # int32: N // BLOCK_N (number of N-tiles) 

69 inv_x_scale, # float32 scalar: 127.0 / x_abs_max 

70 out_scale, # float32 scalar: (x_abs_max / 127.0) * w_scale 

71 BLOCK_M: tl.constexpr, 

72 BLOCK_N: tl.constexpr, # must equal _TILE_BN (64) 

73 BLOCK_K: tl.constexpr, # must equal _TILE_BK (32) 

74): 

75 """ 

76 Fused INT8 GEMM with tiled weight layout. 

77 

78 A[M,K] fp32 (activation, row-major) 

79 B tiled int8 layout [K//BK, N//BN, BK, BN] — each tile contiguous 

80 C[M,N] fp32 output 

81 

82 The tiled layout ensures each B tile load is a contiguous BK*BN-byte 

83 block, eliminating the stride-bk=N cache-miss pattern of row-major [K,N]. 

84 SVE2 i8mm (smmla) path fires as before: both operands are int8. 

85 """ 

86 pid_m = tl.program_id(0) 

87 pid_n = tl.program_id(1) 

88 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

89 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

90 

91 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

92 for k in range(0, tl.cdiv(K, BLOCK_K)): 

93 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K) 

94 

95 # Load FP32 activation tile; quantise to INT8 in-kernel 

96 a_fp32 = tl.load( 

97 a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

98 ) 

99 a_scaled = a_fp32 * inv_x_scale 

100 a_clamped = tl.minimum(tl.maximum(a_scaled, -128.0), 127.0) 

101 a_int8 = a_clamped.to(tl.int8) 

102 

103 # Load tiled B: tile (k, pid_n) is contiguous BK*BN bytes 

104 # Layout: b_ptr[k * N_TILES + pid_n][BK][BN] 

105 b_base = b_ptr + (k * N_TILES + pid_n) * BLOCK_K * BLOCK_N 

106 b = tl.load( 

107 b_base 

108 + tl.arange(0, BLOCK_K)[:, None] * BLOCK_N 

109 + tl.arange(0, BLOCK_N)[None, :] 

110 ) 

111 acc += tl.dot(a_int8, b) 

112 

113 # Dequantise: int32 → float32, scale and store 

114 c_fp32 = acc.to(tl.float32) * out_scale 

115 tl.store( 

116 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, 

117 c_fp32, 

118 ) 

119 

120 

121# --------------------------------------------------------------------------- 

122# Fused kernel: FP32 input → INT8 quant → row-major INT8 GEMM → FP32 out 

123# Used for decode paths (M=1,2, BK=4) where tile is tiny (4×64 bytes). 

124# --------------------------------------------------------------------------- 

125 

126 

127@triton.jit 

128def _i8mm_fused_kernel( 

129 a_ptr, 

130 b_ptr, 

131 c_ptr, 

132 M, 

133 N, 

134 K, 

135 stride_am, 

136 stride_ak, 

137 stride_bk, 

138 stride_bn, 

139 stride_cm, 

140 stride_cn, 

141 inv_x_scale, # float32 scalar: 127.0 / x_abs_max 

142 out_scale, # float32 scalar: (x_abs_max / 127.0) * w_scale 

143 BLOCK_M: tl.constexpr, 

144 BLOCK_N: tl.constexpr, 

145 BLOCK_K: tl.constexpr, 

146): 

147 """ 

148 Fused INT8 GEMM with row-major weight layout [K, N]. 

149 Used for decode (M=1,2, BK=4): LLVM fully unrolls K=4 loop. 

150 """ 

151 pid_m = tl.program_id(0) 

152 pid_n = tl.program_id(1) 

153 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

154 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

155 

156 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

157 for k in range(0, tl.cdiv(K, BLOCK_K)): 

158 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K) 

159 

160 a_fp32 = tl.load( 

161 a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

162 ) 

163 a_scaled = a_fp32 * inv_x_scale 

164 a_clamped = tl.minimum(tl.maximum(a_scaled, -128.0), 127.0) 

165 a_int8 = a_clamped.to(tl.int8) 

166 

167 b = tl.load(b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) 

168 acc += tl.dot(a_int8, b) 

169 

170 c_fp32 = acc.to(tl.float32) * out_scale 

171 tl.store( 

172 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, 

173 c_fp32, 

174 ) 

175 

176 

177# --------------------------------------------------------------------------- 

178# Legacy unfused kernel (kept for reference / debugging) 

179# --------------------------------------------------------------------------- 

180 

181 

182@triton.jit 

183def _i8mm_kernel( 

184 a_ptr, 

185 b_ptr, 

186 c_ptr, 

187 M, 

188 N, 

189 K, 

190 stride_am, 

191 stride_ak, 

192 stride_bk, 

193 stride_bn, 

194 stride_cm, 

195 stride_cn, 

196 BLOCK_M: tl.constexpr, 

197 BLOCK_N: tl.constexpr, 

198 BLOCK_K: tl.constexpr, 

199): 

200 """Unfused INT8 GEMM: A int8, B int8 → C int32.""" 

201 pid_m = tl.program_id(0) 

202 pid_n = tl.program_id(1) 

203 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

204 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

205 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) 

206 for k in range(0, tl.cdiv(K, BLOCK_K)): 

207 offs_k = k * BLOCK_K + tl.arange(0, BLOCK_K) 

208 a = tl.load(a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) 

209 b = tl.load(b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) 

210 acc += tl.dot(a, b) 

211 tl.store( 

212 c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, 

213 acc.to(tl.int32), 

214 ) 

215 

216 

217# --------------------------------------------------------------------------- 

218# Weight cache 

219# --------------------------------------------------------------------------- 

220 

221# w_raw.data_ptr() → (weight_kn [K,N], weight_tiled [K//BK,N//BN,BK,BN] or None, 

222# weight_scale float, bias or None) 

223_weight_cache: dict = {} 

224 

225 

226def _get_weight(W_prepack): 

227 w, bias = W_prepack.unpack() # w: qint8 [N, K] 

228 key = w.data_ptr() # stable physical address 

229 if key in _weight_cache: 

230 return _weight_cache[key] 

231 

232 # Row-major [K, N] for decode (M=1,2, BK=4) 

233 weight_kn = w.int_repr().T.contiguous() # int8 [K, N] 

234 K, N = weight_kn.shape 

235 

236 # Tiled [K//BK, N//BN, BK, BN] for prefill (M≥4, BK=32, BN=64) 

237 # Each tile is BK*BN contiguous bytes → eliminates strided cache misses. 

238 BK, BN = _TILE_BK, _TILE_BN 

239 if K % BK == 0 and N % BN == 0: 

240 weight_tiled = ( 

241 weight_kn.reshape(K // BK, BK, N // BN, BN).permute(0, 2, 1, 3).contiguous() 

242 ) # int8 [K//BK, N//BN, BK, BN] 

243 else: 

244 weight_tiled = None 

245 logger.debug( 

246 "FlagGems ARM: K=%d N=%d not divisible by BK=%d BN=%d; " 

247 "tiled layout disabled for this layer", 

248 K, 

249 N, 

250 BK, 

251 BN, 

252 ) 

253 

254 weight_scale = float(w.q_scale()) 

255 entry = (weight_kn, weight_tiled, weight_scale, bias) 

256 _weight_cache[key] = entry 

257 return entry 

258 

259 

260# --------------------------------------------------------------------------- 

261# Core implementation 

262# --------------------------------------------------------------------------- 

263 

264 

265def _triton_quantized_linear_dynamic(X, W_prepack, reduce_range=False): 

266 """ 

267 Triton-CPU replacement for torch.ops.quantized.linear_dynamic (CPU). 

268 

269 X : float32 tensor, shape [..., K] 

270 W_prepack: torch.ScriptObject (LinearPackedParamsBase), qint8 [N, K] 

271 Returns : float32 tensor, shape [..., N] 

272 

273 Decode (M=1,2): _i8mm_fused_kernel, row-major weight [K,N], BK=4. 

274 LLVM fully unrolls K=4 loop → fastest for tiny GEMV. 

275 

276 Prefill (M≥3): _i8mm_fused_tiled_kernel, tiled weight [K//32,N//64,32,64], BK=32. 

277 BM=64 for M%64==0; BM=8 for all other M (with zero-padding if M%8≠0). 

278 Padding: M=84 → M_kernel=88 (+4 zero rows), unlocks Dynamic ForOp path 

279 (100-128 GOPS) vs old BM=4 static path (57-73 GOPS). 

280 """ 

281 weight_kn, weight_tiled, weight_scale, bias = _get_weight(W_prepack) 

282 

283 K = X.shape[-1] 

284 N = weight_kn.shape[1] 

285 orig_shape = X.shape 

286 

287 x2d = X.view(-1, K) 

288 M = x2d.shape[0] 

289 

290 # Compute activation scale (one reduction, unavoidable for per-tensor quant) 

291 x_abs_max = x2d.abs().max().item() 

292 if x_abs_max == 0.0: 

293 out2d = torch.zeros(M, N, dtype=torch.float32) 

294 if bias is not None: 

295 out2d = out2d + bias 

296 return out2d.view(*orig_shape[:-1], N) 

297 

298 inv_x_scale = 127.0 / x_abs_max 

299 out_scale = (x_abs_max / 127.0) * weight_scale 

300 

301 # ------------------------------------------------------------------ 

302 # Decode paths (M=1,2): row-major weight, BK=4, ConvertDotGeneric. 

303 # ------------------------------------------------------------------ 

304 if M == 1: 

305 BM, BN, BK = 1, 64, 4 

306 out2d = torch.empty(M, N, dtype=torch.float32) 

307 _i8mm_fused_kernel[(1, N // BN)]( 

308 x2d, 

309 weight_kn, 

310 out2d, 

311 M, 

312 N, 

313 K, 

314 x2d.stride(0), 

315 x2d.stride(1), 

316 weight_kn.stride(0), 

317 weight_kn.stride(1), 

318 out2d.stride(0), 

319 out2d.stride(1), 

320 inv_x_scale=inv_x_scale, 

321 out_scale=out_scale, 

322 BLOCK_M=BM, 

323 BLOCK_N=BN, 

324 BLOCK_K=BK, 

325 ) 

326 

327 elif M == 2: 

328 BM, BN, BK = 2, 64, 4 

329 out2d = torch.empty(M, N, dtype=torch.float32) 

330 _i8mm_fused_kernel[(1, N // BN)]( 

331 x2d, 

332 weight_kn, 

333 out2d, 

334 M, 

335 N, 

336 K, 

337 x2d.stride(0), 

338 x2d.stride(1), 

339 weight_kn.stride(0), 

340 weight_kn.stride(1), 

341 out2d.stride(0), 

342 out2d.stride(1), 

343 inv_x_scale=inv_x_scale, 

344 out_scale=out_scale, 

345 BLOCK_M=BM, 

346 BLOCK_N=BN, 

347 BLOCK_K=BK, 

348 ) 

349 

350 # ------------------------------------------------------------------ 

351 # Prefill path (M≥3). 

352 # 

353 # Routing observed empirically via A/B vs commit 80be6a2e^: 

354 # M%64==0 → legacy _i8mm_kernel (fused kernel regresses ~15-20% 

355 # here due to BM=64 BK=32 epilog register pressure). 

356 # M=4 → legacy BM=1 BK=4 (BM=4 BK=32 SVE2 static path is slower 

357 # than BM=1 BK=4 ConvertDotGeneric for this tiny shape). 

358 # M%8==0 → fused kernel BM=8 BK=32 (SVE2 i8mm Dynamic ForOp, ~1.4x). 

359 # otherwise → pad to %8, fused BM=8 BK=32. 

360 # ------------------------------------------------------------------ 

361 elif M % 64 == 0: 

362 # Legacy path: external quant → _i8mm_kernel (int8×int8→int32) → external dequant. 

363 # Fused kernel's BM=64 epilog hurts LLVM register allocation here. 

364 BM, BN, BK = 64, 64, 32 

365 # NOTE: no .round_() — match fused kernel's .to(int8) truncate behavior. 

366 # Rounding here (when fused kernel truncates) creates argmax drift at 

367 # long generations because this M's rounding mode differs from other M's. 

368 x_q = (x2d * inv_x_scale).clamp_(-128, 127).to(torch.int8) 

369 c_i32 = torch.empty(M, N, dtype=torch.int32) 

370 _i8mm_kernel[(M // BM, N // BN)]( 

371 x_q, 

372 weight_kn, 

373 c_i32, 

374 M, 

375 N, 

376 K, 

377 x_q.stride(0), 

378 x_q.stride(1), 

379 weight_kn.stride(0), 

380 weight_kn.stride(1), 

381 c_i32.stride(0), 

382 c_i32.stride(1), 

383 BLOCK_M=BM, 

384 BLOCK_N=BN, 

385 BLOCK_K=BK, 

386 ) 

387 out2d = c_i32.to(torch.float32).mul_(out_scale) 

388 

389 elif M == 4: 

390 # Legacy BM=1 BK=4 path: faster than BM=4 BK=32 static i8mm here. 

391 BM, BN, BK = 1, 64, 4 

392 # NOTE: no .round_() — match fused kernel's .to(int8) truncate behavior. 

393 # Rounding here (when fused kernel truncates) creates argmax drift at 

394 # long generations because this M's rounding mode differs from other M's. 

395 x_q = (x2d * inv_x_scale).clamp_(-128, 127).to(torch.int8) 

396 c_i32 = torch.empty(M, N, dtype=torch.int32) 

397 _i8mm_kernel[(M, N // BN)]( 

398 x_q, 

399 weight_kn, 

400 c_i32, 

401 M, 

402 N, 

403 K, 

404 x_q.stride(0), 

405 x_q.stride(1), 

406 weight_kn.stride(0), 

407 weight_kn.stride(1), 

408 c_i32.stride(0), 

409 c_i32.stride(1), 

410 BLOCK_M=BM, 

411 BLOCK_N=BN, 

412 BLOCK_K=BK, 

413 ) 

414 out2d = c_i32.to(torch.float32).mul_(out_scale) 

415 

416 else: 

417 # Fused kernel path: BM=8 BK=32 (Dynamic ForOp SVE2 i8mm, wins here). 

418 use_tiled = weight_tiled is not None 

419 BN, BK = 64, 32 

420 

421 if M % 8 == 0: 

422 BM = 8 

423 x_kernel, M_kernel = x2d, M 

424 elif _ENABLE_PADDING: 

425 # Pad to next multiple of 8 → Dynamic ForOp path 

426 # e.g. M=84 → M_kernel=88 (4 extra zero rows) 

427 M_kernel = ((M + 7) // 8) * 8 

428 BM = 8 

429 x_kernel = torch.zeros(M_kernel, K, dtype=x2d.dtype) 

430 x_kernel[:M].copy_(x2d) 

431 else: 

432 # Phase 3 fallback: no padding, BM=4 if aligned else BM=1 

433 BM = 4 if M % 4 == 0 else 1 

434 x_kernel, M_kernel = x2d, M 

435 

436 out_kernel = torch.empty(M_kernel, N, dtype=torch.float32) 

437 grid = (M_kernel // BM, N // BN) 

438 

439 if use_tiled: 

440 _i8mm_fused_tiled_kernel[grid]( 

441 x_kernel, 

442 weight_tiled, 

443 out_kernel, 

444 M_kernel, 

445 N, 

446 K, 

447 x_kernel.stride(0), 

448 x_kernel.stride(1), 

449 out_kernel.stride(0), 

450 out_kernel.stride(1), 

451 N // BN, 

452 inv_x_scale=inv_x_scale, 

453 out_scale=out_scale, 

454 BLOCK_M=BM, 

455 BLOCK_N=BN, 

456 BLOCK_K=BK, 

457 ) 

458 else: 

459 _i8mm_fused_kernel[grid]( 

460 x_kernel, 

461 weight_kn, 

462 out_kernel, 

463 M_kernel, 

464 N, 

465 K, 

466 x_kernel.stride(0), 

467 x_kernel.stride(1), 

468 weight_kn.stride(0), 

469 weight_kn.stride(1), 

470 out_kernel.stride(0), 

471 out_kernel.stride(1), 

472 inv_x_scale=inv_x_scale, 

473 out_scale=out_scale, 

474 BLOCK_M=BM, 

475 BLOCK_N=BN, 

476 BLOCK_K=BK, 

477 ) 

478 

479 # Slice off the padding rows (out_kernel[:M] is a view, no copy) 

480 out2d = out_kernel[:M] if M_kernel != M else out_kernel 

481 

482 if bias is not None: 

483 out2d = out2d + bias 

484 return out2d.view(*orig_shape[:-1], N) 

485 

486 

487# --------------------------------------------------------------------------- 

488# Registration 

489# --------------------------------------------------------------------------- 

490 

491_quantized_lib = None # keep reference alive to prevent GC 

492 

493 

494def register(): 

495 """ 

496 Register Triton implementation for quantized::linear_dynamic on CPU. 

497 Idempotent: safe to call multiple times. 

498 """ 

499 global _quantized_lib 

500 if _quantized_lib is not None: 

501 return 

502 

503 try: 

504 _quantized_lib = torch.library.Library("quantized", "IMPL") 

505 _quantized_lib.impl( 

506 "linear_dynamic", 

507 _triton_quantized_linear_dynamic, 

508 "CPU", 

509 allow_override=True, 

510 ) 

511 logger.debug( 

512 "FlagGems ARM: registered Triton-CPU i8mm (fused+tiled) for quantized::linear_dynamic" 

513 ) 

514 except Exception as e: 

515 logger.warning( 

516 f"FlagGems ARM: failed to register quantized::linear_dynamic override: {e}" 

517 )