Coverage for src/flag_gems/runtime/backend/_mthreads/ops/w8a8_block_fp8_matmul.py: 0%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import os 

3from typing import List 

4 

5import torch 

6import triton 

7import triton.language as tl 

8from triton.tools.tensor_descriptor import TensorDescriptor 

9 

10from flag_gems import runtime 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as ext 

14 

15logger = logging.getLogger( 

16 "flag_gems.runtime.backend._mthreads.ops.w8a8_block_fp8_matmul" 

17) 

18EXPAND_CONFIG_FILENAME = os.path.normpath( 

19 os.path.join( 

20 os.path.dirname(__file__), 

21 "..", 

22 "w8a8_block_fp8_matmul_mthreads_expand.yaml", 

23 ) 

24) 

25 

26SQMMA_ON = False 

27 

28 

29def is_supported_sqmma_layout(tensor): 

30 return tensor.is_contiguous() or ( 

31 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

32 ) 

33 

34 

35def is_sqmma_compatible(a, b, output_dtype, n, k): 

36 return ( 

37 a.dim() == 2 

38 and SQMMA_ON 

39 and b.dim() == 2 

40 and a.dtype == b.dtype == torch.float8_e4m3fn 

41 and output_dtype in (torch.float16, torch.bfloat16) 

42 and is_supported_sqmma_layout(a) 

43 and is_supported_sqmma_layout(b) 

44 and n % 16 == 0 

45 and k % 16 == 0 

46 ) 

47 

48 

49def matmul_get_configs(): 

50 return [ 

51 triton.Config( 

52 { 

53 "BLOCK_M": 64, 

54 "BLOCK_N": 64, 

55 "BLOCK_K": 128, 

56 "GROUP_M": 8, 

57 }, 

58 num_stages=3, 

59 num_warps=4, 

60 ) 

61 ] 

62 

63 

64@libentry() 

65@libtuner( 

66 configs=matmul_get_configs(), 

67 key=["M", "N", "K", "stride_am", "stride_bk"], 

68 strategy=["align32", "align32", "align32", "align32", "align32"], 

69 warmup=5, 

70 rep=5, 

71) 

72@triton.jit 

73def w8a8_block_fp8_matmul_kernel( 

74 A, 

75 B, 

76 C, 

77 As, 

78 Bs, 

79 M, 

80 N, 

81 K, 

82 group_n, 

83 group_k, 

84 stride_am, 

85 stride_ak, 

86 stride_bk, 

87 stride_bn, 

88 stride_cm, 

89 stride_cn, 

90 stride_As_m, 

91 stride_As_k, 

92 stride_Bs_k, 

93 stride_Bs_n, 

94 BLOCK_M: tl.constexpr, 

95 BLOCK_N: tl.constexpr, 

96 BLOCK_K: tl.constexpr, 

97 GROUP_M: tl.constexpr, 

98): 

99 pid = tl.program_id(axis=0) 

100 num_pid_m = tl.cdiv(M, BLOCK_M) 

101 num_pid_n = tl.cdiv(N, BLOCK_N) 

102 num_pid_in_group = GROUP_M * num_pid_n 

103 group_id = pid // num_pid_in_group 

104 first_pid_m = group_id * GROUP_M 

105 group_size_m = min(num_pid_m - first_pid_m, GROUP_M) 

106 pid_m = first_pid_m + (pid % group_size_m) 

107 pid_n = (pid % num_pid_in_group) // group_size_m 

108 

109 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

110 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

111 offs_k = tl.arange(0, BLOCK_K) 

112 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

113 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

114 

115 As_ptrs = As + offs_am * stride_As_m 

116 offs_bsn = offs_bn // group_n 

117 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

118 

119 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

120 for k in range(0, tl.cdiv(K, BLOCK_K)): 

121 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) 

122 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) 

123 

124 k_start = k * BLOCK_K 

125 offs_ks = k_start // group_k 

126 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

127 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

128 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] 

129 a_ptrs += BLOCK_K * stride_ak 

130 b_ptrs += BLOCK_K * stride_bk 

131 

132 if C.dtype.element_ty == tl.bfloat16: 

133 c = accumulator.to(tl.bfloat16) 

134 elif C.dtype.element_ty == tl.float16: 

135 c = accumulator.to(tl.float16) 

136 else: 

137 c = accumulator.to(tl.float32) 

138 

139 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

140 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

141 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

142 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

143 tl.store(c_ptrs, c, mask=c_mask) 

144 

145 

146def sqmma_descriptor_pre_hook(nargs): 

147 nargs["a_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_K"]] 

148 nargs["b_desc"].block_shape = [nargs["BLOCK_K"], nargs["BLOCK_N"]] 

149 nargs["c_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_N"]] 

150 

151 

152@libentry() 

153@libtuner( 

154 configs=runtime.ops_get_configs( 

155 "w8a8_block_fp8_general_tma", 

156 pre_hook=sqmma_descriptor_pre_hook, 

157 yaml_path=EXPAND_CONFIG_FILENAME, 

158 ) 

159 if os.environ.get("USE_FLAGTUNE") == "1" 

160 else [ 

161 triton.Config( 

162 {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, 

163 num_stages=3, 

164 num_warps=4, 

165 pre_hook=sqmma_descriptor_pre_hook, 

166 ) 

167 ], 

168 key=["M", "N", "K"], 

169 strategy=runtime.get_expand_config( 

170 "w8a8_block_fp8_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

171 )["strategy"][:3] 

172 if os.environ.get("USE_FLAGTUNE") == "1" 

173 else ["align32", "align32", "align32"], 

174 warmup=5, 

175 rep=5, 

176) 

177@triton.jit 

178def w8a8_block_fp8_matmul_sqmma_kernel( 

179 a_desc, 

180 b_desc, 

181 c_desc, 

182 As, 

183 Bs, 

184 M, 

185 N, 

186 K, 

187 group_n, 

188 group_k, 

189 stride_As_m, 

190 stride_As_k, 

191 stride_Bs_n, 

192 stride_Bs_k, 

193 GROUP_M: tl.constexpr, 

194 BLOCK_M: tl.constexpr, 

195 BLOCK_N: tl.constexpr, 

196 BLOCK_K: tl.constexpr, 

197): 

198 pid = ext.program_id(0) 

199 grid_m = tl.cdiv(M, BLOCK_M) 

200 grid_n = tl.cdiv(N, BLOCK_N) 

201 width = GROUP_M * grid_n 

202 group_id = pid // width 

203 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

204 pid_m = group_id * GROUP_M + (pid % group_size) 

205 pid_n = (pid % width) // group_size 

206 

207 offs_am = (pid_m * BLOCK_M).to(tl.int32) 

208 offs_bn = (pid_n * BLOCK_N).to(tl.int32) 

209 offs_k = tl.zeros((), dtype=tl.int32) 

210 

211 row_offset = offs_am + tl.arange(0, BLOCK_M) 

212 col_offset = offs_bn + tl.arange(0, BLOCK_N) 

213 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

214 

215 for _ in range(0, tl.cdiv(K, BLOCK_K)): 

216 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

217 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

218 

219 scale_k = offs_k // group_k 

220 a_s = tl.load( 

221 As + row_offset * stride_As_m + scale_k * stride_As_k, 

222 mask=row_offset < M, 

223 other=0.0, 

224 ) 

225 b_s = tl.load( 

226 Bs + (col_offset // group_n) * stride_Bs_n + scale_k * stride_Bs_k, 

227 mask=col_offset < N, 

228 other=0.0, 

229 ) 

230 acc += ( 

231 tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

232 * a_s[:, None] 

233 * b_s[None, :] 

234 ) 

235 offs_k += BLOCK_K 

236 

237 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], acc.to(c_desc.dtype)) 

238 

239 

240def general_w8a8_block_fp8_matmul( 

241 a, 

242 b, 

243 c, 

244 a_s, 

245 b_s, 

246 M, 

247 N, 

248 K, 

249 group_n, 

250 group_k, 

251): 

252 logger.debug( 

253 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(general), [shape info]: [-, %s, %s, %s](batch, M, N, K)", 

254 M, 

255 N, 

256 K, 

257 ) 

258 grid = lambda meta: ( 

259 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

260 ) 

261 

262 with torch_device_fn.device(a.device): 

263 w8a8_block_fp8_matmul_kernel[grid]( 

264 a, 

265 b, 

266 c, 

267 a_s, 

268 b_s, 

269 M, 

270 N, 

271 K, 

272 group_n, 

273 group_k, 

274 a.stride(0), 

275 a.stride(1), 

276 b.stride(1), 

277 b.stride(0), 

278 c.stride(0), 

279 c.stride(1), 

280 a_s.stride(0), 

281 a_s.stride(1), 

282 b_s.stride(1), 

283 b_s.stride(0), 

284 ) 

285 return c 

286 

287 

288def sqmma_w8a8_block_fp8_matmul( 

289 a, 

290 b, 

291 c, 

292 a_s, 

293 b_s, 

294 M, 

295 N, 

296 K, 

297 group_n, 

298 group_k, 

299): 

300 logger.debug( 

301 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(sqmma), [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

302 "[A column-major]: %s, [B column-major]: %s", 

303 M, 

304 N, 

305 K, 

306 a.stride(0) == 1, 

307 b.stride(0) == 1, 

308 ) 

309 device = a.device 

310 if not a.is_contiguous(): 

311 a = a.contiguous() 

312 if not b.is_contiguous(): 

313 b = b.contiguous() 

314 

315 desc_a = TensorDescriptor.from_tensor(a, [1, 1]) 

316 desc_b = TensorDescriptor.from_tensor(b, [1, 1]) 

317 desc_c = TensorDescriptor.from_tensor(c, [1, 1]) 

318 

319 grid = lambda meta: ( 

320 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

321 1, 

322 1, 

323 ) 

324 

325 with torch_device_fn.device(device): 

326 w8a8_block_fp8_matmul_sqmma_kernel[grid]( 

327 desc_a, 

328 desc_b, 

329 desc_c, 

330 a_s, 

331 b_s, 

332 M, 

333 N, 

334 K, 

335 group_n, 

336 group_k, 

337 a_s.stride(0), 

338 a_s.stride(1), 

339 b_s.stride(0), 

340 b_s.stride(1), 

341 ) 

342 return c 

343 

344 

345def w8a8_block_fp8_matmul( 

346 A: torch.Tensor, 

347 B: torch.Tensor, 

348 As: torch.Tensor, 

349 Bs: torch.Tensor, 

350 block_size: List[int], 

351 output_dtype: torch.dtype = torch.bfloat16, 

352) -> torch.Tensor: 

353 device = A.device 

354 assert len(block_size) == 2 

355 block_n, block_k = block_size 

356 

357 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1: 

358 A = A.contiguous() 

359 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1: 

360 B = B.contiguous() 

361 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1: 

362 As = As.contiguous() 

363 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1: 

364 Bs = Bs.contiguous() 

365 

366 assert A.shape[-1] == B.shape[-1], "incompatible dimensions" 

367 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch" 

368 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape" 

369 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D" 

370 

371 M = A.numel() // A.shape[-1] 

372 N, K = B.shape 

373 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension" 

374 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension" 

375 

376 output_shape = A.shape[:-1] + (N,) 

377 c = torch.empty(output_shape, device=device, dtype=output_dtype) 

378 

379 a_2d = A.reshape(M, K) 

380 as_2d = As.reshape(M, As.shape[-1]) 

381 c_2d = c.reshape(M, N) 

382 if is_sqmma_compatible(a_2d, B, output_dtype, N, K): 

383 return sqmma_w8a8_block_fp8_matmul( 

384 a_2d, 

385 B, 

386 c_2d, 

387 as_2d, 

388 Bs, 

389 M, 

390 N, 

391 K, 

392 block_n, 

393 block_k, 

394 ).reshape(c.shape) 

395 

396 return general_w8a8_block_fp8_matmul( 

397 a_2d, 

398 B, 

399 c_2d, 

400 as_2d, 

401 Bs, 

402 M, 

403 N, 

404 K, 

405 block_n, 

406 block_k, 

407 ).reshape(c.shape)