Coverage for src/flag_gems/runtime/backend/_mthreads/ops/w8a8_block_fp8_matmul.py: 0%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3from typing import List 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor 

15 

16logger = logging.getLogger( 

17 "flag_gems.runtime.backend._mthreads.ops.w8a8_block_fp8_matmul" 

18) 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join( 

21 os.path.dirname(__file__), 

22 "..", 

23 "w8a8_block_fp8_matmul_mthreads_expand.yaml", 

24 ) 

25) 

26 

27SQMMA_ON = False 

28 

29 

30def is_supported_sqmma_layout(tensor): 

31 return tensor.is_contiguous() or ( 

32 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

33 ) 

34 

35 

36def is_sqmma_compatible(a, b, output_dtype, n, k): 

37 return ( 

38 a.dim() == 2 

39 and SQMMA_ON 

40 and b.dim() == 2 

41 and a.dtype == b.dtype == torch.float8_e4m3fn 

42 and output_dtype in (torch.float16, torch.bfloat16) 

43 and is_supported_sqmma_layout(a) 

44 and is_supported_sqmma_layout(b) 

45 and n % 16 == 0 

46 and k % 16 == 0 

47 ) 

48 

49 

50def get_triton_type(elem_type): 

51 type_map = { 

52 torch.float16: tl.float16, 

53 torch.bfloat16: tl.bfloat16, 

54 torch.float32: tl.float32, 

55 torch.float8_e4m3fn: tl.float8e4nv, 

56 } 

57 return type_map.get(elem_type, None) 

58 

59 

60def matmul_get_configs(): 

61 return [ 

62 triton.Config( 

63 { 

64 "BLOCK_M": 64, 

65 "BLOCK_N": 64, 

66 "BLOCK_K": 128, 

67 "GROUP_M": 8, 

68 }, 

69 num_stages=3, 

70 num_warps=4, 

71 ) 

72 ] 

73 

74 

75@libentry() 

76@libtuner( 

77 configs=runtime.ops_get_configs( 

78 "w8a8_block_fp8_general", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME 

79 ) 

80 if os.environ.get("USE_FLAGTUNE") == "1" 

81 else matmul_get_configs(), 

82 key=["M", "N", "K", "stride_am", "stride_bk"], 

83 strategy=runtime.get_expand_config( 

84 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME 

85 )["strategy"] 

86 if os.environ.get("USE_FLAGTUNE") == "1" 

87 else ["align32", "align32", "align32", "align32", "align32"], 

88 warmup=5, 

89 rep=5, 

90) 

91@triton.jit 

92def w8a8_block_fp8_matmul_kernel( 

93 A, 

94 B, 

95 C, 

96 As, 

97 Bs, 

98 M, 

99 N, 

100 K, 

101 group_n, 

102 group_k, 

103 stride_am, 

104 stride_ak, 

105 stride_bk, 

106 stride_bn, 

107 stride_cm, 

108 stride_cn, 

109 stride_As_m, 

110 stride_As_k, 

111 stride_Bs_k, 

112 stride_Bs_n, 

113 BLOCK_M: tl.constexpr, 

114 BLOCK_N: tl.constexpr, 

115 BLOCK_K: tl.constexpr, 

116 GROUP_M: tl.constexpr, 

117): 

118 pid = tl.program_id(axis=0) 

119 num_pid_m = tl.cdiv(M, BLOCK_M) 

120 num_pid_n = tl.cdiv(N, BLOCK_N) 

121 num_pid_in_group = GROUP_M * num_pid_n 

122 group_id = pid // num_pid_in_group 

123 first_pid_m = group_id * GROUP_M 

124 group_size_m = min(num_pid_m - first_pid_m, GROUP_M) 

125 pid_m = first_pid_m + (pid % group_size_m) 

126 pid_n = (pid % num_pid_in_group) // group_size_m 

127 

128 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

129 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

130 offs_k = tl.arange(0, BLOCK_K) 

131 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

132 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

133 

134 As_ptrs = As + offs_am * stride_As_m 

135 offs_bsn = offs_bn // group_n 

136 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

137 

138 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

139 for k in range(0, tl.cdiv(K, BLOCK_K)): 

140 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) 

141 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) 

142 

143 k_start = k * BLOCK_K 

144 offs_ks = k_start // group_k 

145 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

146 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

147 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] 

148 a_ptrs += BLOCK_K * stride_ak 

149 b_ptrs += BLOCK_K * stride_bk 

150 

151 if C.dtype.element_ty == tl.bfloat16: 

152 c = accumulator.to(tl.bfloat16) 

153 elif C.dtype.element_ty == tl.float16: 

154 c = accumulator.to(tl.float16) 

155 else: 

156 c = accumulator.to(tl.float32) 

157 

158 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

159 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

160 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

161 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

162 tl.store(c_ptrs, c, mask=c_mask) 

163 

164 

165def sqmma_descriptor_pre_hook(nargs): 

166 a = nargs["A"] 

167 b = nargs["B"] 

168 c = nargs["C"] 

169 block_m = nargs["BLOCK_M"] 

170 block_n = nargs["BLOCK_N"] 

171 block_k = nargs["BLOCK_K"] 

172 device = c.device 

173 

174 nargs["a_desc_ptr"].copy_( 

175 get_cached_tma_device_descriptor(a, block_m, block_k, device) 

176 ) 

177 nargs["b_desc_ptr"].copy_( 

178 get_cached_tma_device_descriptor(b, block_k, block_n, device) 

179 ) 

180 nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) 

181 

182 

183def sqmma_get_configs(pre_hook=sqmma_descriptor_pre_hook): 

184 return [ 

185 triton.Config( 

186 { 

187 "BLOCK_M": 64, 

188 "BLOCK_N": 64, 

189 "BLOCK_K": 128, 

190 "GROUP_M": 8, 

191 }, 

192 num_stages=3, 

193 num_warps=4, 

194 pre_hook=pre_hook, 

195 ) 

196 ] 

197 

198 

199@libentry() 

200@libtuner( 

201 configs=runtime.ops_get_configs( 

202 "w8a8_block_fp8_general_tma", 

203 pre_hook=sqmma_descriptor_pre_hook, 

204 yaml_path=EXPAND_CONFIG_FILENAME, 

205 ) 

206 if os.environ.get("USE_FLAGTUNE") == "1" 

207 else sqmma_get_configs(), 

208 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

209 strategy=runtime.get_expand_config( 

210 "w8a8_block_fp8_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

211 )["strategy"] 

212 if os.environ.get("USE_FLAGTUNE") == "1" 

213 else ["align32", "align32", "align32", "align32", "align32", "default"], 

214 warmup=5, 

215 rep=5, 

216) 

217@triton.jit 

218def w8a8_block_fp8_matmul_sqmma_kernel( 

219 A, 

220 B, 

221 C, 

222 As, 

223 Bs, 

224 a_desc_ptr, 

225 b_desc_ptr, 

226 c_desc_ptr, 

227 M, 

228 N, 

229 K, 

230 group_n, 

231 group_k, 

232 stride_am, 

233 stride_bk, 

234 stride_As_m, 

235 stride_As_k, 

236 stride_Bs_n, 

237 stride_Bs_k, 

238 dtype: tl.constexpr, 

239 input_dtype: tl.constexpr, 

240 output_dtype: tl.constexpr, 

241 GROUP_M: tl.constexpr, 

242 BLOCK_M: tl.constexpr, 

243 BLOCK_N: tl.constexpr, 

244 BLOCK_K: tl.constexpr, 

245 is_transpose_a: tl.constexpr = False, 

246 is_transpose_b: tl.constexpr = True, 

247): 

248 pid = tle.program_id(0) 

249 grid_m = tl.cdiv(M, BLOCK_M) 

250 grid_n = tl.cdiv(N, BLOCK_N) 

251 width = GROUP_M * grid_n 

252 group_id = pid // width 

253 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

254 pid_m = group_id * GROUP_M + (pid % group_size) 

255 pid_n = (pid % width) // group_size 

256 

257 offs_am = (pid_m * BLOCK_M).to(tl.int32) 

258 offs_bn = (pid_n * BLOCK_N).to(tl.int32) 

259 offs_k = tl.zeros((), dtype=tl.int32) 

260 

261 row_offset = offs_am + tl.arange(0, BLOCK_M) 

262 col_offset = offs_bn + tl.arange(0, BLOCK_N) 

263 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

264 tme_load_input_dtype = input_dtype 

265 c_store_dtype = output_dtype 

266 

267 for _ in range(0, tl.cdiv(K, BLOCK_K)): 

268 a = tl._experimental_descriptor_load( 

269 a_desc_ptr, 

270 [offs_am, offs_k], 

271 [BLOCK_M, BLOCK_K], 

272 tme_load_input_dtype, 

273 is_transpose_a, 

274 ) 

275 b = tl._experimental_descriptor_load( 

276 b_desc_ptr, 

277 [offs_k, offs_bn], 

278 [BLOCK_K, BLOCK_N], 

279 tme_load_input_dtype, 

280 is_transpose_b, 

281 ) 

282 

283 scale_k = offs_k // group_k 

284 a_s = tl.load( 

285 As + row_offset * stride_As_m + scale_k * stride_As_k, 

286 mask=row_offset < M, 

287 other=0.0, 

288 ) 

289 b_s = tl.load( 

290 Bs + (col_offset // group_n) * stride_Bs_n + scale_k * stride_Bs_k, 

291 mask=col_offset < N, 

292 other=0.0, 

293 ) 

294 acc += ( 

295 tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

296 * a_s[:, None] 

297 * b_s[None, :] 

298 ) 

299 offs_k += BLOCK_K 

300 

301 tl._experimental_descriptor_store( 

302 c_desc_ptr, acc.to(c_store_dtype), [offs_am, offs_bn] 

303 ) 

304 

305 

306def general_w8a8_block_fp8_matmul( 

307 a, 

308 b, 

309 c, 

310 a_s, 

311 b_s, 

312 M, 

313 N, 

314 K, 

315 group_n, 

316 group_k, 

317): 

318 logger.debug( 

319 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(general), [shape info]: [-, %s, %s, %s](batch, M, N, K)", 

320 M, 

321 N, 

322 K, 

323 ) 

324 grid = lambda meta: ( 

325 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

326 ) 

327 

328 with torch_device_fn.device(a.device): 

329 w8a8_block_fp8_matmul_kernel[grid]( 

330 a, 

331 b, 

332 c, 

333 a_s, 

334 b_s, 

335 M, 

336 N, 

337 K, 

338 group_n, 

339 group_k, 

340 a.stride(0), 

341 a.stride(1), 

342 b.stride(1), 

343 b.stride(0), 

344 c.stride(0), 

345 c.stride(1), 

346 a_s.stride(0), 

347 a_s.stride(1), 

348 b_s.stride(1), 

349 b_s.stride(0), 

350 ) 

351 return c 

352 

353 

354def sqmma_w8a8_block_fp8_matmul( 

355 a, 

356 b, 

357 c, 

358 a_s, 

359 b_s, 

360 M, 

361 N, 

362 K, 

363 group_n, 

364 group_k, 

365): 

366 logger.debug( 

367 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(sqmma), [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

368 "[A column-major]: %s, [B column-major]: %s", 

369 M, 

370 N, 

371 K, 

372 a.stride(0) == 1, 

373 b.stride(0) == 1, 

374 ) 

375 device = a.device 

376 is_transpose_a = False 

377 is_transpose_b = True 

378 

379 if not a.is_contiguous(): 

380 if a.stride(0) == 1 and a.stride(1) == a.shape[0]: 

381 is_transpose_a = True 

382 else: 

383 a = a.contiguous() 

384 if not b.is_contiguous(): 

385 if b.stride(0) == 1 and b.stride(1) == b.shape[0]: 

386 is_transpose_b = False 

387 else: 

388 b = b.contiguous() 

389 is_transpose_b = True 

390 

391 desc_a = torch.empty((64,), dtype=torch.int8, device=device) 

392 desc_b = torch.empty((64,), dtype=torch.int8, device=device) 

393 desc_c = torch.empty((64,), dtype=torch.int8, device=device) 

394 

395 grid = lambda meta: ( 

396 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

397 1, 

398 1, 

399 ) 

400 

401 with torch_device_fn.device(device): 

402 w8a8_block_fp8_matmul_sqmma_kernel[grid]( 

403 a, 

404 b, 

405 c, 

406 a_s, 

407 b_s, 

408 desc_a, 

409 desc_b, 

410 desc_c, 

411 M, 

412 N, 

413 K, 

414 group_n, 

415 group_k, 

416 a.stride(0), 

417 b.stride(1), 

418 a_s.stride(0), 

419 a_s.stride(1), 

420 b_s.stride(0), 

421 b_s.stride(1), 

422 dtype=str(a.dtype).split(".")[-1], 

423 input_dtype=get_triton_type(a.dtype), 

424 output_dtype=get_triton_type(c.dtype), 

425 is_transpose_a=is_transpose_a, 

426 is_transpose_b=is_transpose_b, 

427 ) 

428 return c 

429 

430 

431def w8a8_block_fp8_matmul( 

432 A: torch.Tensor, 

433 B: torch.Tensor, 

434 As: torch.Tensor, 

435 Bs: torch.Tensor, 

436 block_size: List[int], 

437 output_dtype: torch.dtype = torch.bfloat16, 

438) -> torch.Tensor: 

439 device = A.device 

440 assert len(block_size) == 2 

441 block_n, block_k = block_size 

442 

443 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1: 

444 A = A.contiguous() 

445 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1: 

446 B = B.contiguous() 

447 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1: 

448 As = As.contiguous() 

449 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1: 

450 Bs = Bs.contiguous() 

451 

452 assert A.shape[-1] == B.shape[-1], "incompatible dimensions" 

453 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch" 

454 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape" 

455 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D" 

456 

457 M = A.numel() // A.shape[-1] 

458 N, K = B.shape 

459 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension" 

460 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension" 

461 

462 output_shape = A.shape[:-1] + (N,) 

463 c = torch.empty(output_shape, device=device, dtype=output_dtype) 

464 

465 a_2d = A.reshape(M, K) 

466 as_2d = As.reshape(M, As.shape[-1]) 

467 c_2d = c.reshape(M, N) 

468 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA") 

469 os.environ["MUSA_ENABLE_SQMMA"] = "1" 

470 try: 

471 if is_sqmma_compatible(a_2d, B, output_dtype, N, K): 

472 return sqmma_w8a8_block_fp8_matmul( 

473 a_2d, 

474 B, 

475 c_2d, 

476 as_2d, 

477 Bs, 

478 M, 

479 N, 

480 K, 

481 block_n, 

482 block_k, 

483 ).reshape(c.shape) 

484 

485 return general_w8a8_block_fp8_matmul( 

486 a_2d, 

487 B, 

488 c_2d, 

489 as_2d, 

490 Bs, 

491 M, 

492 N, 

493 K, 

494 block_n, 

495 block_k, 

496 ).reshape(c.shape) 

497 finally: 

498 if prev_sqmma is None: 

499 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

500 else: 

501 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma