Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%

196 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from triton.tools.tensor_descriptor import TensorDescriptor 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm") 

15 

16EXPAND_CONFIG_FILENAME = os.path.normpath( 

17 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml") 

18) 

19 

20# Module-level capability flag: evaluated once at import time, then reused as 

21# a constant for the entire process lifetime with no repeated parsing overhead. 

22# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2. 

23SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2) 

24 

25 

26def is_supported_sqmma_layout(tensor): 

27 return tensor.is_contiguous() or ( 

28 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

29 ) 

30 

31 

32def is_sqmma_compatible(a, b, N, K): 

33 return ( 

34 SQMMA_ON 

35 and a.dim() == 2 

36 and b.dim() == 2 

37 and a.dtype == b.dtype 

38 and a.dtype in (torch.float16, torch.bfloat16) 

39 and is_supported_sqmma_layout(a) 

40 and is_supported_sqmma_layout(b) 

41 and N % 8 == 0 

42 and K % 8 == 0 

43 ) 

44 

45 

46@triton.jit 

47def prev_multiple_of(a, b): 

48 # the largest x<a that x%b ==0 

49 return tl.cdiv(a, b) * b - b 

50 

51 

52@libentry() 

53@libtuner( 

54 configs=runtime.get_tuned_config("mm"), 

55 key=["M", "N", "K", "stride_am", "stride_bk"], 

56 strategy=["align32", "align32", "align32", "align32", "align32"], 

57 warmup=5, 

58 rep=5, 

59 flagtune_op_name="mm", 

60 flagtune_expand_op_name="mm", 

61 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

62) 

63@triton.jit 

64def mm_kernel( 

65 A, 

66 B, 

67 C, 

68 M, 

69 N, 

70 K, 

71 stride_am, 

72 stride_ak, 

73 stride_bk, 

74 stride_bn, 

75 stride_cm, 

76 stride_cn, 

77 dtype: tl.constexpr, 

78 BLOCK_M: tl.constexpr, 

79 BLOCK_N: tl.constexpr, 

80 BLOCK_K: tl.constexpr, 

81 GROUP_M: tl.constexpr, 

82 IS_FP64: tl.constexpr = False, 

83): 

84 # matrix multiplication 

85 pid = ext.program_id(0) 

86 grid_m = tl.cdiv(M, BLOCK_M) 

87 grid_n = tl.cdiv(N, BLOCK_N) 

88 # re-order program ID for better L2 performance 

89 width = GROUP_M * grid_n 

90 group_id = pid // width 

91 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

92 pid_m = group_id * GROUP_M + (pid % group_size) 

93 pid_n = (pid % width) // (group_size) 

94 # do matrix multiplication 

95 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

96 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

97 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

98 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

99 rm = rm.to(tl.int64) 

100 rn = rn.to(tl.int64) 

101 prev_multiple = prev_multiple_of(K, BLOCK_K) 

102 

103 if IS_FP64: 

104 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

105 else: 

106 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

107 for start_k in range(0, prev_multiple, BLOCK_K): 

108 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

109 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

110 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

111 if a.dtype != b.dtype: 

112 a = a.to(C.dtype.element_ty) 

113 b = b.to(C.dtype.element_ty) 

114 if IS_FP64: 

115 acc += tl.dot(a, b, allow_tf32=False) 

116 else: 

117 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

118 

119 # loop peeling 

120 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

121 mask_k = rk < K 

122 a = tl.load( 

123 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :] 

124 ) 

125 b = tl.load( 

126 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None] 

127 ) 

128 if a.dtype != b.dtype: 

129 a = a.to(C.dtype.element_ty) 

130 b = b.to(C.dtype.element_ty) 

131 if IS_FP64: 

132 acc += tl.dot(a, b, allow_tf32=False) 

133 else: 

134 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

135 

136 acc = acc.to(C.dtype.element_ty) 

137 # rematerialize rm and rn to save registers 

138 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

139 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

140 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

141 mask = (rm < M)[:, None] & (rn < N)[None, :] 

142 # handles write-back with reduction-splitting 

143 tl.store(C, acc, mask=mask) 

144 

145 

146@libentry() 

147@libtuner( 

148 configs=[ 

149 triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}), 

150 triton.Config({"BLOCK_M": 128, "BLOCK_K": 64}), 

151 ], 

152 key=["M", "K", "stride_am", "stride_bk"], 

153 strategy=["align32", "align32", "align32", "default"], 

154 warmup=5, 

155 rep=5, 

156 flagtune_op_name="mm", 

157 flagtune_expand_op_name="gemv", 

158 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

159) 

160@triton.jit 

161def gemv_kernel( 

162 A, 

163 B, 

164 C, 

165 M, 

166 K, 

167 stride_am, 

168 stride_ak, 

169 stride_bk, 

170 stride_cm, 

171 BLOCK_M: tl.constexpr, 

172 BLOCK_K: tl.constexpr, 

173): 

174 pid = ext.program_id(0) 

175 

176 row_start = pid * BLOCK_M 

177 row_offset = row_start + tl.arange(0, BLOCK_M) 

178 row_mask = row_offset < M 

179 

180 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

181 

182 for k_start in range(0, K, BLOCK_K): 

183 k_offset = k_start + tl.arange(0, BLOCK_K) 

184 k_mask = k_offset < K 

185 

186 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

187 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

188 

189 b_ptrs = B + k_offset * stride_bk 

190 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

191 

192 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely. 

193 a = a.to(tl.float32) 

194 b = b.to(tl.float32) 

195 acc += tl.sum(a * b[None, :], axis=1) 

196 

197 c_ptrs = C + row_offset * stride_cm 

198 acc = acc.to(C.dtype.element_ty) 

199 tl.store(c_ptrs, acc, mask=row_mask) 

200 

201 

202_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

203 

204 

205def get_higher_dtype(a, b): 

206 if a is b: 

207 return a 

208 

209 assert a in _ordered_datatypes 

210 assert b in _ordered_datatypes 

211 

212 for d in _ordered_datatypes: 

213 if a is d: 

214 return b 

215 if b is d: 

216 return a 

217 

218 

219def mm_fma(a, b): 

220 logger.debug("GEMS_MTHREADS MM(FMA)") 

221 device = a.device 

222 # handle non-contiguous inputs if necessary 

223 if a.stride(0) > 1 and a.stride(1) > 1: 

224 a = a.contiguous() 

225 if b.stride(0) > 1 and b.stride(1) > 1: 

226 b = b.contiguous() 

227 # checks constraints 

228 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

229 M, K = a.shape 

230 _, N = b.shape 

231 # allocates output 

232 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

233 c = torch.empty((M, N), device=device, dtype=c_dtype) 

234 # launch kernel 

235 grid = lambda META: ( 

236 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

237 ) 

238 with torch_device_fn.device(a.device): 

239 mm_kernel[grid]( 

240 a, 

241 b, 

242 c, 

243 M, 

244 N, 

245 K, 

246 a.stride(0), 

247 a.stride(1), 

248 b.stride(0), 

249 b.stride(1), 

250 c.stride(0), 

251 c.stride(1), 

252 dtype=str(a.dtype).split(".")[-1], 

253 GROUP_M=8, 

254 IS_FP64=a.dtype == torch.float64, 

255 ) 

256 return c 

257 

258 

259def gemv_mm(a, b, c, M, K): 

260 logger.debug( 

261 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)", 

262 M, 

263 K, 

264 ) 

265 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

266 with torch_device_fn.device(a.device): 

267 gemv_kernel[grid]( 

268 a, 

269 b, 

270 c, 

271 M, 

272 K, 

273 a.stride(0), 

274 a.stride(1), 

275 b.stride(0), 

276 c.stride(0), 

277 ) 

278 return c 

279 

280 

281def mm_out(a, b, *, out): 

282 logger.debug("GEMS_MTHREADS MM_OUT") 

283 # handle non-contiguous inputs if necessary 

284 if a.stride(0) > 1 and a.stride(1) > 1: 

285 a = a.contiguous() 

286 if b.stride(0) > 1 and b.stride(1) > 1: 

287 b = b.contiguous() 

288 # checks constraints 

289 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

290 M, K = a.shape 

291 _, N = b.shape 

292 # allocates output 

293 c = out 

294 if N == 1: 

295 return gemv_mm(a, b, c, M, K) 

296 # launch kernel 

297 grid = lambda META: ( 

298 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

299 ) 

300 with torch_device_fn.device(a.device): 

301 mm_kernel[grid]( 

302 a, 

303 b, 

304 c, 

305 M, 

306 N, 

307 K, 

308 a.stride(0), 

309 a.stride(1), 

310 b.stride(0), 

311 b.stride(1), 

312 c.stride(0), 

313 c.stride(1), 

314 dtype=str(a.dtype).split(".")[-1], 

315 GROUP_M=8, 

316 IS_FP64=a.dtype == torch.float64, 

317 ) 

318 return c 

319 

320 

321def sqmma_descriptor_pre_hook(nargs): 

322 nargs["a_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_K"]] 

323 nargs["b_desc"].block_shape = [nargs["BLOCK_K"], nargs["BLOCK_N"]] 

324 nargs["c_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_N"]] 

325 

326 

327@libentry() 

328@libtuner( 

329 configs=runtime.ops_get_configs( 

330 "mm_general_tma", 

331 pre_hook=sqmma_descriptor_pre_hook, 

332 yaml_path=EXPAND_CONFIG_FILENAME, 

333 ) 

334 if os.environ.get("USE_FLAGTUNE") == "1" 

335 else [ 

336 triton.Config( 

337 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, 

338 num_stages=1, 

339 num_warps=4, 

340 pre_hook=sqmma_descriptor_pre_hook, 

341 ) 

342 ], 

343 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

344 strategy=runtime.get_expand_config( 

345 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

346 )["strategy"] 

347 if os.environ.get("USE_FLAGTUNE") == "1" 

348 else ["align32", "align32", "align32", "align32", "align32", "default"], 

349 warmup=5, 

350 rep=5, 

351) 

352@triton.jit 

353def mm_sqmma_kernel( 

354 a_desc, 

355 b_desc, 

356 c_desc, 

357 M, 

358 N, 

359 K, 

360 dtype: tl.constexpr, 

361 GROUP_M: tl.constexpr, 

362 BLOCK_M: tl.constexpr, 

363 BLOCK_N: tl.constexpr, 

364 BLOCK_K: tl.constexpr, 

365): 

366 pid = ext.program_id(0) 

367 grid_m = tl.cdiv(M, BLOCK_M) 

368 grid_n = tl.cdiv(N, BLOCK_N) 

369 width = GROUP_M * grid_n 

370 group_id = pid // width 

371 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

372 pid_m = group_id * GROUP_M + (pid % group_size) 

373 pid_n = (pid % width) // (group_size) 

374 offs_am = (pid_m * BLOCK_M).to(tl.int32) 

375 offs_bn = (pid_n * BLOCK_N).to(tl.int32) 

376 offs_k = 0 

377 offs_k = offs_k.to(tl.int32) 

378 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

379 for k in range(0, tl.cdiv(K, BLOCK_K)): 

380 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

381 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

382 accumulator = tl.dot(a, b, acc=accumulator) 

383 offs_k += BLOCK_K 

384 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype)) 

385 

386 

387def mm_sqmma(A, B, M, N, K): 

388 logger.debug("GEMS_MTHREADS MM(SQMMA)") 

389 device = A.device 

390 if not A.is_contiguous(): 

391 A = A.contiguous() 

392 if not B.is_contiguous(): 

393 B = B.contiguous() 

394 a_type = A.dtype 

395 b_type = B.dtype 

396 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

397 c_dtype = get_higher_dtype(a_type, b_type) 

398 C = torch.empty((M, N), dtype=c_dtype, device=device) 

399 desc_a = TensorDescriptor.from_tensor(A, [1, 1]) 

400 desc_b = TensorDescriptor.from_tensor(B, [1, 1]) 

401 desc_c = TensorDescriptor.from_tensor(C, [1, 1]) 

402 grid = lambda META: ( 

403 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

404 1, 

405 1, 

406 ) 

407 mm_sqmma_kernel[grid]( 

408 desc_a, 

409 desc_b, 

410 desc_c, 

411 M, 

412 N, 

413 K, 

414 str(a_type).split(".")[-1], 

415 ) 

416 return C 

417 

418 

419def mm(a, b): 

420 a_dtype = a.dtype 

421 b_dtype = b.dtype 

422 M, K = a.shape 

423 _, N = b.shape 

424 if N == 1: 

425 c_dtype = get_higher_dtype(a_dtype, b_dtype) 

426 c = torch.empty((M, N), device=a.device, dtype=c_dtype) 

427 return gemv_mm(a, b, c, M, K) 

428 

429 if is_sqmma_compatible(a, b, N, K): 

430 return mm_sqmma( 

431 a, 

432 b, 

433 M, 

434 N, 

435 K, 

436 ) 

437 else: 

438 return mm_fma(a, b)