Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%

227 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor 

14 

15logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm") 

16 

17EXPAND_CONFIG_FILENAME = os.path.normpath( 

18 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml") 

19) 

20 

21# Module-level capability flag: evaluated once at import time, then reused as 

22# a constant for the entire process lifetime with no repeated parsing overhead. 

23# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2. 

24SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2) 

25 

26 

27def is_supported_sqmma_layout(tensor): 

28 return tensor.is_contiguous() or ( 

29 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

30 ) 

31 

32 

33def is_sqmma_compatible(a, b, N, K): 

34 return ( 

35 SQMMA_ON 

36 and a.dim() == 2 

37 and b.dim() == 2 

38 and a.dtype == b.dtype 

39 and a.dtype in (torch.float16, torch.bfloat16) 

40 and is_supported_sqmma_layout(a) 

41 and is_supported_sqmma_layout(b) 

42 and N % 8 == 0 

43 and K % 8 == 0 

44 ) 

45 

46 

47@triton.jit 

48def prev_multiple_of(a, b): 

49 # the largest x<a that x%b ==0 

50 return tl.cdiv(a, b) * b - b 

51 

52 

53@libentry() 

54@libtuner( 

55 configs=runtime.ops_get_configs("mm", yaml_path=EXPAND_CONFIG_FILENAME) 

56 if os.environ.get("USE_FLAGTUNE") == "1" 

57 else runtime.get_tuned_config("mm"), 

58 key=["M", "N", "K", "stride_am", "stride_bk"], 

59 strategy=runtime.get_expand_config("mm", yaml_path=EXPAND_CONFIG_FILENAME)[ 

60 "strategy" 

61 ] 

62 if os.environ.get("USE_FLAGTUNE") == "1" 

63 else ["align32", "align32", "align32", "align32", "align32"], 

64 warmup=5, 

65 rep=5, 

66) 

67@triton.jit 

68def mm_kernel( 

69 A, 

70 B, 

71 C, 

72 M, 

73 N, 

74 K, 

75 stride_am, 

76 stride_ak, 

77 stride_bk, 

78 stride_bn, 

79 stride_cm, 

80 stride_cn, 

81 dtype: tl.constexpr, 

82 BLOCK_M: tl.constexpr, 

83 BLOCK_N: tl.constexpr, 

84 BLOCK_K: tl.constexpr, 

85 GROUP_M: tl.constexpr, 

86): 

87 # matrix multiplication 

88 pid = tle.program_id(0) 

89 grid_m = tl.cdiv(M, BLOCK_M) 

90 grid_n = tl.cdiv(N, BLOCK_N) 

91 # re-order program ID for better L2 performance 

92 width = GROUP_M * grid_n 

93 group_id = pid // width 

94 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

95 pid_m = group_id * GROUP_M + (pid % group_size) 

96 pid_n = (pid % width) // (group_size) 

97 # do matrix multiplication 

98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

100 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

102 rm = rm.to(tl.int64) 

103 rn = rn.to(tl.int64) 

104 prev_multiple = prev_multiple_of(K, BLOCK_K) 

105 

106 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

107 for start_k in range(0, prev_multiple, BLOCK_K): 

108 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

109 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

110 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

111 if a.dtype != b.dtype: 

112 a = a.to(C.dtype.element_ty) 

113 b = b.to(C.dtype.element_ty) 

114 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

115 

116 # loop peeling 

117 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

118 mask_k = rk < K 

119 a = tl.load( 

120 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :] 

121 ) 

122 b = tl.load( 

123 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None] 

124 ) 

125 if a.dtype != b.dtype: 

126 a = a.to(C.dtype.element_ty) 

127 b = b.to(C.dtype.element_ty) 

128 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

129 

130 acc = acc.to(C.dtype.element_ty) 

131 # rematerialize rm and rn to save registers 

132 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

133 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

134 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

135 mask = (rm < M)[:, None] & (rn < N)[None, :] 

136 # handles write-back with reduction-splitting 

137 tl.store(C, acc, mask=mask) 

138 

139 

140@libentry() 

141@libtuner( 

142 configs=runtime.ops_get_configs("gemv", yaml_path=EXPAND_CONFIG_FILENAME) 

143 if os.environ.get("USE_FLAGTUNE") == "1" 

144 else [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})], 

145 key=["M", "K", "stride_am", "stride_bk"], 

146 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[ 

147 "strategy" 

148 ] 

149 if os.environ.get("USE_FLAGTUNE") == "1" 

150 else ["align32", "align32", "align32", "default"], 

151 warmup=5, 

152 rep=5, 

153) 

154@triton.jit 

155def gemv_kernel( 

156 A, 

157 B, 

158 C, 

159 M, 

160 K, 

161 stride_am, 

162 stride_ak, 

163 stride_bk, 

164 stride_cm, 

165 BLOCK_M: tl.constexpr, 

166 BLOCK_K: tl.constexpr, 

167): 

168 pid = tle.program_id(0) 

169 

170 row_start = pid * BLOCK_M 

171 row_offset = row_start + tl.arange(0, BLOCK_M) 

172 row_mask = row_offset < M 

173 

174 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

175 

176 for k_start in range(0, K, BLOCK_K): 

177 k_offset = k_start + tl.arange(0, BLOCK_K) 

178 k_mask = k_offset < K 

179 

180 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

181 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

182 

183 b_ptrs = B + k_offset * stride_bk 

184 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

185 

186 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely. 

187 a = a.to(tl.float32) 

188 b = b.to(tl.float32) 

189 acc += tl.sum(a * b[None, :], axis=1) 

190 

191 c_ptrs = C + row_offset * stride_cm 

192 acc = acc.to(C.dtype.element_ty) 

193 tl.store(c_ptrs, acc, mask=row_mask) 

194 

195 

196_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

197 

198 

199def get_higher_dtype(a, b): 

200 if a is b: 

201 return a 

202 

203 assert a in _ordered_datatypes 

204 assert b in _ordered_datatypes 

205 

206 for d in _ordered_datatypes: 

207 if a is d: 

208 return b 

209 if b is d: 

210 return a 

211 

212 

213def mm_fma(a, b): 

214 logger.debug("GEMS_MTHREADS MM(FMA)") 

215 device = a.device 

216 # handle non-contiguous inputs if necessary 

217 if a.stride(0) > 1 and a.stride(1) > 1: 

218 a = a.contiguous() 

219 if b.stride(0) > 1 and b.stride(1) > 1: 

220 b = b.contiguous() 

221 # checks constraints 

222 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

223 M, K = a.shape 

224 _, N = b.shape 

225 # allocates output 

226 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

227 c = torch.empty((M, N), device=device, dtype=c_dtype) 

228 # launch kernel 

229 grid = lambda META: ( 

230 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

231 ) 

232 with torch_device_fn.device(a.device): 

233 mm_kernel[grid]( 

234 a, 

235 b, 

236 c, 

237 M, 

238 N, 

239 K, 

240 a.stride(0), 

241 a.stride(1), 

242 b.stride(0), 

243 b.stride(1), 

244 c.stride(0), 

245 c.stride(1), 

246 dtype=str(a.dtype).split(".")[-1], 

247 GROUP_M=8, 

248 ) 

249 return c 

250 

251 

252def gemv_mm(a, b, c, M, K): 

253 logger.debug( 

254 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)", 

255 M, 

256 K, 

257 ) 

258 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

259 with torch_device_fn.device(a.device): 

260 gemv_kernel[grid]( 

261 a, 

262 b, 

263 c, 

264 M, 

265 K, 

266 a.stride(0), 

267 a.stride(1), 

268 b.stride(0), 

269 c.stride(0), 

270 ) 

271 return c 

272 

273 

274def mm_out(a, b, *, out): 

275 logger.debug("GEMS_MTHREADS MM_OUT") 

276 # handle non-contiguous inputs if necessary 

277 if a.stride(0) > 1 and a.stride(1) > 1: 

278 a = a.contiguous() 

279 if b.stride(0) > 1 and b.stride(1) > 1: 

280 b = b.contiguous() 

281 # checks constraints 

282 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

283 M, K = a.shape 

284 _, N = b.shape 

285 # allocates output 

286 c = out 

287 if N == 1: 

288 return gemv_mm(a, b, c, M, K) 

289 # launch kernel 

290 grid = lambda META: ( 

291 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

292 ) 

293 with torch_device_fn.device(a.device): 

294 mm_kernel[grid]( 

295 a, 

296 b, 

297 c, 

298 M, 

299 N, 

300 K, 

301 a.stride(0), 

302 a.stride(1), 

303 b.stride(0), 

304 b.stride(1), 

305 c.stride(0), 

306 c.stride(1), 

307 dtype=str(a.dtype).split(".")[-1], 

308 GROUP_M=8, 

309 ) 

310 return c 

311 

312 

313def sqmma_descriptor_pre_hook(nargs): 

314 a = nargs["A"] 

315 b = nargs["B"] 

316 c = nargs["C"] 

317 block_m = nargs["BLOCK_M"] 

318 block_n = nargs["BLOCK_N"] 

319 block_k = nargs["BLOCK_K"] 

320 device = c.device 

321 

322 nargs["a_desc_ptr"].copy_( 

323 get_cached_tma_device_descriptor(a, block_m, block_k, device) 

324 ) 

325 nargs["b_desc_ptr"].copy_( 

326 get_cached_tma_device_descriptor(b, block_k, block_n, device) 

327 ) 

328 nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device)) 

329 

330 

331@libentry() 

332@libtuner( 

333 configs=runtime.ops_get_configs( 

334 "mm_general_tma", 

335 pre_hook=sqmma_descriptor_pre_hook, 

336 yaml_path=EXPAND_CONFIG_FILENAME, 

337 ) 

338 if os.environ.get("USE_FLAGTUNE") == "1" 

339 else [ 

340 triton.Config( 

341 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64}, 

342 num_stages=1, 

343 num_warps=4, 

344 pre_hook=sqmma_descriptor_pre_hook, 

345 ) 

346 ], 

347 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

348 strategy=runtime.get_expand_config( 

349 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

350 )["strategy"] 

351 if os.environ.get("USE_FLAGTUNE") == "1" 

352 else ["align32", "align32", "align32", "align32", "align32", "default"], 

353 warmup=5, 

354 rep=5, 

355) 

356@triton.jit 

357def mm_sqmma_kernel( 

358 A, 

359 B, 

360 C, 

361 a_desc_ptr, 

362 b_desc_ptr, 

363 c_desc_ptr, 

364 M, 

365 N, 

366 K, 

367 stride_am, 

368 stride_ak, 

369 stride_bk, 

370 stride_bn, 

371 dtype: tl.constexpr, 

372 GROUP_M: tl.constexpr, 

373 BLOCK_M: tl.constexpr, 

374 BLOCK_N: tl.constexpr, 

375 BLOCK_K: tl.constexpr, 

376 ab_dtype: tl.constexpr, 

377 c_dtype: tl.constexpr, 

378 is_transpose_a: tl.constexpr = False, 

379 is_transpose_b: tl.constexpr = False, 

380): 

381 pid = tle.program_id(0) 

382 grid_m = tl.cdiv(M, BLOCK_M) 

383 grid_n = tl.cdiv(N, BLOCK_N) 

384 width = GROUP_M * grid_n 

385 group_id = pid // width 

386 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

387 pid_m = group_id * GROUP_M + (pid % group_size) 

388 pid_n = (pid % width) // (group_size) 

389 offs_am = pid_m * BLOCK_M 

390 offs_bn = pid_n * BLOCK_N 

391 offs_k = 0 

392 offs_am = offs_am.to(tl.int32) 

393 offs_bn = offs_bn.to(tl.int32) 

394 offs_k = offs_k.to(tl.int32) 

395 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

396 tme_load_ab_dtype = ab_dtype 

397 c_store_dtype = c_dtype 

398 for k in range(0, tl.cdiv(K, BLOCK_K)): 

399 if is_transpose_a: 

400 a = tl._experimental_descriptor_load( 

401 a_desc_ptr, 

402 [offs_k, offs_am], 

403 [BLOCK_K, BLOCK_M], 

404 tme_load_ab_dtype, 

405 ) 

406 a = tl.trans(a) 

407 else: 

408 a = tl._experimental_descriptor_load( 

409 a_desc_ptr, 

410 [offs_am, offs_k], 

411 [BLOCK_M, BLOCK_K], 

412 tme_load_ab_dtype, 

413 ) 

414 if is_transpose_b: 

415 b = tl._experimental_descriptor_load( 

416 b_desc_ptr, 

417 [offs_bn, offs_k], 

418 [BLOCK_N, BLOCK_K], 

419 tme_load_ab_dtype, 

420 ) 

421 b = tl.trans(b) 

422 else: 

423 b = tl._experimental_descriptor_load( 

424 b_desc_ptr, 

425 [offs_k, offs_bn], 

426 [BLOCK_K, BLOCK_N], 

427 tme_load_ab_dtype, 

428 ) 

429 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

430 offs_k += BLOCK_K 

431 accumulator = accumulator.to(c_store_dtype) 

432 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) 

433 

434 

435def get_triton_type(elem_type): 

436 type_map = { 

437 torch.float16: tl.float16, 

438 torch.bfloat16: tl.bfloat16, 

439 torch.float8_e4m3fn: tl.float8e4nv, 

440 } 

441 return type_map.get(elem_type, None) 

442 

443 

444def mm_sqmma(A, B, M, N, K, GROUP_M): 

445 logger.debug("GEMS_MTHREADS MM(SQMMA)") 

446 device = A.device 

447 # handle non-contiguous inputs if necessary 

448 is_transpose_a = False 

449 is_transpose_b = False 

450 if not A.is_contiguous(): 

451 if A.stride(0) == 1 and A.stride(1) == A.shape[0]: 

452 is_transpose_a = True 

453 else: 

454 A = A.contiguous() 

455 if not B.is_contiguous(): 

456 if B.stride(0) == 1 and B.stride(1) == B.shape[0]: 

457 is_transpose_b = True 

458 else: 

459 B = B.contiguous() 

460 a_type = A.dtype 

461 b_type = B.dtype 

462 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

463 c_dtype = get_higher_dtype(a_type, b_type) 

464 C = torch.empty((M, N), dtype=c_dtype, device=device) 

465 desc_a = torch.empty((64,), dtype=torch.int8, device=device) 

466 desc_b = torch.empty((64,), dtype=torch.int8, device=device) 

467 desc_c = torch.empty((64,), dtype=torch.int8, device=device) 

468 grid = lambda META: ( 

469 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

470 1, 

471 1, 

472 ) 

473 mm_sqmma_kernel[grid]( 

474 A, 

475 B, 

476 C, 

477 desc_a, 

478 desc_b, 

479 desc_c, 

480 M, 

481 N, 

482 K, 

483 A.stride(0), 

484 A.stride(1), 

485 B.stride(0), 

486 B.stride(1), 

487 str(a_type).split(".")[-1], 

488 GROUP_M=GROUP_M, 

489 ab_dtype=get_triton_type(a_type), 

490 c_dtype=get_triton_type(c_dtype), 

491 is_transpose_a=is_transpose_a, 

492 is_transpose_b=is_transpose_b, 

493 ) 

494 return C 

495 

496 

497def mm(a, b): 

498 a_dtype = a.dtype 

499 b_dtype = b.dtype 

500 M, K = a.shape 

501 _, N = b.shape 

502 # fp32 does not support MMA instructions, only enable SQMMA for fp16/bf16 

503 need_sqmma = a_dtype != torch.float32 and b_dtype != torch.float32 

504 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA") 

505 if need_sqmma: 

506 os.environ["MUSA_ENABLE_SQMMA"] = "1" 

507 else: 

508 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

509 try: 

510 if N == 1: 

511 c_dtype = get_higher_dtype(a_dtype, b_dtype) 

512 c = torch.empty((M, N), device=a.device, dtype=c_dtype) 

513 return gemv_mm(a, b, c, M, K) 

514 

515 if is_sqmma_compatible(a, b, N, K): 

516 GROUP_M = 8 

517 return mm_sqmma( 

518 a, 

519 b, 

520 M, 

521 N, 

522 K, 

523 GROUP_M, 

524 ) 

525 else: 

526 return mm_fma(a, b) 

527 finally: 

528 if prev_sqmma is None: 

529 os.environ.pop("MUSA_ENABLE_SQMMA", None) 

530 else: 

531 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma