Coverage for src/flag_gems/runtime/backend/_mthreads/ops/w8a8_block_fp8_matmul.py: 0%
142 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import os
3from typing import List
5import torch
6import triton
7import triton.language as tl
8from triton.tools.tensor_descriptor import TensorDescriptor
10from flag_gems import runtime
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
15logger = logging.getLogger(
16 "flag_gems.runtime.backend._mthreads.ops.w8a8_block_fp8_matmul"
17)
18EXPAND_CONFIG_FILENAME = os.path.normpath(
19 os.path.join(
20 os.path.dirname(__file__),
21 "..",
22 "w8a8_block_fp8_matmul_mthreads_expand.yaml",
23 )
24)
26SQMMA_ON = False
29def is_supported_sqmma_layout(tensor):
30 return tensor.is_contiguous() or (
31 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
32 )
35def is_sqmma_compatible(a, b, output_dtype, n, k):
36 return (
37 a.dim() == 2
38 and SQMMA_ON
39 and b.dim() == 2
40 and a.dtype == b.dtype == torch.float8_e4m3fn
41 and output_dtype in (torch.float16, torch.bfloat16)
42 and is_supported_sqmma_layout(a)
43 and is_supported_sqmma_layout(b)
44 and n % 16 == 0
45 and k % 16 == 0
46 )
49def get_triton_type(elem_type):
50 type_map = {
51 torch.float16: tl.float16,
52 torch.bfloat16: tl.bfloat16,
53 torch.float32: tl.float32,
54 torch.float8_e4m3fn: tl.float8e4nv,
55 }
56 return type_map.get(elem_type, None)
59def matmul_get_configs():
60 return [
61 triton.Config(
62 {
63 "BLOCK_M": 64,
64 "BLOCK_N": 64,
65 "BLOCK_K": 128,
66 "GROUP_M": 8,
67 },
68 num_stages=3,
69 num_warps=4,
70 )
71 ]
74@libentry()
75@libtuner(
76 configs=runtime.ops_get_configs(
77 "w8a8_block_fp8_general", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME
78 )
79 if os.environ.get("USE_FLAGTUNE") == "1"
80 else matmul_get_configs(),
81 key=["M", "N", "K", "stride_am", "stride_bk"],
82 strategy=runtime.get_expand_config(
83 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME
84 )["strategy"]
85 if os.environ.get("USE_FLAGTUNE") == "1"
86 else ["align32", "align32", "align32", "align32", "align32"],
87 warmup=5,
88 rep=5,
89)
90@triton.jit
91def w8a8_block_fp8_matmul_kernel(
92 A,
93 B,
94 C,
95 As,
96 Bs,
97 M,
98 N,
99 K,
100 group_n,
101 group_k,
102 stride_am,
103 stride_ak,
104 stride_bk,
105 stride_bn,
106 stride_cm,
107 stride_cn,
108 stride_As_m,
109 stride_As_k,
110 stride_Bs_k,
111 stride_Bs_n,
112 BLOCK_M: tl.constexpr,
113 BLOCK_N: tl.constexpr,
114 BLOCK_K: tl.constexpr,
115 GROUP_M: tl.constexpr,
116):
117 pid = tl.program_id(axis=0)
118 num_pid_m = tl.cdiv(M, BLOCK_M)
119 num_pid_n = tl.cdiv(N, BLOCK_N)
120 num_pid_in_group = GROUP_M * num_pid_n
121 group_id = pid // num_pid_in_group
122 first_pid_m = group_id * GROUP_M
123 group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
124 pid_m = first_pid_m + (pid % group_size_m)
125 pid_n = (pid % num_pid_in_group) // group_size_m
127 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
128 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
129 offs_k = tl.arange(0, BLOCK_K)
130 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
131 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
133 As_ptrs = As + offs_am * stride_As_m
134 offs_bsn = offs_bn // group_n
135 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
137 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
138 for k in range(0, tl.cdiv(K, BLOCK_K)):
139 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
140 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
142 k_start = k * BLOCK_K
143 offs_ks = k_start // group_k
144 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
145 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
146 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
147 a_ptrs += BLOCK_K * stride_ak
148 b_ptrs += BLOCK_K * stride_bk
150 if C.dtype.element_ty == tl.bfloat16:
151 c = accumulator.to(tl.bfloat16)
152 elif C.dtype.element_ty == tl.float16:
153 c = accumulator.to(tl.float16)
154 else:
155 c = accumulator.to(tl.float32)
157 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
158 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
159 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
160 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
161 tl.store(c_ptrs, c, mask=c_mask)
164@triton.jit
165def w8a8_block_fp8_matmul_sqmma_kernel(
166 a_desc,
167 b_desc,
168 c_desc,
169 As,
170 Bs,
171 M,
172 N,
173 K,
174 group_n,
175 group_k,
176 stride_As_m,
177 stride_As_k,
178 stride_Bs_n,
179 stride_Bs_k,
180 GROUP_M: tl.constexpr,
181 BLOCK_M: tl.constexpr,
182 BLOCK_N: tl.constexpr,
183 BLOCK_K: tl.constexpr,
184):
185 pid = ext.program_id(0)
186 grid_m = tl.cdiv(M, BLOCK_M)
187 grid_n = tl.cdiv(N, BLOCK_N)
188 width = GROUP_M * grid_n
189 group_id = pid // width
190 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
191 pid_m = group_id * GROUP_M + (pid % group_size)
192 pid_n = (pid % width) // group_size
194 offs_am = (pid_m * BLOCK_M).to(tl.int32)
195 offs_bn = (pid_n * BLOCK_N).to(tl.int32)
196 offs_k = tl.zeros((), dtype=tl.int32)
198 row_offset = offs_am + tl.arange(0, BLOCK_M)
199 col_offset = offs_bn + tl.arange(0, BLOCK_N)
200 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
202 for _ in range(0, tl.cdiv(K, BLOCK_K)):
203 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
204 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
206 scale_k = offs_k // group_k
207 a_s = tl.load(
208 As + row_offset * stride_As_m + scale_k * stride_As_k,
209 mask=row_offset < M,
210 other=0.0,
211 )
212 b_s = tl.load(
213 Bs + (col_offset // group_n) * stride_Bs_n + scale_k * stride_Bs_k,
214 mask=col_offset < N,
215 other=0.0,
216 )
217 acc += (
218 tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
219 * a_s[:, None]
220 * b_s[None, :]
221 )
222 offs_k += BLOCK_K
224 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], acc.to(c_desc.dtype))
227def general_w8a8_block_fp8_matmul(
228 a,
229 b,
230 c,
231 a_s,
232 b_s,
233 M,
234 N,
235 K,
236 group_n,
237 group_k,
238):
239 logger.debug(
240 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(general), [shape info]: [-, %s, %s, %s](batch, M, N, K)",
241 M,
242 N,
243 K,
244 )
245 grid = lambda meta: (
246 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
247 )
249 with torch_device_fn.device(a.device):
250 w8a8_block_fp8_matmul_kernel[grid](
251 a,
252 b,
253 c,
254 a_s,
255 b_s,
256 M,
257 N,
258 K,
259 group_n,
260 group_k,
261 a.stride(0),
262 a.stride(1),
263 b.stride(1),
264 b.stride(0),
265 c.stride(0),
266 c.stride(1),
267 a_s.stride(0),
268 a_s.stride(1),
269 b_s.stride(1),
270 b_s.stride(0),
271 )
272 return c
275def sqmma_w8a8_block_fp8_matmul(
276 a,
277 b,
278 c,
279 a_s,
280 b_s,
281 M,
282 N,
283 K,
284 group_n,
285 group_k,
286):
287 logger.debug(
288 "GEMS_MTHREADS W8A8_BLOCK_FP8_MATMUL(sqmma), [shape info]: [-, %s, %s, %s](batch, M, N, K), "
289 "[A column-major]: %s, [B column-major]: %s",
290 M,
291 N,
292 K,
293 a.stride(0) == 1,
294 b.stride(0) == 1,
295 )
296 device = a.device
297 if not a.is_contiguous():
298 a = a.contiguous()
299 if not b.is_contiguous():
300 b = b.contiguous()
302 BLOCK_M = 64
303 BLOCK_N = 64
304 BLOCK_K = 128
305 GROUP_M = 8
307 desc_a = TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K])
308 desc_b = TensorDescriptor.from_tensor(b, [BLOCK_K, BLOCK_N])
309 desc_c = TensorDescriptor.from_tensor(c, [BLOCK_M, BLOCK_N])
311 grid = lambda meta: (
312 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
313 1,
314 1,
315 )
317 with torch_device_fn.device(device):
318 w8a8_block_fp8_matmul_sqmma_kernel[grid](
319 desc_a,
320 desc_b,
321 desc_c,
322 a_s,
323 b_s,
324 M,
325 N,
326 K,
327 group_n,
328 group_k,
329 a_s.stride(0),
330 a_s.stride(1),
331 b_s.stride(0),
332 b_s.stride(1),
333 GROUP_M,
334 BLOCK_M,
335 BLOCK_N,
336 BLOCK_K,
337 num_warps=4,
338 num_stages=3,
339 )
340 return c
343def w8a8_block_fp8_matmul(
344 A: torch.Tensor,
345 B: torch.Tensor,
346 As: torch.Tensor,
347 Bs: torch.Tensor,
348 block_size: List[int],
349 output_dtype: torch.dtype = torch.bfloat16,
350) -> torch.Tensor:
351 device = A.device
352 assert len(block_size) == 2
353 block_n, block_k = block_size
355 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1:
356 A = A.contiguous()
357 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1:
358 B = B.contiguous()
359 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1:
360 As = As.contiguous()
361 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1:
362 Bs = Bs.contiguous()
364 assert A.shape[-1] == B.shape[-1], "incompatible dimensions"
365 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch"
366 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape"
367 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D"
369 M = A.numel() // A.shape[-1]
370 N, K = B.shape
371 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension"
372 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension"
374 output_shape = A.shape[:-1] + (N,)
375 c = torch.empty(output_shape, device=device, dtype=output_dtype)
377 a_2d = A.reshape(M, K)
378 as_2d = As.reshape(M, As.shape[-1])
379 c_2d = c.reshape(M, N)
380 if is_sqmma_compatible(a_2d, B, output_dtype, N, K):
381 return sqmma_w8a8_block_fp8_matmul(
382 a_2d,
383 B,
384 c_2d,
385 as_2d,
386 Bs,
387 M,
388 N,
389 K,
390 block_n,
391 block_k,
392 ).reshape(c.shape)
394 return general_w8a8_block_fp8_matmul(
395 a_2d,
396 B,
397 c_2d,
398 as_2d,
399 Bs,
400 M,
401 N,
402 K,
403 block_n,
404 block_k,
405 ).reshape(c.shape)