Coverage for src/flag_gems/runtime/backend/_mthreads/ops/bmm.py: 0%
135 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
7from triton.tools.tensor_descriptor import TensorDescriptor
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
18EXPAND_CONFIG_FILENAME = os.path.normpath(
19 os.path.join(os.path.dirname(__file__), "..", "bmm_mthreads_expand.yaml")
20)
23def is_supported_sqmma_layout(tensor):
24 return tensor.is_contiguous() or (
25 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
26 )
29def is_sqmma_compatible(a, b, N, K):
30 return (
31 a.dtype == b.dtype
32 and a.dtype in (torch.float16, torch.bfloat16)
33 and is_supported_sqmma_layout(a)
34 and is_supported_sqmma_layout(b)
35 and N % 8 == 0
36 and K % 8 == 0
37 )
40@libentry()
41@libtuner(
42 configs=runtime.get_tuned_config("bmm"),
43 key=["M", "N", "K"],
44 strategy=["align32", "align32", "align32"],
45)
46@triton.heuristics(runtime.get_heuristic_config("bmm"))
47@triton.jit
48def bmm_kernel(
49 A,
50 B,
51 O,
52 M,
53 N,
54 K,
55 TILE_M: tl.constexpr,
56 TILE_N: tl.constexpr,
57 TILE_K: tl.constexpr,
58 GROUP_M: tl.constexpr,
59 DIVISIBLE_M: tl.constexpr,
60 DIVISIBLE_N: tl.constexpr,
61 DIVISIBLE_K: tl.constexpr,
62 IS_FP64: tl.constexpr = False,
63):
64 # batch offsets
65 pid_b = ext.program_id(2)
66 A += pid_b * M * K
67 B += pid_b * K * N
68 O += pid_b * M * N
70 pidx = ext.program_id(0)
71 pidy = ext.program_id(1)
73 if GROUP_M == 1:
74 pid_m, pid_n = pidx, pidy
75 else:
76 # reorder CTAs
77 gridx = ext.num_programs(0)
78 gridy = ext.num_programs(1)
79 pid = pidx + pidy * gridx
81 num_CTA_per_group = gridy * GROUP_M
83 group_id = pid // num_CTA_per_group
84 inner_group_id = pid % num_CTA_per_group
85 GROUP_SIZE = tl.where(
86 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
87 )
88 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
89 pid_n = inner_group_id // GROUP_SIZE
91 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
92 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
93 offs_k = tl.arange(0, TILE_K)
95 if not DIVISIBLE_M:
96 mask_m = offs_m < M
97 if not DIVISIBLE_N:
98 mask_n = offs_n < N
100 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
101 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
102 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
104 num_iters = tl.cdiv(K, TILE_K)
105 if IS_FP64:
106 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float64)
107 else:
108 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
109 for _ in range(num_iters):
110 if DIVISIBLE_K:
111 if DIVISIBLE_M:
112 mask_a = None
113 else:
114 mask_a = mask_m[:, None]
115 if DIVISIBLE_N:
116 mask_b = None
117 else:
118 mask_b = mask_n[None, :]
119 else:
120 mask_k = offs_k < K
121 if DIVISIBLE_M:
122 mask_a = mask_k[None, :]
123 else:
124 mask_a = mask_m[:, None] & mask_k[None, :]
125 if DIVISIBLE_N:
126 mask_b = mask_k[:, None]
127 else:
128 mask_b = mask_k[:, None] & mask_n[None, :]
130 a = tl.load(a_ptrs, mask_a)
131 b = tl.load(b_ptrs, mask_b)
133 offs_k += TILE_K
134 a_ptrs += TILE_K
135 b_ptrs += TILE_K * N
137 o += tl.dot(a, b, allow_tf32=False)
139 if DIVISIBLE_M and DIVISIBLE_N:
140 mask_c = None
141 elif DIVISIBLE_M and not DIVISIBLE_N:
142 mask_c = mask_n[None, :]
143 elif not DIVISIBLE_M and DIVISIBLE_N:
144 mask_c = mask_m[:, None]
145 else:
146 mask_c = mask_m[:, None] & mask_n[None, :]
147 tl.store(o_ptrs, o, mask_c)
150def bmm_fma(A, B):
151 logger.debug("GEMS_MTHREADS BMM(FMA)")
152 batch, M, K = A.shape
153 _, _, N = B.shape
154 A = A.contiguous()
155 B = B.contiguous()
156 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
158 grid_fn = lambda meta: (
159 triton.cdiv(meta["M"], meta["TILE_M"]),
160 triton.cdiv(meta["N"], meta["TILE_N"]),
161 batch,
162 )
163 with torch_device_fn.device(A.device):
164 bmm_kernel[grid_fn](A, B, out, M, N, K, IS_FP64=A.dtype == torch.float64)
165 return out
168def bmm_sqmma_descriptor_pre_hook(nargs):
169 nargs["a_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_K"]]
170 nargs["b_desc"].block_shape = [nargs["BLOCK_SIZE_K"], nargs["BLOCK_SIZE_N"]]
171 nargs["c_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_N"]]
174@libentry()
175@libtuner(
176 configs=runtime.ops_get_configs(
177 "bmm_sqmma",
178 pre_hook=bmm_sqmma_descriptor_pre_hook,
179 yaml_path=EXPAND_CONFIG_FILENAME,
180 )
181 if os.environ.get("USE_FLAGTUNE") == "1"
182 else [
183 triton.Config(
184 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
185 num_stages=1,
186 num_warps=4,
187 pre_hook=bmm_sqmma_descriptor_pre_hook,
188 )
189 ],
190 key=["M", "N", "K"],
191 strategy=runtime.get_expand_config("bmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[
192 "strategy"
193 ][:3]
194 if os.environ.get("USE_FLAGTUNE") == "1"
195 else ["align32", "align32", "align32"],
196 warmup=5,
197 rep=5,
198)
199@triton.jit
200def bmm_sqmma_kernel(
201 a_desc,
202 b_desc,
203 c_desc,
204 batch,
205 M,
206 N,
207 K,
208 BLOCK_SIZE_M: tl.constexpr,
209 BLOCK_SIZE_N: tl.constexpr,
210 BLOCK_SIZE_K: tl.constexpr,
211):
212 pid = tl.program_id(axis=0)
213 batch_index = tl.program_id(axis=1)
214 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
215 pid_m = pid % num_pid_m
216 pid_n = pid // num_pid_m
217 offs_am = (pid_m * BLOCK_SIZE_M + batch_index * M).to(tl.int32)
218 offs_bn = (pid_n * BLOCK_SIZE_N).to(tl.int32)
219 offs_ak = 0
220 offs_ak = offs_ak.to(tl.int32)
221 offs_bk = (batch_index * K).to(tl.int32)
222 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
223 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
224 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_ak])
225 b = tl.load_tensor_descriptor(b_desc, [offs_bk, offs_bn])
226 accumulator = tl.dot(a, b, acc=accumulator)
227 offs_ak += BLOCK_SIZE_K
228 offs_bk += BLOCK_SIZE_K
229 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], accumulator.to(c_desc.dtype))
232def bmm_sqmma(A, B, elem_type, batch, M, N, K):
233 device = "musa"
234 c_type = elem_type if (elem_type != torch.bfloat16) else torch.float16
235 C = torch.empty((batch, M, N), dtype=torch.float16, device=device).to(c_type)
236 desc_a = TensorDescriptor.from_tensor(A.reshape(batch * M, K), [1, 1])
237 desc_b = TensorDescriptor.from_tensor(B.reshape(batch * K, N), [1, 1])
238 desc_c = TensorDescriptor.from_tensor(C.reshape(batch * M, N), [1, 1])
239 grid = lambda META: (
240 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
241 batch,
242 1,
243 )
244 bmm_sqmma_kernel[grid](
245 desc_a,
246 desc_b,
247 desc_c,
248 batch,
249 M,
250 N,
251 K,
252 )
253 return C
256def bmm(a, b):
257 a_dtype = a.dtype
258 batch, M, K = a.shape
259 _, _, N = b.shape
260 if is_sqmma_compatible(a, b, N, K) and M >= 128:
261 return bmm_sqmma(a, b, a_dtype, batch, M, N, K)
262 else:
263 return bmm_fma(a, b)