Coverage for src/flag_gems/runtime/backend/_mthreads/ops/w8a8_block_fp8_matmul.py: 0%
141 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
3from typing import List
5import torch
6import triton
7import triton.language as tl
8from triton.tools.tensor_descriptor import TensorDescriptor
10from flag_gems import runtime
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
15logger = logging.getLogger(
16 "flag_gems.runtime.backend._mthreads.ops.w8a8_block_fp8_matmul"
17)
18EXPAND_CONFIG_FILENAME = os.path.normpath(
19 os.path.join(
20 os.path.dirname(__file__),
21 "..",
22 "w8a8_block_fp8_matmul_mthreads_expand.yaml",
23 )
24)
26SQMMA_ON = False
29def is_supported_sqmma_layout(tensor):
30 return tensor.is_contiguous() or (
31 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
32 )
35def is_sqmma_compatible(a, b, output_dtype, n, k):
36 return (
37 a.dim() == 2
38 and SQMMA_ON
39 and b.dim() == 2
40 and a.dtype == b.dtype == torch.float8_e4m3fn
41 and output_dtype in (torch.float16, torch.bfloat16)
42 and is_supported_sqmma_layout(a)
43 and is_supported_sqmma_layout(b)
44 and n % 16 == 0
45 and k % 16 == 0
46 )
49def matmul_get_configs():
50 return [
51 triton.Config(
52 {
53 "BLOCK_M": 64,
54 "BLOCK_N": 64,
55 "BLOCK_K": 128,
56 "GROUP_M": 8,
57 },
58 num_stages=3,
59 num_warps=4,
60 )
61 ]
64@libentry()
65@libtuner(
66 configs=matmul_get_configs(),
67 key=["M", "N", "K", "stride_am", "stride_bk"],
68 strategy=["align32", "align32", "align32", "align32", "align32"],
69 warmup=5,
70 rep=5,
71)
72@triton.jit
73def w8a8_block_fp8_matmul_kernel(
74 A,
75 B,
76 C,
77 As,
78 Bs,
79 M,
80 N,
81 K,
82 group_n,
83 group_k,
84 stride_am,
85 stride_ak,
86 stride_bk,
87 stride_bn,
88 stride_cm,
89 stride_cn,
90 stride_As_m,
91 stride_As_k,
92 stride_Bs_k,
93 stride_Bs_n,
94 BLOCK_M: tl.constexpr,
95 BLOCK_N: tl.constexpr,
96 BLOCK_K: tl.constexpr,
97 GROUP_M: tl.constexpr,
98):
99 pid = tl.program_id(axis=0)
100 num_pid_m = tl.cdiv(M, BLOCK_M)
101 num_pid_n = tl.cdiv(N, BLOCK_N)
102 num_pid_in_group = GROUP_M * num_pid_n
103 group_id = pid // num_pid_in_group
104 first_pid_m = group_id * GROUP_M
105 group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
106 pid_m = first_pid_m + (pid % group_size_m)
107 pid_n = (pid % num_pid_in_group) // group_size_m
109 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
110 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
111 offs_k = tl.arange(0, BLOCK_K)
112 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
113 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
115 As_ptrs = As + offs_am * stride_As_m
116 offs_bsn = offs_bn // group_n
117 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
119 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
120 for k in range(0, tl.cdiv(K, BLOCK_K)):
121 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
122 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
124 k_start = k * BLOCK_K
125 offs_ks = k_start // group_k
126 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
127 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
128 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
129 a_ptrs += BLOCK_K * stride_ak
130 b_ptrs += BLOCK_K * stride_bk
132 if C.dtype.element_ty == tl.bfloat16:
133 c = accumulator.to(tl.bfloat16)
134 elif C.dtype.element_ty == tl.float16:
135 c = accumulator.to(tl.float16)
136 else:
137 c = accumulator.to(tl.float32)
139 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
140 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
141 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
142 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
143 tl.store(c_ptrs, c, mask=c_mask)
146def sqmma_descriptor_pre_hook(nargs):
147 nargs["a_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_K"]]
148 nargs["b_desc"].block_shape = [nargs["BLOCK_K"], nargs["BLOCK_N"]]
149 nargs["c_desc"].block_shape = [nargs["BLOCK_M"], nargs["BLOCK_N"]]
152@libentry()
153@libtuner(
154 configs=runtime.ops_get_configs(
155 "w8a8_block_fp8_general_tma",
156 pre_hook=sqmma_descriptor_pre_hook,
157 yaml_path=EXPAND_CONFIG_FILENAME,
158 )
159 if os.environ.get("USE_FLAGTUNE") == "1"
160 else [
161 triton.Config(
162 {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8},
163 num_stages=3,
164 num_warps=4,
165 pre_hook=sqmma_descriptor_pre_hook,
166 )
167 ],
168 key=["M", "N", "K"],
169 strategy=runtime.get_expand_config(
170 "w8a8_block_fp8_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
171 )["strategy"][:3]
172 if os.environ.get("USE_FLAGTUNE") == "1"
173 else ["align32", "align32", "align32"],
174 warmup=5,
175 rep=5,
176)
177@triton.jit
178def w8a8_block_fp8_matmul_sqmma_kernel(
179 a_desc,
180 b_desc,
181 c_desc,
182 As,
183 Bs,
184 M,
185 N,
186 K,
187 group_n,
188 group_k,
189 stride_As_m,
190 stride_As_k,
191 stride_Bs_n,
192 stride_Bs_k,
193 GROUP_M: tl.constexpr,
194 BLOCK_M: tl.constexpr,
195 BLOCK_N: tl.constexpr,
196 BLOCK_K: tl.constexpr,
197):
198 pid = ext.program_id(0)
199 grid_m = tl.cdiv(M, BLOCK_M)
200 grid_n = tl.cdiv(N, BLOCK_N)
201 width = GROUP_M * grid_n
202 group_id = pid // width
203 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
204 pid_m = group_id * GROUP_M + (pid % group_size)
205 pid_n = (pid % width) // group_size
207 offs_am = (pid_m * BLOCK_M).to(tl.int32)
208 offs_bn = (pid_n * BLOCK_N).to(tl.int32)
209 offs_k = tl.zeros((), dtype=tl.int32)
211 row_offset = offs_am + tl.arange(0, BLOCK_M)
212 col_offset = offs_bn + tl.arange(0, BLOCK_N)
213 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
215 for _ in range(0, tl.cdiv(K, BLOCK_K)):
216 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
217 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
219 scale_k = offs_k // group_k
220 a_s = tl.load(
221 As + row_offset * stride_As_m + scale_k * stride_As_k,
222 mask=row_offset < M,
223 other=0.0,
224 )
225 b_s = tl.load(
226 Bs + (col_offset // group_n) * stride_Bs_n + scale_k * stride_Bs_k,
227 mask=col_offset < N,
228 other=0.0,
229 )
230 acc += (
231 tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
232 * a_s[:, None]
233 * b_s[None, :]
234 )
235 offs_k += BLOCK_K
237 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], acc.to(c_desc.dtype))
240def general_w8a8_block_fp8_matmul(
241 a,
242 b,
243 c,
244 a_s,
245 b_s,
246 M,
247 N,
248 K,
249 group_n,
250 group_k,
251):
252 logger.debug(
253 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(general), [shape info]: [-, %s, %s, %s](batch, M, N, K)",
254 M,
255 N,
256 K,
257 )
258 grid = lambda meta: (
259 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
260 )
262 with torch_device_fn.device(a.device):
263 w8a8_block_fp8_matmul_kernel[grid](
264 a,
265 b,
266 c,
267 a_s,
268 b_s,
269 M,
270 N,
271 K,
272 group_n,
273 group_k,
274 a.stride(0),
275 a.stride(1),
276 b.stride(1),
277 b.stride(0),
278 c.stride(0),
279 c.stride(1),
280 a_s.stride(0),
281 a_s.stride(1),
282 b_s.stride(1),
283 b_s.stride(0),
284 )
285 return c
288def sqmma_w8a8_block_fp8_matmul(
289 a,
290 b,
291 c,
292 a_s,
293 b_s,
294 M,
295 N,
296 K,
297 group_n,
298 group_k,
299):
300 logger.debug(
301 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(sqmma), [shape info]: [-, %s, %s, %s](batch, M, N, K), "
302 "[A column-major]: %s, [B column-major]: %s",
303 M,
304 N,
305 K,
306 a.stride(0) == 1,
307 b.stride(0) == 1,
308 )
309 device = a.device
310 if not a.is_contiguous():
311 a = a.contiguous()
312 if not b.is_contiguous():
313 b = b.contiguous()
315 desc_a = TensorDescriptor.from_tensor(a, [1, 1])
316 desc_b = TensorDescriptor.from_tensor(b, [1, 1])
317 desc_c = TensorDescriptor.from_tensor(c, [1, 1])
319 grid = lambda meta: (
320 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
321 1,
322 1,
323 )
325 with torch_device_fn.device(device):
326 w8a8_block_fp8_matmul_sqmma_kernel[grid](
327 desc_a,
328 desc_b,
329 desc_c,
330 a_s,
331 b_s,
332 M,
333 N,
334 K,
335 group_n,
336 group_k,
337 a_s.stride(0),
338 a_s.stride(1),
339 b_s.stride(0),
340 b_s.stride(1),
341 )
342 return c
345def w8a8_block_fp8_matmul(
346 A: torch.Tensor,
347 B: torch.Tensor,
348 As: torch.Tensor,
349 Bs: torch.Tensor,
350 block_size: List[int],
351 output_dtype: torch.dtype = torch.bfloat16,
352) -> torch.Tensor:
353 device = A.device
354 assert len(block_size) == 2
355 block_n, block_k = block_size
357 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1:
358 A = A.contiguous()
359 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1:
360 B = B.contiguous()
361 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1:
362 As = As.contiguous()
363 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1:
364 Bs = Bs.contiguous()
366 assert A.shape[-1] == B.shape[-1], "incompatible dimensions"
367 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch"
368 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape"
369 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D"
371 M = A.numel() // A.shape[-1]
372 N, K = B.shape
373 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension"
374 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension"
376 output_shape = A.shape[:-1] + (N,)
377 c = torch.empty(output_shape, device=device, dtype=output_dtype)
379 a_2d = A.reshape(M, K)
380 as_2d = As.reshape(M, As.shape[-1])
381 c_2d = c.reshape(M, N)
382 if is_sqmma_compatible(a_2d, B, output_dtype, N, K):
383 return sqmma_w8a8_block_fp8_matmul(
384 a_2d,
385 B,
386 c_2d,
387 as_2d,
388 Bs,
389 M,
390 N,
391 K,
392 block_n,
393 block_k,
394 ).reshape(c.shape)
396 return general_w8a8_block_fp8_matmul(
397 a_2d,
398 B,
399 c_2d,
400 as_2d,
401 Bs,
402 M,
403 N,
404 K,
405 block_n,
406 block_k,
407 ).reshape(c.shape)