Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_matmul.py: 0%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import functools 

2import logging 

3import os 

4from typing import Any, Dict, List, Optional 

5 

6import torch 

7import triton 

8import triton.language as tl 

9import yaml 

10 

11from flag_gems import runtime 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry, libtuner 

14 

15logger = logging.getLogger( 

16 "flag_gems.runtime.backend._nvidia.hopper.ops.w8a8_block_fp8_matmul" 

17) 

18CACHE_USAGE_THRESHOLD = 0.8 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join( 

21 os.path.dirname(__file__), 

22 "..", 

23 "w8a8_block_fp8_matmul_hopper_expand.yaml", 

24 ) 

25) 

26 

27 

28@functools.lru_cache 

29def get_w8a8_block_fp8_hopper_configs(N: int, K: int) -> Optional[Dict[int, Any]]: 

30 device_name = torch.cuda.get_device_name().replace(" ", "_") 

31 file_name = "w8a8_block_fp8_matmul_hopper.yaml" 

32 

33 cfg_file = os.path.join(os.path.dirname(__file__), "..", file_name) 

34 

35 if os.path.exists(cfg_file): 

36 with open(cfg_file) as f: 

37 logger.info( 

38 "Using config from %s for W8A8 block FP8 kernel.", 

39 cfg_file, 

40 ) 

41 dev_data = yaml.safe_load(f).get(device_name, {}) 

42 NK_data = dev_data.get(f"{N},{K}", {}) 

43 

44 result = {} 

45 for k, p in NK_data.items(): 

46 # unpack the list into dictionary 

47 result[int(k)] = { 

48 "BLOCK_SIZE_M": p[0], 

49 "BLOCK_SIZE_N": p[1], 

50 "BLOCK_SIZE_K": p[2], 

51 "GROUP_SIZE_M": p[3], 

52 "num_warps": p[4], 

53 "num_stages": p[5], 

54 } 

55 

56 if not result: 

57 return None 

58 return result 

59 

60 logger.warning( 

61 "Using default W8A8 Block FP8 kernel config. Performance might " 

62 "be sub-optimal! Config file not found at %s", 

63 cfg_file, 

64 ) 

65 return None 

66 

67 

68def _get_placeholder_tuner_configs(pre_hook=None): 

69 # Placeholder config for libtuner initialization before runtime shapes are known. 

70 return [ 

71 triton.Config( 

72 { 

73 "BLOCK_M": 64, 

74 "BLOCK_N": 64, 

75 "BLOCK_K": 128, 

76 "GROUP_M": 8, 

77 }, 

78 num_stages=3, 

79 num_warps=4, 

80 pre_hook=pre_hook, 

81 ) 

82 ] 

83 

84 

85def _get_fixed_matmul_meta(M: int, N: int, K: int, block_n: int, block_k: int): 

86 configs = get_w8a8_block_fp8_hopper_configs(N, K) 

87 if not configs: 

88 return { 

89 "BLOCK_M": 64, 

90 "BLOCK_N": block_n, 

91 "BLOCK_K": block_k, 

92 "GROUP_M": 32, 

93 "num_warps": 4, 

94 "num_stages": 2, 

95 } 

96 

97 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

98 return { 

99 "BLOCK_M": config["BLOCK_SIZE_M"], 

100 "BLOCK_N": config["BLOCK_SIZE_N"], 

101 "BLOCK_K": config["BLOCK_SIZE_K"], 

102 "GROUP_M": config["GROUP_SIZE_M"], 

103 "num_warps": config["num_warps"], 

104 "num_stages": config["num_stages"], 

105 } 

106 

107 

108@libentry() 

109@libtuner( 

110 configs=runtime.ops_get_configs( 

111 "w8a8_block_fp8_general", 

112 pre_hook=None, 

113 yaml_path=EXPAND_CONFIG_FILENAME, 

114 ) 

115 if os.environ.get("USE_FLAGTUNE") == "1" 

116 else _get_placeholder_tuner_configs(pre_hook=None), 

117 key=["M", "N", "K", "stride_am", "stride_bk"], 

118 strategy=runtime.get_expand_config( 

119 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME 

120 )["strategy"] 

121 if os.environ.get("USE_FLAGTUNE") == "1" 

122 else ["align32", "align32", "align32", "align32", "align32"], 

123 warmup=5, 

124 rep=5, 

125) 

126@triton.jit 

127def w8a8_block_fp8_matmul_kernel_general( 

128 A, 

129 B, 

130 C, 

131 As, 

132 Bs, 

133 M, 

134 N, 

135 K, 

136 group_n, 

137 group_k, 

138 stride_am, 

139 stride_ak, 

140 stride_bk, 

141 stride_bn, 

142 stride_cm, 

143 stride_cn, 

144 stride_As_m, 

145 stride_As_k, 

146 stride_Bs_k, 

147 stride_Bs_n, 

148 BLOCK_M: tl.constexpr, 

149 BLOCK_N: tl.constexpr, 

150 BLOCK_K: tl.constexpr, 

151 GROUP_M: tl.constexpr, 

152): 

153 pid = tl.program_id(axis=0) 

154 num_pid_m = tl.cdiv(M, BLOCK_M) 

155 num_pid_n = tl.cdiv(N, BLOCK_N) 

156 num_pid_in_group = GROUP_M * num_pid_n 

157 group_id = pid // num_pid_in_group 

158 first_pid_m = group_id * GROUP_M 

159 group_size_m = min(num_pid_m - first_pid_m, GROUP_M) 

160 pid_m = first_pid_m + (pid % group_size_m) 

161 pid_n = (pid % num_pid_in_group) // group_size_m 

162 

163 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

164 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

165 offs_k = tl.arange(0, BLOCK_K) 

166 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

167 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

168 

169 As_ptrs = As + offs_am * stride_As_m 

170 offs_bsn = offs_bn // group_n 

171 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

172 

173 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

174 for k in range(0, tl.cdiv(K, BLOCK_K)): 

175 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) 

176 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) 

177 

178 k_start = k * BLOCK_K 

179 offs_ks = k_start // group_k 

180 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

181 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

182 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

183 a_ptrs += BLOCK_K * stride_ak 

184 b_ptrs += BLOCK_K * stride_bk 

185 

186 if C.dtype.element_ty == tl.bfloat16: 

187 c = acc.to(tl.bfloat16) 

188 elif C.dtype.element_ty == tl.float16: 

189 c = acc.to(tl.float16) 

190 else: 

191 c = acc.to(tl.float32) 

192 

193 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

194 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

195 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

196 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

197 tl.store(c_ptrs, c, mask=c_mask) 

198 

199 

200@libentry() 

201@libtuner( 

202 configs=runtime.ops_get_configs( 

203 "w8a8_block_fp8_general_splitk", 

204 yaml_path=EXPAND_CONFIG_FILENAME, 

205 ) 

206 if os.environ.get("USE_FLAGTUNE") == "1" 

207 else _get_placeholder_tuner_configs(pre_hook=None), 

208 key=["M", "N", "K", "stride_am", "stride_bk"], 

209 strategy=runtime.get_expand_config( 

210 "w8a8_block_fp8_general_splitk", yaml_path=EXPAND_CONFIG_FILENAME 

211 )["strategy"] 

212 if os.environ.get("USE_FLAGTUNE") == "1" 

213 else ["align32", "align32", "align32", "align32", "align32"], 

214 warmup=5, 

215 rep=5, 

216) 

217@triton.jit 

218def w8a8_block_fp8_matmul_kernel_splitk( 

219 A, 

220 B, 

221 C, 

222 As, 

223 Bs, 

224 M, 

225 N, 

226 K, 

227 group_n, 

228 group_k, 

229 stride_am, 

230 stride_ak, 

231 stride_bk, 

232 stride_bn, 

233 stride_cm, 

234 stride_cn, 

235 stride_As_m, 

236 stride_As_k, 

237 stride_Bs_k, 

238 stride_Bs_n, 

239 BLOCK_M: tl.constexpr, 

240 BLOCK_N: tl.constexpr, 

241 BLOCK_K: tl.constexpr, 

242 SPLIT_K: tl.constexpr, 

243): 

244 pid = tl.program_id(0) 

245 pid_k = tl.program_id(1) 

246 

247 # grid_m = tl.cdiv(M, BLOCK_M) 

248 grid_n = tl.cdiv(N, BLOCK_N) 

249 pid_m = pid // grid_n 

250 pid_n = pid % grid_n 

251 

252 offset_am = pid_m * BLOCK_M 

253 offset_bn = pid_n * BLOCK_N 

254 offs_am = offset_am + tl.arange(0, BLOCK_M) 

255 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

256 

257 total_k_iters = tl.cdiv(K, BLOCK_K) 

258 k_per_split = tl.cdiv(total_k_iters, SPLIT_K) 

259 k_start = pid_k * k_per_split 

260 k_end = min((pid_k + 1) * k_per_split, total_k_iters) 

261 

262 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

263 for k in range(k_start, k_end): 

264 offset_k = k * BLOCK_K 

265 offs_k = offset_k + tl.arange(0, BLOCK_K) 

266 

267 a = tl.load( 

268 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak, 

269 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K), 

270 other=0.0, 

271 ) 

272 b = tl.load( 

273 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn, 

274 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), 

275 other=0.0, 

276 ) 

277 

278 offs_ks = offset_k // group_k 

279 a_s = tl.load( 

280 As + offs_am * stride_As_m + offs_ks * stride_As_k, 

281 mask=offs_am < M, 

282 other=0.0, 

283 ) 

284 b_s = tl.load( 

285 Bs + offs_ks * stride_Bs_k + (offs_bn // group_n) * stride_Bs_n, 

286 mask=offs_bn < N, 

287 other=0.0, 

288 ) 

289 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

290 

291 offs_cm = offset_am + tl.arange(0, BLOCK_M) 

292 offs_cn = offset_bn + tl.arange(0, BLOCK_N) 

293 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn 

294 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

295 if C.dtype.element_ty == tl.bfloat16: 

296 tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mask) 

297 elif C.dtype.element_ty == tl.float16: 

298 tl.atomic_add(c_ptrs, acc.to(tl.float16), mask=mask) 

299 else: 

300 tl.atomic_add(c_ptrs, acc.to(tl.float32), mask=mask) 

301 

302 

303def general_w8a8_block_fp8_matmul(a, b, c, a_s, b_s, M, N, K, group_n, group_k): 

304 logger.debug( 

305 "GEMS w8a8_block_fp8_matmul-hopper, [scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

306 "[A column-major]: %s, [B column-major]: %s", 

307 M, 

308 N, 

309 K, 

310 a.stride(0) == 1, 

311 b.stride(0) == 1, 

312 ) 

313 

314 use_flagtune = os.environ.get("USE_FLAGTUNE") == "1" 

315 

316 # Split-K path for small-N, large-K shapes 

317 if M < 2048 and N < 2112 and K >= 4096: 

318 if use_flagtune: 

319 splitk_grid = lambda META: ( 

320 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

321 META["SPLIT_K"], 

322 ) 

323 c.zero_() 

324 with torch_device_fn.device(a.device): 

325 w8a8_block_fp8_matmul_kernel_splitk[splitk_grid]( 

326 a, 

327 b, 

328 c, 

329 a_s, 

330 b_s, 

331 M, 

332 N, 

333 K, 

334 group_n, 

335 group_k, 

336 a.stride(0), 

337 a.stride(1), 

338 b.stride(1), 

339 b.stride(0), 

340 c.stride(0), 

341 c.stride(1), 

342 a_s.stride(0), 

343 a_s.stride(1), 

344 b_s.stride(1), 

345 b_s.stride(0), 

346 ) 

347 else: 

348 SPLITK_BLOCK_K = 128 

349 SPLITK_BLOCK_M = 16 if M <= 16 else 64 

350 SPLITK_BLOCK_N = 64 if N > 256 else 32 

351 

352 grid_m = triton.cdiv(M, SPLITK_BLOCK_M) 

353 grid_n = triton.cdiv(N, SPLITK_BLOCK_N) 

354 grid_mn = grid_m * grid_n 

355 total_k_iters = triton.cdiv(K, SPLITK_BLOCK_K) 

356 

357 SM_COUNT = torch.cuda.get_device_properties(a.device).multi_processor_count 

358 split_k = min(total_k_iters, max(4, 2 * SM_COUNT // max(grid_mn, 1))) 

359 

360 c.zero_() 

361 splitk_grid = (grid_mn, split_k) 

362 

363 with torch_device_fn.device(a.device): 

364 w8a8_block_fp8_matmul_kernel_splitk.fn.fn[splitk_grid]( 

365 a, 

366 b, 

367 c, 

368 a_s, 

369 b_s, 

370 M, 

371 N, 

372 K, 

373 group_n, 

374 group_k, 

375 a.stride(0), 

376 a.stride(1), 

377 b.stride(1), 

378 b.stride(0), 

379 c.stride(0), 

380 c.stride(1), 

381 a_s.stride(0), 

382 a_s.stride(1), 

383 b_s.stride(1), 

384 b_s.stride(0), 

385 BLOCK_M=SPLITK_BLOCK_M, 

386 BLOCK_N=SPLITK_BLOCK_N, 

387 BLOCK_K=SPLITK_BLOCK_K, 

388 SPLIT_K=split_k, 

389 ) 

390 return c 

391 

392 else: 

393 grid = lambda meta: ( 

394 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

395 ) 

396 fixed_meta = ( 

397 None 

398 if use_flagtune 

399 else _get_fixed_matmul_meta(M, N, K, block_n=group_n, block_k=group_k) 

400 ) 

401 

402 def alloc_fn(size: int, align: int, stream: Optional[int]): 

403 return torch.empty(size, dtype=torch.int8, device=a.device) 

404 

405 triton.set_allocator(alloc_fn) 

406 if use_flagtune: 

407 launch = lambda: w8a8_block_fp8_matmul_kernel_general[grid]( 

408 a, 

409 b, 

410 c, 

411 a_s, 

412 b_s, 

413 M, 

414 N, 

415 K, 

416 group_n, 

417 group_k, 

418 a.stride(0), 

419 a.stride(1), 

420 b.stride(1), 

421 b.stride(0), 

422 c.stride(0), 

423 c.stride(1), 

424 a_s.stride(0), 

425 a_s.stride(1), 

426 b_s.stride(1), 

427 b_s.stride(0), 

428 ) 

429 else: 

430 launch = lambda: w8a8_block_fp8_matmul_kernel_general.fn.fn[grid]( 

431 a, 

432 b, 

433 c, 

434 a_s, 

435 b_s, 

436 M, 

437 N, 

438 K, 

439 group_n, 

440 group_k, 

441 a.stride(0), 

442 a.stride(1), 

443 b.stride(1), 

444 b.stride(0), 

445 c.stride(0), 

446 c.stride(1), 

447 a_s.stride(0), 

448 a_s.stride(1), 

449 b_s.stride(1), 

450 b_s.stride(0), 

451 **fixed_meta, 

452 ) 

453 

454 with torch_device_fn.device(a.device): 

455 launch() 

456 return c 

457 

458 

459def w8a8_block_fp8_matmul( 

460 A: torch.Tensor, 

461 B: torch.Tensor, 

462 As: torch.Tensor, 

463 Bs: torch.Tensor, 

464 block_size: List[int], 

465 output_dtype: torch.dtype = torch.float16, 

466) -> torch.Tensor: 

467 device = A.device 

468 assert len(block_size) == 2 

469 block_n, block_k = block_size 

470 

471 # handle non-contiguous inputs if necessary 

472 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1: 

473 A = A.contiguous() 

474 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1: 

475 B = B.contiguous() 

476 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1: 

477 As = As.contiguous() 

478 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1: 

479 Bs = Bs.contiguous() 

480 

481 # checks constraints 

482 assert A.shape[-1] == B.shape[-1], "incompatible dimensions" 

483 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch" 

484 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape" 

485 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D" 

486 

487 M = A.numel() // A.shape[-1] 

488 N, K = B.shape 

489 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension" 

490 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension" 

491 

492 # allocates output 

493 output_shape = A.shape[:-1] + (N,) 

494 c = torch.empty(output_shape, device=device, dtype=output_dtype) 

495 

496 a_2d = A.reshape(M, K) 

497 as_2d = As.reshape(M, As.shape[-1]) 

498 c_2d = c.reshape(M, N) 

499 

500 return general_w8a8_block_fp8_matmul( 

501 a_2d, 

502 B, 

503 c_2d, 

504 as_2d, 

505 Bs, 

506 M, 

507 N, 

508 K, 

509 block_n, 

510 block_k, 

511 ).reshape(c.shape)