Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_matmul.py: 0%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import functools 

2import logging 

3import os 

4from typing import Any, Dict, List, Optional 

5 

6import torch 

7import triton 

8import triton.language as tl 

9import yaml 

10 

11from flag_gems import runtime 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry, libtuner 

14 

15logger = logging.getLogger( 

16 "flag_gems.runtime.backend._nvidia.hopper.ops.w8a8_block_fp8_matmul" 

17) 

18CACHE_USAGE_THRESHOLD = 0.8 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join( 

21 os.path.dirname(__file__), 

22 "..", 

23 "w8a8_block_fp8_matmul_hopper_expand.yaml", 

24 ) 

25) 

26 

27 

28@functools.lru_cache 

29def get_w8a8_block_fp8_hopper_configs(N: int, K: int) -> Optional[Dict[int, Any]]: 

30 device_name = torch.cuda.get_device_name().replace(" ", "_") 

31 file_name = "w8a8_block_fp8_matmul_hopper.yaml" 

32 

33 cfg_file = os.path.join(os.path.dirname(__file__), "..", file_name) 

34 

35 if os.path.exists(cfg_file): 

36 with open(cfg_file) as f: 

37 logger.info( 

38 "Using config from %s for W8A8 block FP8 kernel.", 

39 cfg_file, 

40 ) 

41 dev_data = yaml.safe_load(f).get(device_name, {}) 

42 NK_data = dev_data.get(f"{N},{K}", {}) 

43 

44 result = {} 

45 for k, p in NK_data.items(): 

46 # unpack the list into dictionary 

47 result[int(k)] = { 

48 "BLOCK_SIZE_M": p[0], 

49 "BLOCK_SIZE_N": p[1], 

50 "BLOCK_SIZE_K": p[2], 

51 "GROUP_SIZE_M": p[3], 

52 "num_warps": p[4], 

53 "num_stages": p[5], 

54 } 

55 

56 if not result: 

57 return None 

58 return result 

59 

60 logger.warning( 

61 "Using default W8A8 Block FP8 kernel config. Performance might " 

62 "be sub-optimal! Config file not found at %s", 

63 cfg_file, 

64 ) 

65 return None 

66 

67 

68def _get_placeholder_tuner_configs(pre_hook=None): 

69 # Placeholder config for libtuner initialization before runtime shapes are known. 

70 return [ 

71 triton.Config( 

72 { 

73 "BLOCK_M": 64, 

74 "BLOCK_N": 64, 

75 "BLOCK_K": 128, 

76 "GROUP_M": 8, 

77 }, 

78 num_stages=3, 

79 num_warps=4, 

80 pre_hook=pre_hook, 

81 ) 

82 ] 

83 

84 

85def _get_fixed_matmul_meta(M: int, N: int, K: int, block_n: int, block_k: int): 

86 configs = get_w8a8_block_fp8_hopper_configs(N, K) 

87 if not configs: 

88 return { 

89 "BLOCK_M": 64, 

90 "BLOCK_N": block_n, 

91 "BLOCK_K": block_k, 

92 "GROUP_M": 32, 

93 "num_warps": 4, 

94 "num_stages": 2, 

95 } 

96 

97 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

98 return { 

99 "BLOCK_M": config["BLOCK_SIZE_M"], 

100 "BLOCK_N": config["BLOCK_SIZE_N"], 

101 "BLOCK_K": config["BLOCK_SIZE_K"], 

102 "GROUP_M": config["GROUP_SIZE_M"], 

103 "num_warps": config["num_warps"], 

104 "num_stages": config["num_stages"], 

105 } 

106 

107 

108@libentry() 

109@libtuner( 

110 configs=_get_placeholder_tuner_configs(pre_hook=None), 

111 key=["M", "N", "K", "stride_am", "stride_bk"], 

112 strategy=["align32", "align32", "align32", "align32", "align32"], 

113 warmup=5, 

114 rep=5, 

115 flagtune_op_name="w8a8_block_fp8_matmul", 

116 flagtune_expand_op_name="w8a8_block_fp8_general", 

117 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

118 flagtune_pre_hook=None, 

119) 

120@triton.jit 

121def w8a8_block_fp8_matmul_kernel_general( 

122 A, 

123 B, 

124 C, 

125 As, 

126 Bs, 

127 M, 

128 N, 

129 K, 

130 group_n, 

131 group_k, 

132 stride_am, 

133 stride_ak, 

134 stride_bk, 

135 stride_bn, 

136 stride_cm, 

137 stride_cn, 

138 stride_As_m, 

139 stride_As_k, 

140 stride_Bs_k, 

141 stride_Bs_n, 

142 BLOCK_M: tl.constexpr, 

143 BLOCK_N: tl.constexpr, 

144 BLOCK_K: tl.constexpr, 

145 GROUP_M: tl.constexpr, 

146): 

147 pid = tl.program_id(axis=0) 

148 num_pid_m = tl.cdiv(M, BLOCK_M) 

149 num_pid_n = tl.cdiv(N, BLOCK_N) 

150 num_pid_in_group = GROUP_M * num_pid_n 

151 group_id = pid // num_pid_in_group 

152 first_pid_m = group_id * GROUP_M 

153 group_size_m = min(num_pid_m - first_pid_m, GROUP_M) 

154 pid_m = first_pid_m + (pid % group_size_m) 

155 pid_n = (pid % num_pid_in_group) // group_size_m 

156 

157 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

158 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

159 offs_k = tl.arange(0, BLOCK_K) 

160 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

161 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

162 

163 As_ptrs = As + offs_am * stride_As_m 

164 offs_bsn = offs_bn // group_n 

165 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

166 

167 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

168 for k in range(0, tl.cdiv(K, BLOCK_K)): 

169 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) 

170 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) 

171 

172 k_start = k * BLOCK_K 

173 offs_ks = k_start // group_k 

174 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

175 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

176 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

177 a_ptrs += BLOCK_K * stride_ak 

178 b_ptrs += BLOCK_K * stride_bk 

179 

180 if C.dtype.element_ty == tl.bfloat16: 

181 c = acc.to(tl.bfloat16) 

182 elif C.dtype.element_ty == tl.float16: 

183 c = acc.to(tl.float16) 

184 else: 

185 c = acc.to(tl.float32) 

186 

187 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

188 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

189 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

190 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

191 tl.store(c_ptrs, c, mask=c_mask) 

192 

193 

194@libentry() 

195@libtuner( 

196 configs=_get_placeholder_tuner_configs(pre_hook=None), 

197 key=["M", "N", "K", "stride_am", "stride_bk"], 

198 strategy=["align32", "align32", "align32", "align32", "align32"], 

199 warmup=5, 

200 rep=5, 

201 flagtune_op_name="w8a8_block_fp8_matmul", 

202 flagtune_expand_op_name="w8a8_block_fp8_general_splitk", 

203 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

204 flagtune_pre_hook=None, 

205) 

206@triton.jit 

207def w8a8_block_fp8_matmul_kernel_splitk( 

208 A, 

209 B, 

210 C, 

211 As, 

212 Bs, 

213 M, 

214 N, 

215 K, 

216 group_n, 

217 group_k, 

218 stride_am, 

219 stride_ak, 

220 stride_bk, 

221 stride_bn, 

222 stride_cm, 

223 stride_cn, 

224 stride_As_m, 

225 stride_As_k, 

226 stride_Bs_k, 

227 stride_Bs_n, 

228 BLOCK_M: tl.constexpr, 

229 BLOCK_N: tl.constexpr, 

230 BLOCK_K: tl.constexpr, 

231 SPLIT_K: tl.constexpr, 

232): 

233 pid = tl.program_id(0) 

234 pid_k = tl.program_id(1) 

235 

236 # grid_m = tl.cdiv(M, BLOCK_M) 

237 grid_n = tl.cdiv(N, BLOCK_N) 

238 pid_m = pid // grid_n 

239 pid_n = pid % grid_n 

240 

241 offset_am = pid_m * BLOCK_M 

242 offset_bn = pid_n * BLOCK_N 

243 offs_am = offset_am + tl.arange(0, BLOCK_M) 

244 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

245 

246 total_k_iters = tl.cdiv(K, BLOCK_K) 

247 k_per_split = tl.cdiv(total_k_iters, SPLIT_K) 

248 k_start = pid_k * k_per_split 

249 k_end = min((pid_k + 1) * k_per_split, total_k_iters) 

250 

251 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

252 for k in range(k_start, k_end): 

253 offset_k = k * BLOCK_K 

254 offs_k = offset_k + tl.arange(0, BLOCK_K) 

255 

256 a = tl.load( 

257 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak, 

258 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K), 

259 other=0.0, 

260 ) 

261 b = tl.load( 

262 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn, 

263 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), 

264 other=0.0, 

265 ) 

266 

267 offs_ks = offset_k // group_k 

268 a_s = tl.load( 

269 As + offs_am * stride_As_m + offs_ks * stride_As_k, 

270 mask=offs_am < M, 

271 other=0.0, 

272 ) 

273 b_s = tl.load( 

274 Bs + offs_ks * stride_Bs_k + (offs_bn // group_n) * stride_Bs_n, 

275 mask=offs_bn < N, 

276 other=0.0, 

277 ) 

278 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

279 

280 offs_cm = offset_am + tl.arange(0, BLOCK_M) 

281 offs_cn = offset_bn + tl.arange(0, BLOCK_N) 

282 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn 

283 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

284 if C.dtype.element_ty == tl.bfloat16: 

285 tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mask) 

286 elif C.dtype.element_ty == tl.float16: 

287 tl.atomic_add(c_ptrs, acc.to(tl.float16), mask=mask) 

288 else: 

289 tl.atomic_add(c_ptrs, acc.to(tl.float32), mask=mask) 

290 

291 

292def general_w8a8_block_fp8_matmul(a, b, c, a_s, b_s, M, N, K, group_n, group_k): 

293 logger.debug( 

294 "GEMS w8a8_block_fp8_matmul-hopper, [scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

295 "[A column-major]: %s, [B column-major]: %s", 

296 M, 

297 N, 

298 K, 

299 a.stride(0) == 1, 

300 b.stride(0) == 1, 

301 ) 

302 

303 # Default W8A8 keeps the existing fixed-meta path. When explicitly included 

304 # in flag_gems.flagtune(...), launch through LibTuner so expanded configs 

305 # are selected by the same registry-driven mechanism used by mm. 

306 use_flagtune = runtime.flagtune_enabled("w8a8_block_fp8_matmul") 

307 

308 # Split-K path for small-N, large-K shapes 

309 if M < 2048 and N < 2112 and K >= 4096: 

310 if use_flagtune: 

311 splitk_grid = lambda META: ( 

312 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

313 META["SPLIT_K"], 

314 ) 

315 c.zero_() 

316 with torch_device_fn.device(a.device): 

317 w8a8_block_fp8_matmul_kernel_splitk[splitk_grid]( 

318 a, 

319 b, 

320 c, 

321 a_s, 

322 b_s, 

323 M, 

324 N, 

325 K, 

326 group_n, 

327 group_k, 

328 a.stride(0), 

329 a.stride(1), 

330 b.stride(1), 

331 b.stride(0), 

332 c.stride(0), 

333 c.stride(1), 

334 a_s.stride(0), 

335 a_s.stride(1), 

336 b_s.stride(1), 

337 b_s.stride(0), 

338 ) 

339 else: 

340 SPLITK_BLOCK_K = 128 

341 SPLITK_BLOCK_M = 16 if M <= 16 else 64 

342 SPLITK_BLOCK_N = 64 if N > 256 else 32 

343 

344 grid_m = triton.cdiv(M, SPLITK_BLOCK_M) 

345 grid_n = triton.cdiv(N, SPLITK_BLOCK_N) 

346 grid_mn = grid_m * grid_n 

347 total_k_iters = triton.cdiv(K, SPLITK_BLOCK_K) 

348 

349 SM_COUNT = torch.cuda.get_device_properties(a.device).multi_processor_count 

350 split_k = min(total_k_iters, max(4, 2 * SM_COUNT // max(grid_mn, 1))) 

351 

352 c.zero_() 

353 splitk_grid = (grid_mn, split_k) 

354 

355 with torch_device_fn.device(a.device): 

356 w8a8_block_fp8_matmul_kernel_splitk.fn.fn[splitk_grid]( 

357 a, 

358 b, 

359 c, 

360 a_s, 

361 b_s, 

362 M, 

363 N, 

364 K, 

365 group_n, 

366 group_k, 

367 a.stride(0), 

368 a.stride(1), 

369 b.stride(1), 

370 b.stride(0), 

371 c.stride(0), 

372 c.stride(1), 

373 a_s.stride(0), 

374 a_s.stride(1), 

375 b_s.stride(1), 

376 b_s.stride(0), 

377 BLOCK_M=SPLITK_BLOCK_M, 

378 BLOCK_N=SPLITK_BLOCK_N, 

379 BLOCK_K=SPLITK_BLOCK_K, 

380 SPLIT_K=split_k, 

381 ) 

382 return c 

383 

384 else: 

385 grid = lambda meta: ( 

386 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

387 ) 

388 fixed_meta = ( 

389 None 

390 if use_flagtune 

391 else _get_fixed_matmul_meta(M, N, K, block_n=group_n, block_k=group_k) 

392 ) 

393 

394 def alloc_fn(size: int, align: int, stream: Optional[int]): 

395 return torch.empty(size, dtype=torch.int8, device=a.device) 

396 

397 triton.set_allocator(alloc_fn) 

398 if use_flagtune: 

399 launch = lambda: w8a8_block_fp8_matmul_kernel_general[grid]( 

400 a, 

401 b, 

402 c, 

403 a_s, 

404 b_s, 

405 M, 

406 N, 

407 K, 

408 group_n, 

409 group_k, 

410 a.stride(0), 

411 a.stride(1), 

412 b.stride(1), 

413 b.stride(0), 

414 c.stride(0), 

415 c.stride(1), 

416 a_s.stride(0), 

417 a_s.stride(1), 

418 b_s.stride(1), 

419 b_s.stride(0), 

420 ) 

421 else: 

422 launch = lambda: w8a8_block_fp8_matmul_kernel_general.fn.fn[grid]( 

423 a, 

424 b, 

425 c, 

426 a_s, 

427 b_s, 

428 M, 

429 N, 

430 K, 

431 group_n, 

432 group_k, 

433 a.stride(0), 

434 a.stride(1), 

435 b.stride(1), 

436 b.stride(0), 

437 c.stride(0), 

438 c.stride(1), 

439 a_s.stride(0), 

440 a_s.stride(1), 

441 b_s.stride(1), 

442 b_s.stride(0), 

443 **fixed_meta, 

444 ) 

445 

446 with torch_device_fn.device(a.device): 

447 launch() 

448 return c 

449 

450 

451def w8a8_block_fp8_matmul( 

452 A: torch.Tensor, 

453 B: torch.Tensor, 

454 As: torch.Tensor, 

455 Bs: torch.Tensor, 

456 block_size: List[int], 

457 output_dtype: torch.dtype = torch.float16, 

458) -> torch.Tensor: 

459 device = A.device 

460 assert len(block_size) == 2 

461 block_n, block_k = block_size 

462 

463 # handle non-contiguous inputs if necessary 

464 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1: 

465 A = A.contiguous() 

466 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1: 

467 B = B.contiguous() 

468 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1: 

469 As = As.contiguous() 

470 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1: 

471 Bs = Bs.contiguous() 

472 

473 # checks constraints 

474 assert A.shape[-1] == B.shape[-1], "incompatible dimensions" 

475 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch" 

476 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape" 

477 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D" 

478 

479 M = A.numel() // A.shape[-1] 

480 N, K = B.shape 

481 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension" 

482 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension" 

483 

484 # allocates output 

485 output_shape = A.shape[:-1] + (N,) 

486 c = torch.empty(output_shape, device=device, dtype=output_dtype) 

487 

488 a_2d = A.reshape(M, K) 

489 as_2d = As.reshape(M, As.shape[-1]) 

490 c_2d = c.reshape(M, N) 

491 

492 return general_w8a8_block_fp8_matmul( 

493 a_2d, 

494 B, 

495 c_2d, 

496 as_2d, 

497 Bs, 

498 M, 

499 N, 

500 K, 

501 block_n, 

502 block_k, 

503 ).reshape(c.shape)