Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%
206 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
7from triton.tools.tensor_descriptor import TensorDescriptor
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm")
16EXPAND_CONFIG_FILENAME = os.path.normpath(
17 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml")
18)
20# Module-level capability flag: evaluated once at import time, then reused as
21# a constant for the entire process lifetime with no repeated parsing overhead.
22# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2.
23SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2)
26def is_supported_sqmma_layout(tensor):
27 return tensor.is_contiguous() or (
28 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
29 )
32def is_sqmma_compatible(a, b, N, K):
33 return (
34 SQMMA_ON
35 and a.dim() == 2
36 and b.dim() == 2
37 and a.dtype == b.dtype
38 and a.dtype in (torch.float16, torch.bfloat16)
39 and is_supported_sqmma_layout(a)
40 and is_supported_sqmma_layout(b)
41 and N % 8 == 0
42 and K % 8 == 0
43 )
46@triton.jit
47def prev_multiple_of(a, b):
48 # the largest x<a that x%b ==0
49 return tl.cdiv(a, b) * b - b
52@libentry()
53@libtuner(
54 configs=runtime.ops_get_configs("mm", yaml_path=EXPAND_CONFIG_FILENAME)
55 if os.environ.get("USE_FLAGTUNE") == "1"
56 else runtime.get_tuned_config("mm"),
57 key=["M", "N", "K", "stride_am", "stride_bk"],
58 strategy=runtime.get_expand_config("mm", yaml_path=EXPAND_CONFIG_FILENAME)[
59 "strategy"
60 ]
61 if os.environ.get("USE_FLAGTUNE") == "1"
62 else ["align32", "align32", "align32", "align32", "align32"],
63 warmup=5,
64 rep=5,
65)
66@triton.jit
67def mm_kernel(
68 A,
69 B,
70 C,
71 M,
72 N,
73 K,
74 stride_am,
75 stride_ak,
76 stride_bk,
77 stride_bn,
78 stride_cm,
79 stride_cn,
80 dtype: tl.constexpr,
81 BLOCK_M: tl.constexpr,
82 BLOCK_N: tl.constexpr,
83 BLOCK_K: tl.constexpr,
84 GROUP_M: tl.constexpr,
85 IS_FP64: tl.constexpr = False,
86):
87 # matrix multiplication
88 pid = ext.program_id(0)
89 grid_m = tl.cdiv(M, BLOCK_M)
90 grid_n = tl.cdiv(N, BLOCK_N)
91 # re-order program ID for better L2 performance
92 width = GROUP_M * grid_n
93 group_id = pid // width
94 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
95 pid_m = group_id * GROUP_M + (pid % group_size)
96 pid_n = (pid % width) // (group_size)
97 # do matrix multiplication
98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
100 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
101 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
102 rm = rm.to(tl.int64)
103 rn = rn.to(tl.int64)
104 prev_multiple = prev_multiple_of(K, BLOCK_K)
106 if IS_FP64:
107 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
108 else:
109 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
110 for start_k in range(0, prev_multiple, BLOCK_K):
111 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
112 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
113 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
114 if a.dtype != b.dtype:
115 a = a.to(C.dtype.element_ty)
116 b = b.to(C.dtype.element_ty)
117 if IS_FP64:
118 acc += tl.dot(a, b, allow_tf32=False)
119 else:
120 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
122 # loop peeling
123 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
124 mask_k = rk < K
125 a = tl.load(
126 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :]
127 )
128 b = tl.load(
129 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None]
130 )
131 if a.dtype != b.dtype:
132 a = a.to(C.dtype.element_ty)
133 b = b.to(C.dtype.element_ty)
134 if IS_FP64:
135 acc += tl.dot(a, b, allow_tf32=False)
136 else:
137 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
139 acc = acc.to(C.dtype.element_ty)
140 # rematerialize rm and rn to save registers
141 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
142 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
143 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
144 mask = (rm < M)[:, None] & (rn < N)[None, :]
145 # handles write-back with reduction-splitting
146 tl.store(C, acc, mask=mask)
149@libentry()
150@libtuner(
151 configs=runtime.ops_get_configs("gemv", yaml_path=EXPAND_CONFIG_FILENAME)
152 if os.environ.get("USE_FLAGTUNE") == "1"
153 else [
154 triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}),
155 triton.Config({"BLOCK_M": 128, "BLOCK_K": 64}),
156 ],
157 key=["M", "K", "stride_am", "stride_bk"],
158 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[
159 "strategy"
160 ]
161 if os.environ.get("USE_FLAGTUNE") == "1"
162 else ["align32", "align32", "align32", "default"],
163 warmup=5,
164 rep=5,
165)
166@triton.jit
167def gemv_kernel(
168 A,
169 B,
170 C,
171 M,
172 K,
173 stride_am,
174 stride_ak,
175 stride_bk,
176 stride_cm,
177 BLOCK_M: tl.constexpr,
178 BLOCK_K: tl.constexpr,
179):
180 pid = ext.program_id(0)
182 row_start = pid * BLOCK_M
183 row_offset = row_start + tl.arange(0, BLOCK_M)
184 row_mask = row_offset < M
186 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
188 for k_start in range(0, K, BLOCK_K):
189 k_offset = k_start + tl.arange(0, BLOCK_K)
190 k_mask = k_offset < K
192 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
193 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
195 b_ptrs = B + k_offset * stride_bk
196 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
198 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely.
199 a = a.to(tl.float32)
200 b = b.to(tl.float32)
201 acc += tl.sum(a * b[None, :], axis=1)
203 c_ptrs = C + row_offset * stride_cm
204 acc = acc.to(C.dtype.element_ty)
205 tl.store(c_ptrs, acc, mask=row_mask)
208_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
211def get_higher_dtype(a, b):
212 if a is b:
213 return a
215 assert a in _ordered_datatypes
216 assert b in _ordered_datatypes
218 for d in _ordered_datatypes:
219 if a is d:
220 return b
221 if b is d:
222 return a
225def mm_fma(a, b):
226 logger.debug("GEMS_MTHREADS MM(FMA)")
227 device = a.device
228 # handle non-contiguous inputs if necessary
229 if a.stride(0) > 1 and a.stride(1) > 1:
230 a = a.contiguous()
231 if b.stride(0) > 1 and b.stride(1) > 1:
232 b = b.contiguous()
233 # checks constraints
234 assert a.shape[1] == b.shape[0], "incompatible dimensions"
235 M, K = a.shape
236 _, N = b.shape
237 # allocates output
238 c_dtype = get_higher_dtype(a.dtype, b.dtype)
239 c = torch.empty((M, N), device=device, dtype=c_dtype)
240 # launch kernel
241 grid = lambda META: (
242 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
243 )
244 with torch_device_fn.device(a.device):
245 mm_kernel[grid](
246 a,
247 b,
248 c,
249 M,
250 N,
251 K,
252 a.stride(0),
253 a.stride(1),
254 b.stride(0),
255 b.stride(1),
256 c.stride(0),
257 c.stride(1),
258 dtype=str(a.dtype).split(".")[-1],
259 GROUP_M=8,
260 IS_FP64=a.dtype == torch.float64,
261 )
262 return c
265def gemv_mm(a, b, c, M, K):
266 logger.debug(
267 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)",
268 M,
269 K,
270 )
271 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
272 with torch_device_fn.device(a.device):
273 gemv_kernel[grid](
274 a,
275 b,
276 c,
277 M,
278 K,
279 a.stride(0),
280 a.stride(1),
281 b.stride(0),
282 c.stride(0),
283 )
284 return c
287def mm_out(a, b, *, out):
288 logger.debug("GEMS_MTHREADS MM_OUT")
289 # handle non-contiguous inputs if necessary
290 if a.stride(0) > 1 and a.stride(1) > 1:
291 a = a.contiguous()
292 if b.stride(0) > 1 and b.stride(1) > 1:
293 b = b.contiguous()
294 # checks constraints
295 assert a.shape[1] == b.shape[0], "incompatible dimensions"
296 M, K = a.shape
297 _, N = b.shape
298 # allocates output
299 c = out
300 if N == 1:
301 return gemv_mm(a, b, c, M, K)
302 # launch kernel
303 grid = lambda META: (
304 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
305 )
306 with torch_device_fn.device(a.device):
307 mm_kernel[grid](
308 a,
309 b,
310 c,
311 M,
312 N,
313 K,
314 a.stride(0),
315 a.stride(1),
316 b.stride(0),
317 b.stride(1),
318 c.stride(0),
319 c.stride(1),
320 dtype=str(a.dtype).split(".")[-1],
321 GROUP_M=8,
322 IS_FP64=a.dtype == torch.float64,
323 )
324 return c
327def matmul_sqmma_set_block_size_hook(nargs):
328 BLOCK_M = nargs["BLOCK_M"]
329 BLOCK_N = nargs["BLOCK_N"]
330 BLOCK_K = nargs["BLOCK_K"]
331 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
332 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
333 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
336def sqmma_get_configs(pre_hook=matmul_sqmma_set_block_size_hook):
337 return [
338 triton.Config(
339 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128},
340 num_stages=1,
341 num_warps=4,
342 pre_hook=pre_hook,
343 ),
344 triton.Config(
345 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64},
346 num_stages=1,
347 num_warps=4,
348 pre_hook=pre_hook,
349 ),
350 ]
353@libentry()
354@libtuner(
355 configs=sqmma_get_configs(),
356 key=["M", "N", "K", "dtype"],
357 strategy=["align32", "align32", "align32", "default"],
358)
359@triton.jit
360def mm_sqmma_kernel(
361 a_desc,
362 b_desc,
363 c_desc,
364 M,
365 N,
366 K,
367 dtype: tl.constexpr,
368 GROUP_M: tl.constexpr,
369 BLOCK_M: tl.constexpr,
370 BLOCK_N: tl.constexpr,
371 BLOCK_K: tl.constexpr,
372):
373 pid = ext.program_id(0)
374 grid_m = tl.cdiv(M, BLOCK_M)
375 grid_n = tl.cdiv(N, BLOCK_N)
376 width = GROUP_M * grid_n
377 group_id = pid // width
378 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
379 pid_m = group_id * GROUP_M + (pid % group_size)
380 pid_n = (pid % width) // (group_size)
381 offs_am = (pid_m * BLOCK_M).to(tl.int32)
382 offs_bn = (pid_n * BLOCK_N).to(tl.int32)
383 offs_k = 0
384 offs_k = offs_k.to(tl.int32)
385 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
386 for k in range(0, tl.cdiv(K, BLOCK_K)):
387 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
388 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
389 accumulator = tl.dot(a, b, acc=accumulator)
390 offs_k += BLOCK_K
391 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype))
394def get_triton_type(elem_type):
395 type_map = {
396 torch.float16: tl.float16,
397 torch.bfloat16: tl.bfloat16,
398 torch.float8_e4m3fn: tl.float8e4nv,
399 }
400 return type_map.get(elem_type, None)
403def mm_sqmma(A, B, M, N, K, GROUP_M):
404 logger.debug("GEMS_MTHREADS MM(SQMMA)")
405 device = A.device
406 if not A.is_contiguous():
407 A = A.contiguous()
408 if not B.is_contiguous():
409 B = B.contiguous()
410 a_type = A.dtype
411 b_type = B.dtype
412 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
413 c_dtype = get_higher_dtype(a_type, b_type)
414 C = torch.empty((M, N), dtype=c_dtype, device=device)
415 # Real block_shape values are filled in by matmul_sqmma_set_block_size_hook
416 # at autotune/launch time based on the BLOCK_M/N/K selected by libtuner.
417 dummy_block = [1, 1]
418 desc_a = TensorDescriptor(A, A.shape, A.stride(), dummy_block)
419 desc_b = TensorDescriptor(B, B.shape, B.stride(), dummy_block)
420 desc_c = TensorDescriptor(C, C.shape, C.stride(), dummy_block)
421 grid = lambda META: (
422 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
423 1,
424 1,
425 )
426 mm_sqmma_kernel[grid](
427 desc_a,
428 desc_b,
429 desc_c,
430 M,
431 N,
432 K,
433 dtype=str(a_type).split(".")[-1],
434 GROUP_M=GROUP_M,
435 )
436 return C
439def mm(a, b):
440 a_dtype = a.dtype
441 b_dtype = b.dtype
442 M, K = a.shape
443 _, N = b.shape
444 if N == 1:
445 c_dtype = get_higher_dtype(a_dtype, b_dtype)
446 c = torch.empty((M, N), device=a.device, dtype=c_dtype)
447 return gemv_mm(a, b, c, M, K)
449 if is_sqmma_compatible(a, b, N, K):
450 GROUP_M = 8
451 return mm_sqmma(
452 a,
453 b,
454 M,
455 N,
456 K,
457 GROUP_M,
458 )
459 else:
460 return mm_fma(a, b)