Coverage for src/flag_gems/runtime/backend/_mthreads/ops/bmm.py: 0%
159 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
13from .utils import create_tma_device_descriptor, get_cached_tma_device_descriptor
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(os.path.dirname(__file__), "..", "bmm_mthreads_expand.yaml")
21)
24def is_supported_sqmma_layout(tensor):
25 return tensor.is_contiguous() or (
26 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
27 )
30def is_sqmma_compatible(a, b, N, K):
31 return (
32 a.dtype == b.dtype
33 and a.dtype in (torch.float16, torch.bfloat16)
34 and is_supported_sqmma_layout(a)
35 and is_supported_sqmma_layout(b)
36 and N % 8 == 0
37 and K % 8 == 0
38 )
41@libentry()
42@libtuner(
43 configs=runtime.get_tuned_config("bmm"),
44 key=["M", "N", "K"],
45 strategy=["align32", "align32", "align32"],
46)
47@triton.heuristics(runtime.get_heuristic_config("bmm"))
48@triton.jit
49def bmm_kernel(
50 A,
51 B,
52 O,
53 M,
54 N,
55 K,
56 TILE_M: tl.constexpr,
57 TILE_N: tl.constexpr,
58 TILE_K: tl.constexpr,
59 GROUP_M: tl.constexpr,
60 DIVISIBLE_M: tl.constexpr,
61 DIVISIBLE_N: tl.constexpr,
62 DIVISIBLE_K: tl.constexpr,
63):
64 # batch offsets
65 pid_b = tle.program_id(2)
66 A += pid_b * M * K
67 B += pid_b * K * N
68 O += pid_b * M * N
70 pidx = tle.program_id(0)
71 pidy = tle.program_id(1)
73 if GROUP_M == 1:
74 pid_m, pid_n = pidx, pidy
75 else:
76 # reorder CTAs
77 gridx = tle.num_programs(0)
78 gridy = tle.num_programs(1)
79 pid = pidx + pidy * gridx
81 num_CTA_per_group = gridy * GROUP_M
83 group_id = pid // num_CTA_per_group
84 inner_group_id = pid % num_CTA_per_group
85 GROUP_SIZE = tl.where(
86 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M
87 )
88 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE
89 pid_n = inner_group_id // GROUP_SIZE
91 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)
92 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)
93 offs_k = tl.arange(0, TILE_K)
95 if not DIVISIBLE_M:
96 mask_m = offs_m < M
97 if not DIVISIBLE_N:
98 mask_n = offs_n < N
100 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]
101 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]
102 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]
104 num_iters = tl.cdiv(K, TILE_K)
105 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
106 for _ in range(num_iters):
107 if DIVISIBLE_K:
108 if DIVISIBLE_M:
109 mask_a = None
110 else:
111 mask_a = mask_m[:, None]
112 if DIVISIBLE_N:
113 mask_b = None
114 else:
115 mask_b = mask_n[None, :]
116 else:
117 mask_k = offs_k < K
118 if DIVISIBLE_M:
119 mask_a = mask_k[None, :]
120 else:
121 mask_a = mask_m[:, None] & mask_k[None, :]
122 if DIVISIBLE_N:
123 mask_b = mask_k[:, None]
124 else:
125 mask_b = mask_k[:, None] & mask_n[None, :]
127 a = tl.load(a_ptrs, mask_a)
128 b = tl.load(b_ptrs, mask_b)
130 offs_k += TILE_K
131 a_ptrs += TILE_K
132 b_ptrs += TILE_K * N
134 o += tl.dot(a, b, allow_tf32=False)
136 if DIVISIBLE_M and DIVISIBLE_N:
137 mask_c = None
138 elif DIVISIBLE_M and not DIVISIBLE_N:
139 mask_c = mask_n[None, :]
140 elif not DIVISIBLE_M and DIVISIBLE_N:
141 mask_c = mask_m[:, None]
142 else:
143 mask_c = mask_m[:, None] & mask_n[None, :]
144 tl.store(o_ptrs, o, mask_c)
147def bmm_fma(A, B):
148 logger.debug("GEMS_MTHREADS BMM(FMA)")
149 batch, M, K = A.shape
150 _, _, N = B.shape
151 A = A.contiguous()
152 B = B.contiguous()
153 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)
155 grid_fn = lambda meta: (
156 triton.cdiv(meta["M"], meta["TILE_M"]),
157 triton.cdiv(meta["N"], meta["TILE_N"]),
158 batch,
159 )
160 with torch_device_fn.device(A.device):
161 bmm_kernel[grid_fn](A, B, out, M, N, K)
162 return out
165def bmm_sqmma_descriptor_pre_hook(nargs):
166 a = nargs["A"]
167 b = nargs["B"]
168 c = nargs["C"]
169 batch = nargs["batch"]
170 M = nargs["M"]
171 N = nargs["N"]
172 K = nargs["K"]
173 block_m = nargs["BLOCK_SIZE_M"]
174 block_n = nargs["BLOCK_SIZE_N"]
175 block_k = nargs["BLOCK_SIZE_K"]
176 device = c.device
178 nargs["a_desc_ptr"].copy_(
179 get_cached_tma_device_descriptor(
180 a.reshape(batch * M, K), block_m, block_k, device
181 )
182 )
183 nargs["b_desc_ptr"].copy_(
184 get_cached_tma_device_descriptor(
185 b.reshape(batch * K, N), block_k, block_n, device
186 )
187 )
188 nargs["c_desc_ptr"].copy_(
189 create_tma_device_descriptor(c.reshape(batch * M, N), block_m, block_n, device)
190 )
193@libentry()
194@libtuner(
195 configs=runtime.ops_get_configs(
196 "bmm_sqmma",
197 pre_hook=bmm_sqmma_descriptor_pre_hook,
198 yaml_path=EXPAND_CONFIG_FILENAME,
199 )
200 if os.environ.get("USE_FLAGTUNE") == "1"
201 else [
202 triton.Config(
203 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
204 num_stages=1,
205 num_warps=4,
206 pre_hook=bmm_sqmma_descriptor_pre_hook,
207 )
208 ],
209 key=["M", "N", "K"],
210 strategy=runtime.get_expand_config("bmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[
211 "strategy"
212 ][:3]
213 if os.environ.get("USE_FLAGTUNE") == "1"
214 else ["align32", "align32", "align32"],
215 warmup=5,
216 rep=5,
217)
218@triton.jit
219def bmm_sqmma_kernel(
220 A,
221 B,
222 C,
223 a_desc_ptr,
224 b_desc_ptr,
225 c_desc_ptr,
226 batch,
227 M,
228 N,
229 K,
230 BLOCK_SIZE_M: tl.constexpr,
231 BLOCK_SIZE_N: tl.constexpr,
232 BLOCK_SIZE_K: tl.constexpr,
233 ab_type: tl.constexpr,
234 d_type: tl.constexpr,
235):
236 pid = tl.program_id(axis=0)
237 batch_index = tl.program_id(axis=1)
238 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
239 pid_m = pid % num_pid_m
240 pid_n = pid // num_pid_m
241 offs_am = pid_m * BLOCK_SIZE_M + batch_index * M
242 offs_bn = pid_n * BLOCK_SIZE_N
243 offs_ak = 0
244 offs_bk = batch_index * K
245 tme_load_type = ab_type
246 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
247 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
248 a = tl._experimental_descriptor_load(
249 a_desc_ptr, [offs_am, offs_ak], [BLOCK_SIZE_M, BLOCK_SIZE_K], tme_load_type
250 )
251 b = tl._experimental_descriptor_load(
252 b_desc_ptr, [offs_bk, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tme_load_type
253 )
254 accumulator = tl.dot(a, b, acc=accumulator)
255 offs_ak += BLOCK_SIZE_K
256 offs_bk += BLOCK_SIZE_K
257 accumulator = accumulator.to(d_type)
258 tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
261def get_triton_type(elem_type):
262 type_map = {
263 torch.float16: tl.float16,
264 torch.bfloat16: tl.bfloat16,
265 torch.float8_e4m3fn: tl.float8e4nv,
266 }
267 return type_map.get(elem_type, None)
270def bmm_sqmma(A, B, elem_type, batch, M, N, K):
271 device = "musa"
272 ab_type = elem_type
273 c_type = elem_type if (elem_type != torch.bfloat16) else torch.float16
274 C = torch.empty((batch, M, N), dtype=torch.float16, device=device).to(c_type)
275 desc_a = torch.empty((64,), dtype=torch.int8, device=device)
276 desc_b = torch.empty((64,), dtype=torch.int8, device=device)
277 desc_c = torch.empty((64,), dtype=torch.int8, device=device)
278 grid = lambda META: (
279 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
280 batch,
281 1,
282 )
283 bmm_sqmma_kernel[grid](
284 A,
285 B,
286 C,
287 desc_a,
288 desc_b,
289 desc_c,
290 batch,
291 M,
292 N,
293 K,
294 ab_type=get_triton_type(ab_type),
295 d_type=get_triton_type(c_type),
296 )
297 return C
300def bmm(a, b):
301 a_dtype = a.dtype
302 b_dtype = b.dtype
303 batch, M, K = a.shape
304 _, _, N = b.shape
305 need_sqmma = a_dtype != torch.float32 and b_dtype != torch.float32
306 prev_sqmma = os.environ.get("MUSA_ENABLE_SQMMA")
307 if need_sqmma:
308 os.environ["MUSA_ENABLE_SQMMA"] = "1"
309 else:
310 os.environ.pop("MUSA_ENABLE_SQMMA", None)
311 try:
312 if is_sqmma_compatible(a, b, N, K):
313 return bmm_sqmma(a, b, a_dtype, batch, M, N, K)
314 else:
315 return bmm_fma(a, b)
316 finally:
317 if prev_sqmma is None:
318 os.environ.pop("MUSA_ENABLE_SQMMA", None)
319 else:
320 os.environ["MUSA_ENABLE_SQMMA"] = prev_sqmma