Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%
227 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor
15logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm")
17EXPAND_CONFIG_FILENAME = os.path.normpath(
18 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml")
19)
21# Module-level capability flag: evaluated once at import time, then reused as
22# a constant for the entire process lifetime with no repeated parsing overhead.
23# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2.
24SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2)
27def is_supported_sqmma_layout(tensor):
28 return tensor.is_contiguous() or (
29 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
30 )
33def is_sqmma_compatible(a, b, N, K):
34 return (
35 SQMMA_ON
36 and a.dim() == 2
37 and b.dim() == 2
38 and a.dtype == b.dtype
39 and a.dtype in (torch.float16, torch.bfloat16)
40 and is_supported_sqmma_layout(a)
41 and is_supported_sqmma_layout(b)
42 and N % 8 == 0
43 and K % 8 == 0
44 )
47@triton.jit
48def prev_multiple_of(a, b):
49 # the largest x<a that x%b ==0
50 return tl.cdiv(a, b) * b - b
53@libentry()
54@libtuner(
55 configs=runtime.ops_get_configs("mm", yaml_path=EXPAND_CONFIG_FILENAME)
56 if os.environ.get("USE_FLAGTUNE") == "1"
57 else runtime.get_tuned_config("mm"),
58 key=["M", "N", "K", "stride_am", "stride_bk"],
59 strategy=runtime.get_expand_config("mm", yaml_path=EXPAND_CONFIG_FILENAME)[
60 "strategy"
61 ]
62 if os.environ.get("USE_FLAGTUNE") == "1"
63 else ["align32", "align32", "align32", "align32", "align32"],
64 warmup=5,
65 rep=5,
66)
67@triton.jit
68def mm_kernel(
69 A,
70 B,
71 C,
72 M,
73 N,
74 K,
75 stride_am,
76 stride_ak,
77 stride_bk,
78 stride_bn,
79 stride_cm,
80 stride_cn,
81 dtype: tl.constexpr,
82 BLOCK_M: tl.constexpr,
83 BLOCK_N: tl.constexpr,
84 BLOCK_K: tl.constexpr,
85 GROUP_M: tl.constexpr,
86):
87 # matrix multiplication
88 pid = tle.program_id(0)
89 grid_m = tl.cdiv(M, BLOCK_M)
90 grid_n = tl.cdiv(N, BLOCK_N)
91 # re-order program ID for better L2 performance
92 width = GROUP_M * grid_n
93 group_id = pid // width
94 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
95 pid_m = group_id * GROUP_M + (pid % group_size)
96 pid_n = (pid % width) // (group_size)
97 # do matrix multiplication
98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
100 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
102 rm = rm.to(tl.int64)
103 rn = rn.to(tl.int64)
104 prev_multiple = prev_multiple_of(K, BLOCK_K)
106 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
107 for start_k in range(0, prev_multiple, BLOCK_K):
108 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
109 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
110 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
111 if a.dtype != b.dtype:
112 a = a.to(C.dtype.element_ty)
113 b = b.to(C.dtype.element_ty)
114 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
116 # loop peeling
117 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
118 mask_k = rk < K
119 a = tl.load(
120 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :]
121 )
122 b = tl.load(
123 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None]
124 )
125 if a.dtype != b.dtype:
126 a = a.to(C.dtype.element_ty)
127 b = b.to(C.dtype.element_ty)
128 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
130 acc = acc.to(C.dtype.element_ty)
131 # rematerialize rm and rn to save registers
132 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
133 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
134 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
135 mask = (rm < M)[:, None] & (rn < N)[None, :]
136 # handles write-back with reduction-splitting
137 tl.store(C, acc, mask=mask)
140@libentry()
141@libtuner(
142 configs=runtime.ops_get_configs("gemv", yaml_path=EXPAND_CONFIG_FILENAME)
143 if os.environ.get("USE_FLAGTUNE") == "1"
144 else [triton.Config({"BLOCK_M": 64, "BLOCK_K": 64})],
145 key=["M", "K", "stride_am", "stride_bk"],
146 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[
147 "strategy"
148 ]
149 if os.environ.get("USE_FLAGTUNE") == "1"
150 else ["align32", "align32", "align32", "default"],
151 warmup=5,
152 rep=5,
153)
154@triton.jit
155def gemv_kernel(
156 A,
157 B,
158 C,
159 M,
160 K,
161 stride_am,
162 stride_ak,
163 stride_bk,
164 stride_cm,
165 BLOCK_M: tl.constexpr,
166 BLOCK_K: tl.constexpr,
167):
168 pid = tle.program_id(0)
170 row_start = pid * BLOCK_M
171 row_offset = row_start + tl.arange(0, BLOCK_M)
172 row_mask = row_offset < M
174 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
176 for k_start in range(0, K, BLOCK_K):
177 k_offset = k_start + tl.arange(0, BLOCK_K)
178 k_mask = k_offset < K
180 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
181 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
183 b_ptrs = B + k_offset * stride_bk
184 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
186 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely.
187 a = a.to(tl.float32)
188 b = b.to(tl.float32)
189 acc += tl.sum(a * b[None, :], axis=1)
191 c_ptrs = C + row_offset * stride_cm
192 acc = acc.to(C.dtype.element_ty)
193 tl.store(c_ptrs, acc, mask=row_mask)
196_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
199def get_higher_dtype(a, b):
200 if a is b:
201 return a
203 assert a in _ordered_datatypes
204 assert b in _ordered_datatypes
206 for d in _ordered_datatypes:
207 if a is d:
208 return b
209 if b is d:
210 return a
213def mm_fma(a, b):
214 logger.debug("GEMS_MTHREADS MM(FMA)")
215 device = a.device
216 # handle non-contiguous inputs if necessary
217 if a.stride(0) > 1 and a.stride(1) > 1:
218 a = a.contiguous()
219 if b.stride(0) > 1 and b.stride(1) > 1:
220 b = b.contiguous()
221 # checks constraints
222 assert a.shape[1] == b.shape[0], "incompatible dimensions"
223 M, K = a.shape
224 _, N = b.shape
225 # allocates output
226 c_dtype = get_higher_dtype(a.dtype, b.dtype)
227 c = torch.empty((M, N), device=device, dtype=c_dtype)
228 # launch kernel
229 grid = lambda META: (
230 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
231 )
232 with torch_device_fn.device(a.device):
233 mm_kernel[grid](
234 a,
235 b,
236 c,
237 M,
238 N,
239 K,
240 a.stride(0),
241 a.stride(1),
242 b.stride(0),
243 b.stride(1),
244 c.stride(0),
245 c.stride(1),
246 dtype=str(a.dtype).split(".")[-1],
247 GROUP_M=8,
248 )
249 return c
252def gemv_mm(a, b, c, M, K):
253 logger.debug(
254 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)",
255 M,
256 K,
257 )
258 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
259 with torch_device_fn.device(a.device):
260 gemv_kernel[grid](
261 a,
262 b,
263 c,
264 M,
265 K,
266 a.stride(0),
267 a.stride(1),
268 b.stride(0),
269 c.stride(0),
270 )
271 return c
274def mm_out(a, b, *, out):
275 logger.debug("GEMS_MTHREADS MM_OUT")
276 # handle non-contiguous inputs if necessary
277 if a.stride(0) > 1 and a.stride(1) > 1:
278 a = a.contiguous()
279 if b.stride(0) > 1 and b.stride(1) > 1:
280 b = b.contiguous()
281 # checks constraints
282 assert a.shape[1] == b.shape[0], "incompatible dimensions"
283 M, K = a.shape
284 _, N = b.shape
285 # allocates output
286 c = out
287 if N == 1:
288 return gemv_mm(a, b, c, M, K)
289 # launch kernel
290 grid = lambda META: (
291 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
292 )
293 with torch_device_fn.device(a.device):
294 mm_kernel[grid](
295 a,
296 b,
297 c,
298 M,
299 N,
300 K,
301 a.stride(0),
302 a.stride(1),
303 b.stride(0),
304 b.stride(1),
305 c.stride(0),
306 c.stride(1),
307 dtype=str(a.dtype).split(".")[-1],
308 GROUP_M=8,
309 )
310 return c
313def sqmma_descriptor_pre_hook(nargs):
314 a = nargs["A"]
315 b = nargs["B"]
316 c = nargs["C"]
317 block_m = nargs["BLOCK_M"]
318 block_n = nargs["BLOCK_N"]
319 block_k = nargs["BLOCK_K"]
320 device = c.device
322 nargs["a_desc_ptr"].copy_(
323 get_cached_tma_device_descriptor(a, block_m, block_k, device)
324 )
325 nargs["b_desc_ptr"].copy_(
326 get_cached_tma_device_descriptor(b, block_k, block_n, device)
327 )
328 nargs["c_desc_ptr"].copy_(create_tma_device_descriptor(c, block_m, block_n, device))
331@libentry()
332@libtuner(
333 configs=runtime.ops_get_configs(
334 "mm_general_tma",
335 pre_hook=sqmma_descriptor_pre_hook,
336 yaml_path=EXPAND_CONFIG_FILENAME,
337 )
338 if os.environ.get("USE_FLAGTUNE") == "1"
339 else [
340 triton.Config(
341 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64},
342 num_stages=1,
343 num_warps=4,
344 pre_hook=sqmma_descriptor_pre_hook,
345 )
346 ],
347 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
348 strategy=runtime.get_expand_config(
349 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
350 )["strategy"]
351 if os.environ.get("USE_FLAGTUNE") == "1"
352 else ["align32", "align32", "align32", "align32", "align32", "default"],
353 warmup=5,
354 rep=5,
355)
356@triton.jit
357def mm_sqmma_kernel(
358 A,
359 B,
360 C,
361 a_desc_ptr,
362 b_desc_ptr,
363 c_desc_ptr,
364 M,
365 N,
366 K,
367 stride_am,
368 stride_ak,
369 stride_bk,
370 stride_bn,
371 dtype: tl.constexpr,
372 GROUP_M: tl.constexpr,
373 BLOCK_M: tl.constexpr,
374 BLOCK_N: tl.constexpr,
375 BLOCK_K: tl.constexpr,
376 ab_dtype: tl.constexpr,
377 c_dtype: tl.constexpr,
378 is_transpose_a: tl.constexpr = False,
379 is_transpose_b: tl.constexpr = False,
380):
381 pid = tle.program_id(0)
382 grid_m = tl.cdiv(M, BLOCK_M)
383 grid_n = tl.cdiv(N, BLOCK_N)
384 width = GROUP_M * grid_n
385 group_id = pid // width
386 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
387 pid_m = group_id * GROUP_M + (pid % group_size)
388 pid_n = (pid % width) // (group_size)
389 offs_am = pid_m * BLOCK_M
390 offs_bn = pid_n * BLOCK_N
391 offs_k = 0
392 offs_am = offs_am.to(tl.int32)
393 offs_bn = offs_bn.to(tl.int32)
394 offs_k = offs_k.to(tl.int32)
395 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
396 tme_load_ab_dtype = ab_dtype
397 c_store_dtype = c_dtype
398 for k in range(0, tl.cdiv(K, BLOCK_K)):
399 if is_transpose_a:
400 a = tl._experimental_descriptor_load(
401 a_desc_ptr,
402 [offs_k, offs_am],
403 [BLOCK_K, BLOCK_M],
404 tme_load_ab_dtype,
405 )
406 a = tl.trans(a)
407 else:
408 a = tl._experimental_descriptor_load(
409 a_desc_ptr,
410 [offs_am, offs_k],
411 [BLOCK_M, BLOCK_K],
412 tme_load_ab_dtype,
413 )
414 if is_transpose_b:
415 b = tl._experimental_descriptor_load(
416 b_desc_ptr,
417 [offs_bn, offs_k],
418 [BLOCK_N, BLOCK_K],
419 tme_load_ab_dtype,
420 )
421 b = tl.trans(b)
422 else:
423 b = tl._experimental_descriptor_load(
424 b_desc_ptr,
425 [offs_k, offs_bn],
426 [BLOCK_K, BLOCK_N],
427 tme_load_ab_dtype,
428 )
429 accumulator += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
430 offs_k += BLOCK_K
431 accumulator = accumulator.to(c_store_dtype)
432 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
435def get_triton_type(elem_type):
436 type_map = {
437 torch.float16: tl.float16,
438 torch.bfloat16: tl.bfloat16,
439 torch.float8_e4m3fn: tl.float8e4nv,
440 }
441 return type_map.get(elem_type, None)
444def mm_sqmma(A, B, M, N, K, GROUP_M):
445 logger.debug("GEMS_MTHREADS MM(SQMMA)")
446 device = A.device
447 # handle non-contiguous inputs if necessary
448 is_transpose_a = False
449 is_transpose_b = False
450 if not A.is_contiguous():
451 if A.stride(0) == 1 and A.stride(1) == A.shape[0]:
452 is_transpose_a = True
453 else:
454 A = A.contiguous()
455 if not B.is_contiguous():
456 if B.stride(0) == 1 and B.stride(1) == B.shape[0]:
457 is_transpose_b = True
458 else:
459 B = B.contiguous()
460 a_type = A.dtype
461 b_type = B.dtype
462 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
463 c_dtype = get_higher_dtype(a_type, b_type)
464 C = torch.empty((M, N), dtype=c_dtype, device=device)
465 desc_a = torch.empty((64,), dtype=torch.int8, device=device)
466 desc_b = torch.empty((64,), dtype=torch.int8, device=device)
467 desc_c = torch.empty((64,), dtype=torch.int8, device=device)
468 grid = lambda META: (
469 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
470 1,
471 1,
472 )
473 mm_sqmma_kernel[grid](
474 A,
475 B,
476 C,
477 desc_a,
478 desc_b,
479 desc_c,
480 M,
481 N,
482 K,
483 A.stride(0),
484 A.stride(1),
485 B.stride(0),
486 B.stride(1),
487 str(a_type).split(".")[-1],
488 GROUP_M=GROUP_M,
489 ab_dtype=get_triton_type(a_type),
490 c_dtype=get_triton_type(c_dtype),
491 is_transpose_a=is_transpose_a,
492 is_transpose_b=is_transpose_b,
493 )
494 return C
497def mm(a, b):
498 a_dtype = a.dtype
499 b_dtype = b.dtype
500 M, K = a.shape
501 _, N = b.shape
502 # fp32 does not support MMA instructions, only enable SQMMA for fp16/bf16
503 need_sqmma = a_dtype != torch.float32 and b_dtype != torch.float32
504 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA")
505 if need_sqmma:
506 os.environ["MUSA_ENABLE_SQMMA"] = "1"
507 else:
508 os.environ.pop("MUSA_ENABLE_SQMMA", None)
509 try:
510 if N == 1:
511 c_dtype = get_higher_dtype(a_dtype, b_dtype)
512 c = torch.empty((M, N), device=a.device, dtype=c_dtype)
513 return gemv_mm(a, b, c, M, K)
515 if is_sqmma_compatible(a, b, N, K):
516 GROUP_M = 8
517 return mm_sqmma(
518 a,
519 b,
520 M,
521 N,
522 K,
523 GROUP_M,
524 )
525 else:
526 return mm_fma(a, b)
527 finally:
528 if prev_sqmma is None:
529 os.environ.pop("MUSA_ENABLE_SQMMA", None)
530 else:
531 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma