Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from triton.tools.tensor_descriptor import TensorDescriptor 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm") 

15 

16EXPAND_CONFIG_FILENAME = os.path.normpath( 

17 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml") 

18) 

19 

20# Module-level capability flag: evaluated once at import time, then reused as 

21# a constant for the entire process lifetime with no repeated parsing overhead. 

22# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2. 

23SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2) 

24 

25 

26def is_supported_sqmma_layout(tensor): 

27 return tensor.is_contiguous() or ( 

28 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

29 ) 

30 

31 

32def is_sqmma_compatible(a, b, N, K): 

33 return ( 

34 SQMMA_ON 

35 and a.dim() == 2 

36 and b.dim() == 2 

37 and a.dtype == b.dtype 

38 and a.dtype in (torch.float16, torch.bfloat16) 

39 and is_supported_sqmma_layout(a) 

40 and is_supported_sqmma_layout(b) 

41 and N % 8 == 0 

42 and K % 8 == 0 

43 ) 

44 

45 

46@triton.jit 

47def prev_multiple_of(a, b): 

48 # the largest x<a that x%b ==0 

49 return tl.cdiv(a, b) * b - b 

50 

51 

52@libentry() 

53@libtuner( 

54 configs=runtime.ops_get_configs("mm", yaml_path=EXPAND_CONFIG_FILENAME) 

55 if os.environ.get("USE_FLAGTUNE") == "1" 

56 else runtime.get_tuned_config("mm"), 

57 key=["M", "N", "K", "stride_am", "stride_bk"], 

58 strategy=runtime.get_expand_config("mm", yaml_path=EXPAND_CONFIG_FILENAME)[ 

59 "strategy" 

60 ] 

61 if os.environ.get("USE_FLAGTUNE") == "1" 

62 else ["align32", "align32", "align32", "align32", "align32"], 

63 warmup=5, 

64 rep=5, 

65) 

66@triton.jit 

67def mm_kernel( 

68 A, 

69 B, 

70 C, 

71 M, 

72 N, 

73 K, 

74 stride_am, 

75 stride_ak, 

76 stride_bk, 

77 stride_bn, 

78 stride_cm, 

79 stride_cn, 

80 dtype: tl.constexpr, 

81 BLOCK_M: tl.constexpr, 

82 BLOCK_N: tl.constexpr, 

83 BLOCK_K: tl.constexpr, 

84 GROUP_M: tl.constexpr, 

85 IS_FP64: tl.constexpr = False, 

86): 

87 # matrix multiplication 

88 pid = ext.program_id(0) 

89 grid_m = tl.cdiv(M, BLOCK_M) 

90 grid_n = tl.cdiv(N, BLOCK_N) 

91 # re-order program ID for better L2 performance 

92 width = GROUP_M * grid_n 

93 group_id = pid // width 

94 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

95 pid_m = group_id * GROUP_M + (pid % group_size) 

96 pid_n = (pid % width) // (group_size) 

97 # do matrix multiplication 

98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

100 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

102 rm = rm.to(tl.int64) 

103 rn = rn.to(tl.int64) 

104 prev_multiple = prev_multiple_of(K, BLOCK_K) 

105 

106 if IS_FP64: 

107 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

108 else: 

109 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

110 for start_k in range(0, prev_multiple, BLOCK_K): 

111 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

112 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

113 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

114 if a.dtype != b.dtype: 

115 a = a.to(C.dtype.element_ty) 

116 b = b.to(C.dtype.element_ty) 

117 if IS_FP64: 

118 acc += tl.dot(a, b, allow_tf32=False) 

119 else: 

120 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

121 

122 # loop peeling 

123 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

124 mask_k = rk < K 

125 a = tl.load( 

126 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :] 

127 ) 

128 b = tl.load( 

129 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None] 

130 ) 

131 if a.dtype != b.dtype: 

132 a = a.to(C.dtype.element_ty) 

133 b = b.to(C.dtype.element_ty) 

134 if IS_FP64: 

135 acc += tl.dot(a, b, allow_tf32=False) 

136 else: 

137 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

138 

139 acc = acc.to(C.dtype.element_ty) 

140 # rematerialize rm and rn to save registers 

141 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

142 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

143 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

144 mask = (rm < M)[:, None] & (rn < N)[None, :] 

145 # handles write-back with reduction-splitting 

146 tl.store(C, acc, mask=mask) 

147 

148 

149@libentry() 

150@libtuner( 

151 configs=runtime.ops_get_configs("gemv", yaml_path=EXPAND_CONFIG_FILENAME) 

152 if os.environ.get("USE_FLAGTUNE") == "1" 

153 else [ 

154 triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}), 

155 triton.Config({"BLOCK_M": 128, "BLOCK_K": 64}), 

156 ], 

157 key=["M", "K", "stride_am", "stride_bk"], 

158 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[ 

159 "strategy" 

160 ] 

161 if os.environ.get("USE_FLAGTUNE") == "1" 

162 else ["align32", "align32", "align32", "default"], 

163 warmup=5, 

164 rep=5, 

165) 

166@triton.jit 

167def gemv_kernel( 

168 A, 

169 B, 

170 C, 

171 M, 

172 K, 

173 stride_am, 

174 stride_ak, 

175 stride_bk, 

176 stride_cm, 

177 BLOCK_M: tl.constexpr, 

178 BLOCK_K: tl.constexpr, 

179): 

180 pid = ext.program_id(0) 

181 

182 row_start = pid * BLOCK_M 

183 row_offset = row_start + tl.arange(0, BLOCK_M) 

184 row_mask = row_offset < M 

185 

186 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

187 

188 for k_start in range(0, K, BLOCK_K): 

189 k_offset = k_start + tl.arange(0, BLOCK_K) 

190 k_mask = k_offset < K 

191 

192 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

193 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

194 

195 b_ptrs = B + k_offset * stride_bk 

196 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

197 

198 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely. 

199 a = a.to(tl.float32) 

200 b = b.to(tl.float32) 

201 acc += tl.sum(a * b[None, :], axis=1) 

202 

203 c_ptrs = C + row_offset * stride_cm 

204 acc = acc.to(C.dtype.element_ty) 

205 tl.store(c_ptrs, acc, mask=row_mask) 

206 

207 

208_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

209 

210 

211def get_higher_dtype(a, b): 

212 if a is b: 

213 return a 

214 

215 assert a in _ordered_datatypes 

216 assert b in _ordered_datatypes 

217 

218 for d in _ordered_datatypes: 

219 if a is d: 

220 return b 

221 if b is d: 

222 return a 

223 

224 

225def mm_fma(a, b): 

226 logger.debug("GEMS_MTHREADS MM(FMA)") 

227 device = a.device 

228 # handle non-contiguous inputs if necessary 

229 if a.stride(0) > 1 and a.stride(1) > 1: 

230 a = a.contiguous() 

231 if b.stride(0) > 1 and b.stride(1) > 1: 

232 b = b.contiguous() 

233 # checks constraints 

234 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

235 M, K = a.shape 

236 _, N = b.shape 

237 # allocates output 

238 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

239 c = torch.empty((M, N), device=device, dtype=c_dtype) 

240 # launch kernel 

241 grid = lambda META: ( 

242 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

243 ) 

244 with torch_device_fn.device(a.device): 

245 mm_kernel[grid]( 

246 a, 

247 b, 

248 c, 

249 M, 

250 N, 

251 K, 

252 a.stride(0), 

253 a.stride(1), 

254 b.stride(0), 

255 b.stride(1), 

256 c.stride(0), 

257 c.stride(1), 

258 dtype=str(a.dtype).split(".")[-1], 

259 GROUP_M=8, 

260 IS_FP64=a.dtype == torch.float64, 

261 ) 

262 return c 

263 

264 

265def gemv_mm(a, b, c, M, K): 

266 logger.debug( 

267 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)", 

268 M, 

269 K, 

270 ) 

271 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

272 with torch_device_fn.device(a.device): 

273 gemv_kernel[grid]( 

274 a, 

275 b, 

276 c, 

277 M, 

278 K, 

279 a.stride(0), 

280 a.stride(1), 

281 b.stride(0), 

282 c.stride(0), 

283 ) 

284 return c 

285 

286 

287def mm_out(a, b, *, out): 

288 logger.debug("GEMS_MTHREADS MM_OUT") 

289 # handle non-contiguous inputs if necessary 

290 if a.stride(0) > 1 and a.stride(1) > 1: 

291 a = a.contiguous() 

292 if b.stride(0) > 1 and b.stride(1) > 1: 

293 b = b.contiguous() 

294 # checks constraints 

295 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

296 M, K = a.shape 

297 _, N = b.shape 

298 # allocates output 

299 c = out 

300 if N == 1: 

301 return gemv_mm(a, b, c, M, K) 

302 # launch kernel 

303 grid = lambda META: ( 

304 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

305 ) 

306 with torch_device_fn.device(a.device): 

307 mm_kernel[grid]( 

308 a, 

309 b, 

310 c, 

311 M, 

312 N, 

313 K, 

314 a.stride(0), 

315 a.stride(1), 

316 b.stride(0), 

317 b.stride(1), 

318 c.stride(0), 

319 c.stride(1), 

320 dtype=str(a.dtype).split(".")[-1], 

321 GROUP_M=8, 

322 IS_FP64=a.dtype == torch.float64, 

323 ) 

324 return c 

325 

326 

327def matmul_sqmma_set_block_size_hook(nargs): 

328 BLOCK_M = nargs["BLOCK_M"] 

329 BLOCK_N = nargs["BLOCK_N"] 

330 BLOCK_K = nargs["BLOCK_K"] 

331 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

332 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

333 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

334 

335 

336def sqmma_get_configs(pre_hook=matmul_sqmma_set_block_size_hook): 

337 return [ 

338 triton.Config( 

339 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128}, 

340 num_stages=1, 

341 num_warps=4, 

342 pre_hook=pre_hook, 

343 ), 

344 triton.Config( 

345 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, 

346 num_stages=1, 

347 num_warps=4, 

348 pre_hook=pre_hook, 

349 ), 

350 ] 

351 

352 

353@libentry() 

354@libtuner( 

355 configs=sqmma_get_configs(), 

356 key=["M", "N", "K", "dtype"], 

357 strategy=["align32", "align32", "align32", "default"], 

358) 

359@triton.jit 

360def mm_sqmma_kernel( 

361 a_desc, 

362 b_desc, 

363 c_desc, 

364 M, 

365 N, 

366 K, 

367 dtype: tl.constexpr, 

368 GROUP_M: tl.constexpr, 

369 BLOCK_M: tl.constexpr, 

370 BLOCK_N: tl.constexpr, 

371 BLOCK_K: tl.constexpr, 

372): 

373 pid = ext.program_id(0) 

374 grid_m = tl.cdiv(M, BLOCK_M) 

375 grid_n = tl.cdiv(N, BLOCK_N) 

376 width = GROUP_M * grid_n 

377 group_id = pid // width 

378 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

379 pid_m = group_id * GROUP_M + (pid % group_size) 

380 pid_n = (pid % width) // (group_size) 

381 offs_am = (pid_m * BLOCK_M).to(tl.int32) 

382 offs_bn = (pid_n * BLOCK_N).to(tl.int32) 

383 offs_k = 0 

384 offs_k = offs_k.to(tl.int32) 

385 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

386 for k in range(0, tl.cdiv(K, BLOCK_K)): 

387 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

388 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

389 accumulator = tl.dot(a, b, acc=accumulator) 

390 offs_k += BLOCK_K 

391 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype)) 

392 

393 

394def get_triton_type(elem_type): 

395 type_map = { 

396 torch.float16: tl.float16, 

397 torch.bfloat16: tl.bfloat16, 

398 torch.float8_e4m3fn: tl.float8e4nv, 

399 } 

400 return type_map.get(elem_type, None) 

401 

402 

403def mm_sqmma(A, B, M, N, K, GROUP_M): 

404 logger.debug("GEMS_MTHREADS MM(SQMMA)") 

405 device = A.device 

406 if not A.is_contiguous(): 

407 A = A.contiguous() 

408 if not B.is_contiguous(): 

409 B = B.contiguous() 

410 a_type = A.dtype 

411 b_type = B.dtype 

412 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

413 c_dtype = get_higher_dtype(a_type, b_type) 

414 C = torch.empty((M, N), dtype=c_dtype, device=device) 

415 # Real block_shape values are filled in by matmul_sqmma_set_block_size_hook 

416 # at autotune/launch time based on the BLOCK_M/N/K selected by libtuner. 

417 dummy_block = [1, 1] 

418 desc_a = TensorDescriptor(A, A.shape, A.stride(), dummy_block) 

419 desc_b = TensorDescriptor(B, B.shape, B.stride(), dummy_block) 

420 desc_c = TensorDescriptor(C, C.shape, C.stride(), dummy_block) 

421 grid = lambda META: ( 

422 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

423 1, 

424 1, 

425 ) 

426 mm_sqmma_kernel[grid]( 

427 desc_a, 

428 desc_b, 

429 desc_c, 

430 M, 

431 N, 

432 K, 

433 dtype=str(a_type).split(".")[-1], 

434 GROUP_M=GROUP_M, 

435 ) 

436 return C 

437 

438 

439def mm(a, b): 

440 a_dtype = a.dtype 

441 b_dtype = b.dtype 

442 M, K = a.shape 

443 _, N = b.shape 

444 if N == 1: 

445 c_dtype = get_higher_dtype(a_dtype, b_dtype) 

446 c = torch.empty((M, N), device=a.device, dtype=c_dtype) 

447 return gemv_mm(a, b, c, M, K) 

448 

449 if is_sqmma_compatible(a, b, N, K): 

450 GROUP_M = 8 

451 return mm_sqmma( 

452 a, 

453 b, 

454 M, 

455 N, 

456 K, 

457 GROUP_M, 

458 ) 

459 else: 

460 return mm_fma(a, b)