Coverage for src/flag_gems/runtime/backend/_mthreads/ops/w8a8_block_fp8_matmul.py: 0%

142 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import os 

3from typing import List 

4 

5import torch 

6import triton 

7import triton.language as tl 

8from triton.tools.tensor_descriptor import TensorDescriptor 

9 

10from flag_gems import runtime 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as ext 

14 

15logger = logging.getLogger( 

16 "flag_gems.runtime.backend._mthreads.ops.w8a8_block_fp8_matmul" 

17) 

18EXPAND_CONFIG_FILENAME = os.path.normpath( 

19 os.path.join( 

20 os.path.dirname(__file__), 

21 "..", 

22 "w8a8_block_fp8_matmul_mthreads_expand.yaml", 

23 ) 

24) 

25 

26SQMMA_ON = False 

27 

28 

29def is_supported_sqmma_layout(tensor): 

30 return tensor.is_contiguous() or ( 

31 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

32 ) 

33 

34 

35def is_sqmma_compatible(a, b, output_dtype, n, k): 

36 return ( 

37 a.dim() == 2 

38 and SQMMA_ON 

39 and b.dim() == 2 

40 and a.dtype == b.dtype == torch.float8_e4m3fn 

41 and output_dtype in (torch.float16, torch.bfloat16) 

42 and is_supported_sqmma_layout(a) 

43 and is_supported_sqmma_layout(b) 

44 and n % 16 == 0 

45 and k % 16 == 0 

46 ) 

47 

48 

49def get_triton_type(elem_type): 

50 type_map = { 

51 torch.float16: tl.float16, 

52 torch.bfloat16: tl.bfloat16, 

53 torch.float32: tl.float32, 

54 torch.float8_e4m3fn: tl.float8e4nv, 

55 } 

56 return type_map.get(elem_type, None) 

57 

58 

59def matmul_get_configs(): 

60 return [ 

61 triton.Config( 

62 { 

63 "BLOCK_M": 64, 

64 "BLOCK_N": 64, 

65 "BLOCK_K": 128, 

66 "GROUP_M": 8, 

67 }, 

68 num_stages=3, 

69 num_warps=4, 

70 ) 

71 ] 

72 

73 

74@libentry() 

75@libtuner( 

76 configs=runtime.ops_get_configs( 

77 "w8a8_block_fp8_general", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME 

78 ) 

79 if os.environ.get("USE_FLAGTUNE") == "1" 

80 else matmul_get_configs(), 

81 key=["M", "N", "K", "stride_am", "stride_bk"], 

82 strategy=runtime.get_expand_config( 

83 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME 

84 )["strategy"] 

85 if os.environ.get("USE_FLAGTUNE") == "1" 

86 else ["align32", "align32", "align32", "align32", "align32"], 

87 warmup=5, 

88 rep=5, 

89) 

90@triton.jit 

91def w8a8_block_fp8_matmul_kernel( 

92 A, 

93 B, 

94 C, 

95 As, 

96 Bs, 

97 M, 

98 N, 

99 K, 

100 group_n, 

101 group_k, 

102 stride_am, 

103 stride_ak, 

104 stride_bk, 

105 stride_bn, 

106 stride_cm, 

107 stride_cn, 

108 stride_As_m, 

109 stride_As_k, 

110 stride_Bs_k, 

111 stride_Bs_n, 

112 BLOCK_M: tl.constexpr, 

113 BLOCK_N: tl.constexpr, 

114 BLOCK_K: tl.constexpr, 

115 GROUP_M: tl.constexpr, 

116): 

117 pid = tl.program_id(axis=0) 

118 num_pid_m = tl.cdiv(M, BLOCK_M) 

119 num_pid_n = tl.cdiv(N, BLOCK_N) 

120 num_pid_in_group = GROUP_M * num_pid_n 

121 group_id = pid // num_pid_in_group 

122 first_pid_m = group_id * GROUP_M 

123 group_size_m = min(num_pid_m - first_pid_m, GROUP_M) 

124 pid_m = first_pid_m + (pid % group_size_m) 

125 pid_n = (pid % num_pid_in_group) // group_size_m 

126 

127 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

128 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

129 offs_k = tl.arange(0, BLOCK_K) 

130 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

131 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

132 

133 As_ptrs = As + offs_am * stride_As_m 

134 offs_bsn = offs_bn // group_n 

135 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

136 

137 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

138 for k in range(0, tl.cdiv(K, BLOCK_K)): 

139 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) 

140 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) 

141 

142 k_start = k * BLOCK_K 

143 offs_ks = k_start // group_k 

144 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

145 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

146 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] 

147 a_ptrs += BLOCK_K * stride_ak 

148 b_ptrs += BLOCK_K * stride_bk 

149 

150 if C.dtype.element_ty == tl.bfloat16: 

151 c = accumulator.to(tl.bfloat16) 

152 elif C.dtype.element_ty == tl.float16: 

153 c = accumulator.to(tl.float16) 

154 else: 

155 c = accumulator.to(tl.float32) 

156 

157 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

158 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

159 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

160 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

161 tl.store(c_ptrs, c, mask=c_mask) 

162 

163 

164@triton.jit 

165def w8a8_block_fp8_matmul_sqmma_kernel( 

166 a_desc, 

167 b_desc, 

168 c_desc, 

169 As, 

170 Bs, 

171 M, 

172 N, 

173 K, 

174 group_n, 

175 group_k, 

176 stride_As_m, 

177 stride_As_k, 

178 stride_Bs_n, 

179 stride_Bs_k, 

180 GROUP_M: tl.constexpr, 

181 BLOCK_M: tl.constexpr, 

182 BLOCK_N: tl.constexpr, 

183 BLOCK_K: tl.constexpr, 

184): 

185 pid = ext.program_id(0) 

186 grid_m = tl.cdiv(M, BLOCK_M) 

187 grid_n = tl.cdiv(N, BLOCK_N) 

188 width = GROUP_M * grid_n 

189 group_id = pid // width 

190 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

191 pid_m = group_id * GROUP_M + (pid % group_size) 

192 pid_n = (pid % width) // group_size 

193 

194 offs_am = (pid_m * BLOCK_M).to(tl.int32) 

195 offs_bn = (pid_n * BLOCK_N).to(tl.int32) 

196 offs_k = tl.zeros((), dtype=tl.int32) 

197 

198 row_offset = offs_am + tl.arange(0, BLOCK_M) 

199 col_offset = offs_bn + tl.arange(0, BLOCK_N) 

200 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

201 

202 for _ in range(0, tl.cdiv(K, BLOCK_K)): 

203 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

204 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

205 

206 scale_k = offs_k // group_k 

207 a_s = tl.load( 

208 As + row_offset * stride_As_m + scale_k * stride_As_k, 

209 mask=row_offset < M, 

210 other=0.0, 

211 ) 

212 b_s = tl.load( 

213 Bs + (col_offset // group_n) * stride_Bs_n + scale_k * stride_Bs_k, 

214 mask=col_offset < N, 

215 other=0.0, 

216 ) 

217 acc += ( 

218 tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

219 * a_s[:, None] 

220 * b_s[None, :] 

221 ) 

222 offs_k += BLOCK_K 

223 

224 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], acc.to(c_desc.dtype)) 

225 

226 

227def general_w8a8_block_fp8_matmul( 

228 a, 

229 b, 

230 c, 

231 a_s, 

232 b_s, 

233 M, 

234 N, 

235 K, 

236 group_n, 

237 group_k, 

238): 

239 logger.debug( 

240 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(general), [shape info]: [-, %s, %s, %s](batch, M, N, K)", 

241 M, 

242 N, 

243 K, 

244 ) 

245 grid = lambda meta: ( 

246 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

247 ) 

248 

249 with torch_device_fn.device(a.device): 

250 w8a8_block_fp8_matmul_kernel[grid]( 

251 a, 

252 b, 

253 c, 

254 a_s, 

255 b_s, 

256 M, 

257 N, 

258 K, 

259 group_n, 

260 group_k, 

261 a.stride(0), 

262 a.stride(1), 

263 b.stride(1), 

264 b.stride(0), 

265 c.stride(0), 

266 c.stride(1), 

267 a_s.stride(0), 

268 a_s.stride(1), 

269 b_s.stride(1), 

270 b_s.stride(0), 

271 ) 

272 return c 

273 

274 

275def sqmma_w8a8_block_fp8_matmul( 

276 a, 

277 b, 

278 c, 

279 a_s, 

280 b_s, 

281 M, 

282 N, 

283 K, 

284 group_n, 

285 group_k, 

286): 

287 logger.debug( 

288 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(sqmma), [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

289 "[A column-major]: %s, [B column-major]: %s", 

290 M, 

291 N, 

292 K, 

293 a.stride(0) == 1, 

294 b.stride(0) == 1, 

295 ) 

296 device = a.device 

297 if not a.is_contiguous(): 

298 a = a.contiguous() 

299 if not b.is_contiguous(): 

300 b = b.contiguous() 

301 

302 BLOCK_M = 64 

303 BLOCK_N = 64 

304 BLOCK_K = 128 

305 GROUP_M = 8 

306 

307 desc_a = TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K]) 

308 desc_b = TensorDescriptor.from_tensor(b, [BLOCK_K, BLOCK_N]) 

309 desc_c = TensorDescriptor.from_tensor(c, [BLOCK_M, BLOCK_N]) 

310 

311 grid = lambda meta: ( 

312 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]), 

313 1, 

314 1, 

315 ) 

316 

317 with torch_device_fn.device(device): 

318 w8a8_block_fp8_matmul_sqmma_kernel[grid]( 

319 desc_a, 

320 desc_b, 

321 desc_c, 

322 a_s, 

323 b_s, 

324 M, 

325 N, 

326 K, 

327 group_n, 

328 group_k, 

329 a_s.stride(0), 

330 a_s.stride(1), 

331 b_s.stride(0), 

332 b_s.stride(1), 

333 GROUP_M, 

334 BLOCK_M, 

335 BLOCK_N, 

336 BLOCK_K, 

337 num_warps=4, 

338 num_stages=3, 

339 ) 

340 return c 

341 

342 

343def w8a8_block_fp8_matmul( 

344 A: torch.Tensor, 

345 B: torch.Tensor, 

346 As: torch.Tensor, 

347 Bs: torch.Tensor, 

348 block_size: List[int], 

349 output_dtype: torch.dtype = torch.bfloat16, 

350) -> torch.Tensor: 

351 device = A.device 

352 assert len(block_size) == 2 

353 block_n, block_k = block_size 

354 

355 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1: 

356 A = A.contiguous() 

357 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1: 

358 B = B.contiguous() 

359 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1: 

360 As = As.contiguous() 

361 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1: 

362 Bs = Bs.contiguous() 

363 

364 assert A.shape[-1] == B.shape[-1], "incompatible dimensions" 

365 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch" 

366 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape" 

367 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D" 

368 

369 M = A.numel() // A.shape[-1] 

370 N, K = B.shape 

371 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension" 

372 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension" 

373 

374 output_shape = A.shape[:-1] + (N,) 

375 c = torch.empty(output_shape, device=device, dtype=output_dtype) 

376 

377 a_2d = A.reshape(M, K) 

378 as_2d = As.reshape(M, As.shape[-1]) 

379 c_2d = c.reshape(M, N) 

380 if is_sqmma_compatible(a_2d, B, output_dtype, N, K): 

381 return sqmma_w8a8_block_fp8_matmul( 

382 a_2d, 

383 B, 

384 c_2d, 

385 as_2d, 

386 Bs, 

387 M, 

388 N, 

389 K, 

390 block_n, 

391 block_k, 

392 ).reshape(c.shape) 

393 

394 return general_w8a8_block_fp8_matmul( 

395 a_2d, 

396 B, 

397 c_2d, 

398 as_2d, 

399 Bs, 

400 M, 

401 N, 

402 K, 

403 block_n, 

404 block_k, 

405 ).reshape(c.shape)