Coverage for src/flag_gems/runtime/backend/_mthreads/ops/mm.py: 0%
196 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
7from triton.tools.tensor_descriptor import TensorDescriptor
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.mm")
16EXPAND_CONFIG_FILENAME = os.path.normpath(
17 os.path.join(os.path.dirname(__file__), "..", "mm_mthreads_expand.yaml")
18)
20# Module-level capability flag: evaluated once at import time, then reused as
21# a constant for the entire process lifetime with no repeated parsing overhead.
22# False when Triton < 3.2 (e.g. 3.1), True when Triton >= 3.2.
23SQMMA_ON = tuple(int(x) for x in triton.__version__.split(".")[:2]) >= (3, 2)
26def is_supported_sqmma_layout(tensor):
27 return tensor.is_contiguous() or (
28 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
29 )
32def is_sqmma_compatible(a, b, N, K):
33 return (
34 SQMMA_ON
35 and a.dim() == 2
36 and b.dim() == 2
37 and a.dtype == b.dtype
38 and a.dtype in (torch.float16, torch.bfloat16)
39 and is_supported_sqmma_layout(a)
40 and is_supported_sqmma_layout(b)
41 and N % 8 == 0
42 and K % 8 == 0
43 )
46@triton.jit
47def prev_multiple_of(a, b):
48 # the largest x<a that x%b ==0
49 return tl.cdiv(a, b) * b - b
52@libentry()
53@libtuner(
54 configs=runtime.get_tuned_config("mm"),
55 key=["M", "N", "K", "stride_am", "stride_bk"],
56 strategy=["align32", "align32", "align32", "align32", "align32"],
57 warmup=5,
58 rep=5,
59 flagtune_op_name="mm",
60 flagtune_expand_op_name="mm",
61 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
62)
63@triton.jit
64def mm_kernel(
65 A,
66 B,
67 C,
68 M,
69 N,
70 K,
71 stride_am,
72 stride_ak,
73 stride_bk,
74 stride_bn,
75 stride_cm,
76 stride_cn,
77 dtype: tl.constexpr,
78 BLOCK_M: tl.constexpr,
79 BLOCK_N: tl.constexpr,
80 BLOCK_K: tl.constexpr,
81 GROUP_M: tl.constexpr,
82 IS_FP64: tl.constexpr = False,
83):
84 # matrix multiplication
85 pid = ext.program_id(0)
86 grid_m = tl.cdiv(M, BLOCK_M)
87 grid_n = tl.cdiv(N, BLOCK_N)
88 # re-order program ID for better L2 performance
89 width = GROUP_M * grid_n
90 group_id = pid // width
91 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
92 pid_m = group_id * GROUP_M + (pid % group_size)
93 pid_n = (pid % width) // (group_size)
94 # do matrix multiplication
95 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
96 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
97 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
98 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
99 rm = rm.to(tl.int64)
100 rn = rn.to(tl.int64)
101 prev_multiple = prev_multiple_of(K, BLOCK_K)
103 if IS_FP64:
104 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
105 else:
106 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
107 for start_k in range(0, prev_multiple, BLOCK_K):
108 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
109 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
110 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
111 if a.dtype != b.dtype:
112 a = a.to(C.dtype.element_ty)
113 b = b.to(C.dtype.element_ty)
114 if IS_FP64:
115 acc += tl.dot(a, b, allow_tf32=False)
116 else:
117 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
119 # loop peeling
120 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
121 mask_k = rk < K
122 a = tl.load(
123 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), mask=mask_k[None, :]
124 )
125 b = tl.load(
126 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), mask=mask_k[:, None]
127 )
128 if a.dtype != b.dtype:
129 a = a.to(C.dtype.element_ty)
130 b = b.to(C.dtype.element_ty)
131 if IS_FP64:
132 acc += tl.dot(a, b, allow_tf32=False)
133 else:
134 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
136 acc = acc.to(C.dtype.element_ty)
137 # rematerialize rm and rn to save registers
138 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
139 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
140 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
141 mask = (rm < M)[:, None] & (rn < N)[None, :]
142 # handles write-back with reduction-splitting
143 tl.store(C, acc, mask=mask)
146@libentry()
147@libtuner(
148 configs=[
149 triton.Config({"BLOCK_M": 64, "BLOCK_K": 64}),
150 triton.Config({"BLOCK_M": 128, "BLOCK_K": 64}),
151 ],
152 key=["M", "K", "stride_am", "stride_bk"],
153 strategy=["align32", "align32", "align32", "default"],
154 warmup=5,
155 rep=5,
156 flagtune_op_name="mm",
157 flagtune_expand_op_name="gemv",
158 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
159)
160@triton.jit
161def gemv_kernel(
162 A,
163 B,
164 C,
165 M,
166 K,
167 stride_am,
168 stride_ak,
169 stride_bk,
170 stride_cm,
171 BLOCK_M: tl.constexpr,
172 BLOCK_K: tl.constexpr,
173):
174 pid = ext.program_id(0)
176 row_start = pid * BLOCK_M
177 row_offset = row_start + tl.arange(0, BLOCK_M)
178 row_mask = row_offset < M
180 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
182 for k_start in range(0, K, BLOCK_K):
183 k_offset = k_start + tl.arange(0, BLOCK_K)
184 k_mask = k_offset < K
186 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
187 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
189 b_ptrs = B + k_offset * stride_bk
190 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
192 # Keep the reduction in fp32 so N=1 GEMV matches the mm path more closely.
193 a = a.to(tl.float32)
194 b = b.to(tl.float32)
195 acc += tl.sum(a * b[None, :], axis=1)
197 c_ptrs = C + row_offset * stride_cm
198 acc = acc.to(C.dtype.element_ty)
199 tl.store(c_ptrs, acc, mask=row_mask)
202_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
205def get_higher_dtype(a, b):
206 if a is b:
207 return a
209 assert a in _ordered_datatypes
210 assert b in _ordered_datatypes
212 for d in _ordered_datatypes:
213 if a is d:
214 return b
215 if b is d:
216 return a
219def mm_fma(a, b):
220 logger.debug("GEMS_MTHREADS MM(FMA)")
221 device = a.device
222 # handle non-contiguous inputs if necessary
223 if a.stride(0) > 1 and a.stride(1) > 1:
224 a = a.contiguous()
225 if b.stride(0) > 1 and b.stride(1) > 1:
226 b = b.contiguous()
227 # checks constraints
228 assert a.shape[1] == b.shape[0], "incompatible dimensions"
229 M, K = a.shape
230 _, N = b.shape
231 # allocates output
232 c_dtype = get_higher_dtype(a.dtype, b.dtype)
233 c = torch.empty((M, N), device=device, dtype=c_dtype)
234 # launch kernel
235 grid = lambda META: (
236 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
237 )
238 with torch_device_fn.device(a.device):
239 mm_kernel[grid](
240 a,
241 b,
242 c,
243 M,
244 N,
245 K,
246 a.stride(0),
247 a.stride(1),
248 b.stride(0),
249 b.stride(1),
250 c.stride(0),
251 c.stride(1),
252 dtype=str(a.dtype).split(".")[-1],
253 GROUP_M=8,
254 IS_FP64=a.dtype == torch.float64,
255 )
256 return c
259def gemv_mm(a, b, c, M, K):
260 logger.debug(
261 "GEMS_MTHREADS MM(GEMV), [shape info]: [%s, %s, 1](M, K, N)",
262 M,
263 K,
264 )
265 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
266 with torch_device_fn.device(a.device):
267 gemv_kernel[grid](
268 a,
269 b,
270 c,
271 M,
272 K,
273 a.stride(0),
274 a.stride(1),
275 b.stride(0),
276 c.stride(0),
277 )
278 return c
281def mm_out(a, b, *, out):
282 logger.debug("GEMS_MTHREADS MM_OUT")
283 # handle non-contiguous inputs if necessary
284 if a.stride(0) > 1 and a.stride(1) > 1:
285 a = a.contiguous()
286 if b.stride(0) > 1 and b.stride(1) > 1:
287 b = b.contiguous()
288 # checks constraints
289 assert a.shape[1] == b.shape[0], "incompatible dimensions"
290 M, K = a.shape
291 _, N = b.shape
292 # allocates output
293 c = out
294 if N == 1:
295 return gemv_mm(a, b, c, M, K)
296 # launch kernel
297 grid = lambda META: (
298 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
299 )
300 with torch_device_fn.device(a.device):
301 mm_kernel[grid](
302 a,
303 b,
304 c,
305 M,
306 N,
307 K,
308 a.stride(0),
309 a.stride(1),
310 b.stride(0),
311 b.stride(1),
312 c.stride(0),
313 c.stride(1),
314 dtype=str(a.dtype).split(".")[-1],
315 GROUP_M=8,
316 IS_FP64=a.dtype == torch.float64,
317 )
318 return c
321def sqmma_descriptor_pre_hook(nargs):
322 nargs["a_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_K"]]
323 nargs["b_desc"].block_shape = [nargs["BLOCK_K"], nargs["BLOCK_N"]]
324 nargs["c_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_N"]]
327@libentry()
328@libtuner(
329 configs=runtime.ops_get_configs(
330 "mm_general_tma",
331 pre_hook=sqmma_descriptor_pre_hook,
332 yaml_path=EXPAND_CONFIG_FILENAME,
333 )
334 if os.environ.get("USE_FLAGTUNE") == "1"
335 else [
336 triton.Config(
337 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8},
338 num_stages=1,
339 num_warps=4,
340 pre_hook=sqmma_descriptor_pre_hook,
341 )
342 ],
343 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
344 strategy=runtime.get_expand_config(
345 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
346 )["strategy"]
347 if os.environ.get("USE_FLAGTUNE") == "1"
348 else ["align32", "align32", "align32", "align32", "align32", "default"],
349 warmup=5,
350 rep=5,
351)
352@triton.jit
353def mm_sqmma_kernel(
354 a_desc,
355 b_desc,
356 c_desc,
357 M,
358 N,
359 K,
360 dtype: tl.constexpr,
361 GROUP_M: tl.constexpr,
362 BLOCK_M: tl.constexpr,
363 BLOCK_N: tl.constexpr,
364 BLOCK_K: tl.constexpr,
365):
366 pid = ext.program_id(0)
367 grid_m = tl.cdiv(M, BLOCK_M)
368 grid_n = tl.cdiv(N, BLOCK_N)
369 width = GROUP_M * grid_n
370 group_id = pid // width
371 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
372 pid_m = group_id * GROUP_M + (pid % group_size)
373 pid_n = (pid % width) // (group_size)
374 offs_am = (pid_m * BLOCK_M).to(tl.int32)
375 offs_bn = (pid_n * BLOCK_N).to(tl.int32)
376 offs_k = 0
377 offs_k = offs_k.to(tl.int32)
378 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
379 for k in range(0, tl.cdiv(K, BLOCK_K)):
380 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
381 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
382 accumulator = tl.dot(a, b, acc=accumulator)
383 offs_k += BLOCK_K
384 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype))
387def mm_sqmma(A, B, M, N, K):
388 logger.debug("GEMS_MTHREADS MM(SQMMA)")
389 device = A.device
390 if not A.is_contiguous():
391 A = A.contiguous()
392 if not B.is_contiguous():
393 B = B.contiguous()
394 a_type = A.dtype
395 b_type = B.dtype
396 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
397 c_dtype = get_higher_dtype(a_type, b_type)
398 C = torch.empty((M, N), dtype=c_dtype, device=device)
399 desc_a = TensorDescriptor.from_tensor(A, [1, 1])
400 desc_b = TensorDescriptor.from_tensor(B, [1, 1])
401 desc_c = TensorDescriptor.from_tensor(C, [1, 1])
402 grid = lambda META: (
403 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
404 1,
405 1,
406 )
407 mm_sqmma_kernel[grid](
408 desc_a,
409 desc_b,
410 desc_c,
411 M,
412 N,
413 K,
414 str(a_type).split(".")[-1],
415 )
416 return C
419def mm(a, b):
420 a_dtype = a.dtype
421 b_dtype = b.dtype
422 M, K = a.shape
423 _, N = b.shape
424 if N == 1:
425 c_dtype = get_higher_dtype(a_dtype, b_dtype)
426 c = torch.empty((M, N), device=a.device, dtype=c_dtype)
427 return gemv_mm(a, b, c, M, K)
429 if is_sqmma_compatible(a, b, N, K):
430 return mm_sqmma(
431 a,
432 b,
433 M,
434 N,
435 K,
436 )
437 else:
438 return mm_fma(a, b)