Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_matmul.py: 0%

246 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import functools 

2import logging 

3import os 

4from typing import Any, Dict, List, Optional 

5 

6import torch 

7import triton 

8import triton.language as tl 

9import yaml 

10 

11from flag_gems import runtime 

12from flag_gems.runtime import torch_device_fn 

13from flag_gems.utils import libentry, libtuner 

14 

15logger = logging.getLogger( 

16 "flag_gems.runtime.backend._nvidia.hopper.ops.w8a8_block_fp8_matmul" 

17) 

18CACHE_USAGE_THRESHOLD = 0.8 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join( 

21 os.path.dirname(__file__), 

22 "..", 

23 "w8a8_block_fp8_matmul_hopper_expand.yaml", 

24 ) 

25) 

26 

27TMA_ON = False 

28 

29 

30@functools.lru_cache 

31def get_w8a8_block_fp8_hopper_configs(N: int, K: int) -> Optional[Dict[int, Any]]: 

32 device_name = torch.cuda.get_device_name().replace(" ", "_") 

33 name_parts = device_name.split("_") 

34 if any(part.startswith("H20") for part in name_parts): 

35 device_name = "NVIDIA_H20" 

36 file_name = "w8a8_block_fp8_matmul_hopper.yaml" 

37 

38 cfg_file = os.path.join(os.path.dirname(__file__), "..", file_name) 

39 

40 if os.path.exists(cfg_file): 

41 with open(cfg_file) as f: 

42 logger.info( 

43 "Using config from %s for W8A8 block FP8 kernel.", 

44 cfg_file, 

45 ) 

46 dev_data = yaml.safe_load(f).get(device_name, {}) 

47 NK_data = dev_data.get(f"{N},{K}", {}) 

48 

49 result = {} 

50 for k, p in NK_data.items(): 

51 # unpack the list into dictionary 

52 result[int(k)] = { 

53 "BLOCK_SIZE_M": p[0], 

54 "BLOCK_SIZE_N": p[1], 

55 "BLOCK_SIZE_K": p[2], 

56 "GROUP_SIZE_M": p[3], 

57 "num_warps": p[4], 

58 "num_stages": p[5], 

59 } 

60 

61 if not result: 

62 return None 

63 return result 

64 

65 logger.warning( 

66 "Using default W8A8 Block FP8 kernel config. Performance might " 

67 "be sub-optimal! Config file not found at %s", 

68 cfg_file, 

69 ) 

70 return None 

71 

72 

73def _get_placeholder_tuner_configs(pre_hook=None): 

74 # Placeholder config for libtuner initialization before runtime shapes are known. 

75 return [ 

76 triton.Config( 

77 { 

78 "BLOCK_M": 64, 

79 "BLOCK_N": 64, 

80 "BLOCK_K": 128, 

81 "GROUP_M": 8, 

82 }, 

83 num_stages=3, 

84 num_warps=4, 

85 pre_hook=pre_hook, 

86 ) 

87 ] 

88 

89 

90def _get_fixed_matmul_meta(M: int, N: int, K: int, block_n: int, block_k: int): 

91 configs = get_w8a8_block_fp8_hopper_configs(N, K) 

92 if not configs: 

93 return { 

94 "BLOCK_M": 64, 

95 "BLOCK_N": block_n, 

96 "BLOCK_K": block_k, 

97 "GROUP_M": 32, 

98 "num_warps": 4, 

99 "num_stages": 2, 

100 } 

101 

102 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

103 return { 

104 "BLOCK_M": config["BLOCK_SIZE_M"], 

105 "BLOCK_N": config["BLOCK_SIZE_N"], 

106 "BLOCK_K": config["BLOCK_SIZE_K"], 

107 "GROUP_M": config["GROUP_SIZE_M"], 

108 "num_warps": config["num_warps"], 

109 "num_stages": config["num_stages"], 

110 } 

111 

112 

113def is_tma_compatible(a, b, n, k): 

114 """ 

115 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

116 

117 TMA requires 128-bit (16-byte) alignment for memory access. 

118 For FP8 inputs (1 byte/element), both N and K must be multiples of 16 

119 to satisfy the 16-byte alignment requirement. 

120 

121 Args: 

122 a, b: Input tensors 

123 n, k: Matrix dimensions 

124 

125 Returns: 

126 bool: True if compatible with TMA's 128-bit alignment requirement 

127 """ 

128 return ( 

129 a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) 

130 and b.dtype == a.dtype 

131 and TMA_ON 

132 and n % 16 == 0 

133 and k % 16 == 0 

134 ) 

135 

136 

137def matmul_tma_set_block_size_hook(nargs): 

138 BLOCK_M = nargs["BLOCK_M"] 

139 BLOCK_N = nargs["BLOCK_N"] 

140 BLOCK_K = nargs["BLOCK_K"] 

141 if nargs["A_ROW_MAJOR"]: 

142 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

143 else: 

144 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

145 

146 if nargs["B_ROW_MAJOR"]: 

147 # B is stored as [N, K] in row-major order, and the kernel loads an 

148 # [BLOCK_N, BLOCK_K] tile before transposing it to [BLOCK_K, BLOCK_N]. 

149 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

150 else: 

151 # For the column-major case we build the descriptor on B.T with shape 

152 # [K, N], so the loaded tile already has layout [BLOCK_K, BLOCK_N]. 

153 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

154 

155 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

156 

157 

158@libentry() 

159@libtuner( 

160 configs=runtime.ops_get_configs( 

161 "w8a8_block_fp8_general", 

162 pre_hook=None, 

163 yaml_path=EXPAND_CONFIG_FILENAME, 

164 ) 

165 if os.environ.get("USE_FLAGTUNE") == "1" 

166 else _get_placeholder_tuner_configs(pre_hook=None), 

167 key=["M", "N", "K", "stride_am", "stride_bk"], 

168 strategy=runtime.get_expand_config( 

169 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME 

170 )["strategy"] 

171 if os.environ.get("USE_FLAGTUNE") == "1" 

172 else ["align32", "align32", "align32", "align32", "align32"], 

173 warmup=5, 

174 rep=5, 

175) 

176@triton.jit 

177def w8a8_block_fp8_matmul_kernel_general( 

178 A, 

179 B, 

180 C, 

181 As, 

182 Bs, 

183 M, 

184 N, 

185 K, 

186 group_n, 

187 group_k, 

188 stride_am, 

189 stride_ak, 

190 stride_bk, 

191 stride_bn, 

192 stride_cm, 

193 stride_cn, 

194 stride_As_m, 

195 stride_As_k, 

196 stride_Bs_k, 

197 stride_Bs_n, 

198 BLOCK_M: tl.constexpr, 

199 BLOCK_N: tl.constexpr, 

200 BLOCK_K: tl.constexpr, 

201 GROUP_M: tl.constexpr, 

202): 

203 pid = tl.program_id(axis=0) 

204 num_pid_m = tl.cdiv(M, BLOCK_M) 

205 num_pid_n = tl.cdiv(N, BLOCK_N) 

206 num_pid_in_group = GROUP_M * num_pid_n 

207 group_id = pid // num_pid_in_group 

208 first_pid_m = group_id * GROUP_M 

209 group_size_m = min(num_pid_m - first_pid_m, GROUP_M) 

210 pid_m = first_pid_m + (pid % group_size_m) 

211 pid_n = (pid % num_pid_in_group) // group_size_m 

212 

213 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

214 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

215 offs_k = tl.arange(0, BLOCK_K) 

216 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

217 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

218 

219 As_ptrs = As + offs_am * stride_As_m 

220 offs_bsn = offs_bn // group_n 

221 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

222 

223 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

224 for k in range(0, tl.cdiv(K, BLOCK_K)): 

225 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) 

226 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) 

227 

228 k_start = k * BLOCK_K 

229 offs_ks = k_start // group_k 

230 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

231 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

232 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

233 a_ptrs += BLOCK_K * stride_ak 

234 b_ptrs += BLOCK_K * stride_bk 

235 

236 if C.dtype.element_ty == tl.bfloat16: 

237 c = acc.to(tl.bfloat16) 

238 elif C.dtype.element_ty == tl.float16: 

239 c = acc.to(tl.float16) 

240 else: 

241 c = acc.to(tl.float32) 

242 

243 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

244 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

245 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

246 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

247 tl.store(c_ptrs, c, mask=c_mask) 

248 

249 

250@libentry() 

251@libtuner( 

252 configs=runtime.ops_get_configs( 

253 "w8a8_block_fp8_general_tma", 

254 pre_hook=matmul_tma_set_block_size_hook, 

255 yaml_path=EXPAND_CONFIG_FILENAME, 

256 ) 

257 if os.environ.get("USE_FLAGTUNE") == "1" 

258 else _get_placeholder_tuner_configs(pre_hook=matmul_tma_set_block_size_hook), 

259 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

260 strategy=runtime.get_expand_config( 

261 "w8a8_block_fp8_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

262 )["strategy"] 

263 if os.environ.get("USE_FLAGTUNE") == "1" 

264 else ["align32", "align32", "align32", "align32", "align32", "default"], 

265 warmup=5, 

266 rep=5, 

267) 

268@triton.jit 

269def w8a8_block_fp8_matmul_kernel_host_tma( 

270 a_desc, 

271 b_desc, 

272 c_desc, 

273 As, 

274 Bs, 

275 M, 

276 N, 

277 K, 

278 group_n, 

279 group_k, 

280 stride_am, 

281 stride_ak, 

282 stride_bn, 

283 stride_bk, 

284 stride_cm, 

285 stride_cn, 

286 stride_As_m, 

287 stride_As_k, 

288 stride_Bs_n, 

289 stride_Bs_k, 

290 BLOCK_M: tl.constexpr, 

291 BLOCK_N: tl.constexpr, 

292 BLOCK_K: tl.constexpr, 

293 GROUP_M: tl.constexpr, 

294 A_ROW_MAJOR: tl.constexpr, 

295 B_ROW_MAJOR: tl.constexpr, 

296 dtype: tl.constexpr, 

297 enable_warp_specialization=True, 

298): 

299 # matrix multiplication 

300 pid = tl.program_id(0) 

301 grid_m = tl.cdiv(M, BLOCK_M) 

302 grid_n = tl.cdiv(N, BLOCK_N) 

303 # re-order program ID for better L2 performance 

304 width = GROUP_M * grid_n 

305 group_id = pid // width 

306 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

307 pid_m = group_id * GROUP_M + (pid % group_size) 

308 pid_n = (pid % width) // group_size 

309 

310 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

311 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

312 offs_am = offset_am + tl.arange(0, BLOCK_M) 

313 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

314 iters = tl.cdiv(K, BLOCK_K) 

315 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

316 

317 for k in range(iters): 

318 offset_ak = (k * BLOCK_K).to(tl.int32) 

319 

320 if A_ROW_MAJOR: 

321 a = a_desc.load([offset_am, offset_ak]) 

322 else: 

323 a_t = a_desc.load([offset_ak, offset_am]) 

324 a = tl.trans(a_t) 

325 

326 if B_ROW_MAJOR: 

327 b_t = b_desc.load([offset_bn, offset_ak]) 

328 b = tl.trans(b_t) 

329 else: 

330 b = b_desc.load([offset_ak, offset_bn]) 

331 

332 offs_ks = (offset_ak // group_k).to(tl.int32) 

333 a_s = tl.load( 

334 As + offs_am * stride_As_m + offs_ks * stride_As_k, 

335 mask=offs_am < M, 

336 other=0.0, 

337 ) 

338 b_s = tl.load( 

339 Bs + (offs_bn // group_n) * stride_Bs_n + offs_ks * stride_Bs_k, 

340 mask=offs_bn < N, 

341 other=0.0, 

342 ) 

343 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

344 

345 c_desc.store([offset_am, offset_bn], acc.to(c_desc.dtype)) 

346 

347 

348@libentry() 

349@libtuner( 

350 configs=runtime.ops_get_configs( 

351 "w8a8_block_fp8_general_splitk", 

352 yaml_path=EXPAND_CONFIG_FILENAME, 

353 ) 

354 if os.environ.get("USE_FLAGTUNE") == "1" 

355 else _get_placeholder_tuner_configs(pre_hook=None), 

356 key=["M", "N", "K", "stride_am", "stride_bk"], 

357 strategy=runtime.get_expand_config( 

358 "w8a8_block_fp8_general_splitk", yaml_path=EXPAND_CONFIG_FILENAME 

359 )["strategy"] 

360 if os.environ.get("USE_FLAGTUNE") == "1" 

361 else ["align32", "align32", "align32", "align32", "align32"], 

362 warmup=5, 

363 rep=5, 

364) 

365@triton.jit 

366def w8a8_block_fp8_matmul_kernel_splitk( 

367 A, 

368 B, 

369 C, 

370 As, 

371 Bs, 

372 M, 

373 N, 

374 K, 

375 group_n, 

376 group_k, 

377 stride_am, 

378 stride_ak, 

379 stride_bk, 

380 stride_bn, 

381 stride_cm, 

382 stride_cn, 

383 stride_As_m, 

384 stride_As_k, 

385 stride_Bs_k, 

386 stride_Bs_n, 

387 BLOCK_M: tl.constexpr, 

388 BLOCK_N: tl.constexpr, 

389 BLOCK_K: tl.constexpr, 

390 SPLIT_K: tl.constexpr, 

391): 

392 pid = tl.program_id(0) 

393 pid_k = tl.program_id(1) 

394 

395 # grid_m = tl.cdiv(M, BLOCK_M) 

396 grid_n = tl.cdiv(N, BLOCK_N) 

397 pid_m = pid // grid_n 

398 pid_n = pid % grid_n 

399 

400 offset_am = pid_m * BLOCK_M 

401 offset_bn = pid_n * BLOCK_N 

402 offs_am = offset_am + tl.arange(0, BLOCK_M) 

403 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

404 

405 total_k_iters = tl.cdiv(K, BLOCK_K) 

406 k_per_split = tl.cdiv(total_k_iters, SPLIT_K) 

407 k_start = pid_k * k_per_split 

408 k_end = min((pid_k + 1) * k_per_split, total_k_iters) 

409 

410 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

411 for k in range(k_start, k_end): 

412 offset_k = k * BLOCK_K 

413 offs_k = offset_k + tl.arange(0, BLOCK_K) 

414 

415 a = tl.load( 

416 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak, 

417 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K), 

418 other=0.0, 

419 ) 

420 b = tl.load( 

421 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn, 

422 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), 

423 other=0.0, 

424 ) 

425 

426 offs_ks = offset_k // group_k 

427 a_s = tl.load( 

428 As + offs_am * stride_As_m + offs_ks * stride_As_k, 

429 mask=offs_am < M, 

430 other=0.0, 

431 ) 

432 b_s = tl.load( 

433 Bs + offs_ks * stride_Bs_k + (offs_bn // group_n) * stride_Bs_n, 

434 mask=offs_bn < N, 

435 other=0.0, 

436 ) 

437 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] 

438 

439 offs_cm = offset_am + tl.arange(0, BLOCK_M) 

440 offs_cn = offset_bn + tl.arange(0, BLOCK_N) 

441 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn 

442 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

443 if C.dtype.element_ty == tl.bfloat16: 

444 tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mask) 

445 elif C.dtype.element_ty == tl.float16: 

446 tl.atomic_add(c_ptrs, acc.to(tl.float16), mask=mask) 

447 else: 

448 tl.atomic_add(c_ptrs, acc.to(tl.float32), mask=mask) 

449 

450 

451def general_w8a8_block_fp8_matmul(a, b, c, a_s, b_s, M, N, K, group_n, group_k): 

452 logger.debug( 

453 "GEMS w8a8_block_fp8_matmul-hopper, [scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

454 "[A column-major]: %s, [B column-major]: %s", 

455 M, 

456 N, 

457 K, 

458 a.stride(0) == 1, 

459 b.stride(0) == 1, 

460 ) 

461 

462 use_flagtune = os.environ.get("USE_FLAGTUNE") == "1" 

463 

464 # Split-K path for small-N, large-K shapes 

465 if N <= 512 and K == 7168 and M < 8276: 

466 if use_flagtune: 

467 splitk_grid = lambda META: ( 

468 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

469 META["SPLIT_K"], 

470 ) 

471 c.zero_() 

472 with torch_device_fn.device(a.device): 

473 w8a8_block_fp8_matmul_kernel_splitk[splitk_grid]( 

474 a, 

475 b, 

476 c, 

477 a_s, 

478 b_s, 

479 M, 

480 N, 

481 K, 

482 group_n, 

483 group_k, 

484 a.stride(0), 

485 a.stride(1), 

486 b.stride(1), 

487 b.stride(0), 

488 c.stride(0), 

489 c.stride(1), 

490 a_s.stride(0), 

491 a_s.stride(1), 

492 b_s.stride(1), 

493 b_s.stride(0), 

494 ) 

495 else: 

496 SPLITK_BLOCK_K = 128 

497 SPLITK_BLOCK_M = 16 if M <= 16 else 64 

498 SPLITK_BLOCK_N = 64 if N > 256 else 32 

499 

500 grid_m = triton.cdiv(M, SPLITK_BLOCK_M) 

501 grid_n = triton.cdiv(N, SPLITK_BLOCK_N) 

502 grid_mn = grid_m * grid_n 

503 total_k_iters = triton.cdiv(K, SPLITK_BLOCK_K) 

504 

505 SM_COUNT = torch.cuda.get_device_properties(a.device).multi_processor_count 

506 split_k = min(total_k_iters, max(4, 2 * SM_COUNT // max(grid_mn, 1))) 

507 

508 c.zero_() 

509 splitk_grid = (grid_mn, split_k) 

510 

511 with torch_device_fn.device(a.device): 

512 w8a8_block_fp8_matmul_kernel_splitk.fn.fn[splitk_grid]( 

513 a, 

514 b, 

515 c, 

516 a_s, 

517 b_s, 

518 M, 

519 N, 

520 K, 

521 group_n, 

522 group_k, 

523 a.stride(0), 

524 a.stride(1), 

525 b.stride(1), 

526 b.stride(0), 

527 c.stride(0), 

528 c.stride(1), 

529 a_s.stride(0), 

530 a_s.stride(1), 

531 b_s.stride(1), 

532 b_s.stride(0), 

533 BLOCK_M=SPLITK_BLOCK_M, 

534 BLOCK_N=SPLITK_BLOCK_N, 

535 BLOCK_K=SPLITK_BLOCK_K, 

536 SPLIT_K=split_k, 

537 ) 

538 return c 

539 

540 grid = lambda meta: ( 

541 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

542 ) 

543 fixed_meta = ( 

544 None 

545 if use_flagtune 

546 else _get_fixed_matmul_meta(M, N, K, block_n=group_n, block_k=group_k) 

547 ) 

548 

549 if hasattr( 

550 triton.tools.tensor_descriptor, "TensorDescriptor" 

551 ) and is_tma_compatible(a, b, N, K): 

552 a_row_major = a.stride(1) == 1 

553 b_row_major = b.stride(1) == 1 

554 dummy_block = [1, 1] 

555 # triton 3.5.0 

556 from triton.tools.tensor_descriptor import TensorDescriptor 

557 

558 if a_row_major: 

559 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

560 else: 

561 a_desc = TensorDescriptor(a.T, a.T.shape, a.T.stride(), dummy_block) 

562 

563 if b_row_major: 

564 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

565 else: 

566 b_desc = TensorDescriptor(b.T, b.T.shape, b.T.stride(), dummy_block) 

567 

568 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

569 if use_flagtune: 

570 launch = lambda: w8a8_block_fp8_matmul_kernel_host_tma[grid]( 

571 a_desc, 

572 b_desc, 

573 c_desc, 

574 a_s, 

575 b_s, 

576 M, 

577 N, 

578 K, 

579 group_n, 

580 group_k, 

581 a.stride(0), 

582 a.stride(1), 

583 b.stride(0), 

584 b.stride(1), 

585 c.stride(0), 

586 c.stride(1), 

587 a_s.stride(0), 

588 a_s.stride(1), 

589 b_s.stride(0), 

590 b_s.stride(1), 

591 A_ROW_MAJOR=a_row_major, 

592 B_ROW_MAJOR=b_row_major, 

593 dtype=str(a.dtype).split(".")[-1], 

594 ) 

595 else: 

596 # The fixed-config path bypasses libtuner, so we must apply the 

597 # descriptor block-shape update that would normally run via the 

598 # TMA pre_hook before launching the underlying JIT kernel. 

599 matmul_tma_set_block_size_hook( 

600 { 

601 "BLOCK_M": fixed_meta["BLOCK_M"], 

602 "BLOCK_N": fixed_meta["BLOCK_N"], 

603 "BLOCK_K": fixed_meta["BLOCK_K"], 

604 "a_desc": a_desc, 

605 "b_desc": b_desc, 

606 "c_desc": c_desc, 

607 "A_ROW_MAJOR": a_row_major, 

608 "B_ROW_MAJOR": b_row_major, 

609 } 

610 ) 

611 launch = lambda: w8a8_block_fp8_matmul_kernel_host_tma.fn.fn[grid]( 

612 a_desc, 

613 b_desc, 

614 c_desc, 

615 a_s, 

616 b_s, 

617 M, 

618 N, 

619 K, 

620 group_n, 

621 group_k, 

622 a.stride(0), 

623 a.stride(1), 

624 b.stride(0), 

625 b.stride(1), 

626 c.stride(0), 

627 c.stride(1), 

628 a_s.stride(0), 

629 a_s.stride(1), 

630 b_s.stride(0), 

631 b_s.stride(1), 

632 A_ROW_MAJOR=a_row_major, 

633 B_ROW_MAJOR=b_row_major, 

634 dtype=str(a.dtype).split(".")[-1], 

635 **fixed_meta, 

636 ) 

637 

638 with torch_device_fn.device(a.device): 

639 launch() 

640 else: 

641 

642 def alloc_fn(size: int, align: int, stream: Optional[int]): 

643 return torch.empty(size, dtype=torch.int8, device=a.device) 

644 

645 triton.set_allocator(alloc_fn) 

646 if use_flagtune: 

647 launch = lambda: w8a8_block_fp8_matmul_kernel_general[grid]( 

648 a, 

649 b, 

650 c, 

651 a_s, 

652 b_s, 

653 M, 

654 N, 

655 K, 

656 group_n, 

657 group_k, 

658 a.stride(0), 

659 a.stride(1), 

660 b.stride(1), 

661 b.stride(0), 

662 c.stride(0), 

663 c.stride(1), 

664 a_s.stride(0), 

665 a_s.stride(1), 

666 b_s.stride(1), 

667 b_s.stride(0), 

668 ) 

669 else: 

670 launch = lambda: w8a8_block_fp8_matmul_kernel_general.fn.fn[grid]( 

671 a, 

672 b, 

673 c, 

674 a_s, 

675 b_s, 

676 M, 

677 N, 

678 K, 

679 group_n, 

680 group_k, 

681 a.stride(0), 

682 a.stride(1), 

683 b.stride(1), 

684 b.stride(0), 

685 c.stride(0), 

686 c.stride(1), 

687 a_s.stride(0), 

688 a_s.stride(1), 

689 b_s.stride(1), 

690 b_s.stride(0), 

691 **fixed_meta, 

692 ) 

693 

694 with torch_device_fn.device(a.device): 

695 launch() 

696 return c 

697 

698 

699def w8a8_block_fp8_matmul( 

700 A: torch.Tensor, 

701 B: torch.Tensor, 

702 As: torch.Tensor, 

703 Bs: torch.Tensor, 

704 block_size: List[int], 

705 output_dtype: torch.dtype = torch.float16, 

706) -> torch.Tensor: 

707 device = A.device 

708 assert len(block_size) == 2 

709 block_n, block_k = block_size 

710 

711 # handle non-contiguous inputs if necessary 

712 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1: 

713 A = A.contiguous() 

714 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1: 

715 B = B.contiguous() 

716 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1: 

717 As = As.contiguous() 

718 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1: 

719 Bs = Bs.contiguous() 

720 

721 # checks constraints 

722 assert A.shape[-1] == B.shape[-1], "incompatible dimensions" 

723 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch" 

724 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape" 

725 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D" 

726 

727 M = A.numel() // A.shape[-1] 

728 N, K = B.shape 

729 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension" 

730 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension" 

731 

732 # allocates output 

733 output_shape = A.shape[:-1] + (N,) 

734 c = torch.empty(output_shape, device=device, dtype=output_dtype) 

735 

736 a_2d = A.reshape(M, K) 

737 as_2d = As.reshape(M, As.shape[-1]) 

738 c_2d = c.reshape(M, N) 

739 

740 return general_w8a8_block_fp8_matmul( 

741 a_2d, 

742 B, 

743 c_2d, 

744 as_2d, 

745 Bs, 

746 M, 

747 N, 

748 K, 

749 block_n, 

750 block_k, 

751 ).reshape(c.shape)