Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_matmul.py: 0%
178 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import functools
2import logging
3import os
4from typing import Any, Dict, List, Optional
6import torch
7import triton
8import triton.language as tl
9import yaml
11from flag_gems import runtime
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
15logger = logging.getLogger(
16 "flag_gems.runtime.backend._nvidia.hopper.ops.w8a8_block_fp8_matmul"
17)
18CACHE_USAGE_THRESHOLD = 0.8
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(
21 os.path.dirname(__file__),
22 "..",
23 "w8a8_block_fp8_matmul_hopper_expand.yaml",
24 )
25)
28@functools.lru_cache
29def get_w8a8_block_fp8_hopper_configs(N: int, K: int) -> Optional[Dict[int, Any]]:
30 device_name = torch.cuda.get_device_name().replace(" ", "_")
31 file_name = "w8a8_block_fp8_matmul_hopper.yaml"
33 cfg_file = os.path.join(os.path.dirname(__file__), "..", file_name)
35 if os.path.exists(cfg_file):
36 with open(cfg_file) as f:
37 logger.info(
38 "Using config from %s for W8A8 block FP8 kernel.",
39 cfg_file,
40 )
41 dev_data = yaml.safe_load(f).get(device_name, {})
42 NK_data = dev_data.get(f"{N},{K}", {})
44 result = {}
45 for k, p in NK_data.items():
46 # unpack the list into dictionary
47 result[int(k)] = {
48 "BLOCK_SIZE_M": p[0],
49 "BLOCK_SIZE_N": p[1],
50 "BLOCK_SIZE_K": p[2],
51 "GROUP_SIZE_M": p[3],
52 "num_warps": p[4],
53 "num_stages": p[5],
54 }
56 if not result:
57 return None
58 return result
60 logger.warning(
61 "Using default W8A8 Block FP8 kernel config. Performance might "
62 "be sub-optimal! Config file not found at %s",
63 cfg_file,
64 )
65 return None
68def _get_placeholder_tuner_configs(pre_hook=None):
69 # Placeholder config for libtuner initialization before runtime shapes are known.
70 return [
71 triton.Config(
72 {
73 "BLOCK_M": 64,
74 "BLOCK_N": 64,
75 "BLOCK_K": 128,
76 "GROUP_M": 8,
77 },
78 num_stages=3,
79 num_warps=4,
80 pre_hook=pre_hook,
81 )
82 ]
85def _get_fixed_matmul_meta(M: int, N: int, K: int, block_n: int, block_k: int):
86 configs = get_w8a8_block_fp8_hopper_configs(N, K)
87 if not configs:
88 return {
89 "BLOCK_M": 64,
90 "BLOCK_N": block_n,
91 "BLOCK_K": block_k,
92 "GROUP_M": 32,
93 "num_warps": 4,
94 "num_stages": 2,
95 }
97 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
98 return {
99 "BLOCK_M": config["BLOCK_SIZE_M"],
100 "BLOCK_N": config["BLOCK_SIZE_N"],
101 "BLOCK_K": config["BLOCK_SIZE_K"],
102 "GROUP_M": config["GROUP_SIZE_M"],
103 "num_warps": config["num_warps"],
104 "num_stages": config["num_stages"],
105 }
108@libentry()
109@libtuner(
110 configs=_get_placeholder_tuner_configs(pre_hook=None),
111 key=["M", "N", "K", "stride_am", "stride_bk"],
112 strategy=["align32", "align32", "align32", "align32", "align32"],
113 warmup=5,
114 rep=5,
115 flagtune_op_name="w8a8_block_fp8_matmul",
116 flagtune_expand_op_name="w8a8_block_fp8_general",
117 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
118 flagtune_pre_hook=None,
119)
120@triton.jit
121def w8a8_block_fp8_matmul_kernel_general(
122 A,
123 B,
124 C,
125 As,
126 Bs,
127 M,
128 N,
129 K,
130 group_n,
131 group_k,
132 stride_am,
133 stride_ak,
134 stride_bk,
135 stride_bn,
136 stride_cm,
137 stride_cn,
138 stride_As_m,
139 stride_As_k,
140 stride_Bs_k,
141 stride_Bs_n,
142 BLOCK_M: tl.constexpr,
143 BLOCK_N: tl.constexpr,
144 BLOCK_K: tl.constexpr,
145 GROUP_M: tl.constexpr,
146):
147 pid = tl.program_id(axis=0)
148 num_pid_m = tl.cdiv(M, BLOCK_M)
149 num_pid_n = tl.cdiv(N, BLOCK_N)
150 num_pid_in_group = GROUP_M * num_pid_n
151 group_id = pid // num_pid_in_group
152 first_pid_m = group_id * GROUP_M
153 group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
154 pid_m = first_pid_m + (pid % group_size_m)
155 pid_n = (pid % num_pid_in_group) // group_size_m
157 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
158 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
159 offs_k = tl.arange(0, BLOCK_K)
160 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
161 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
163 As_ptrs = As + offs_am * stride_As_m
164 offs_bsn = offs_bn // group_n
165 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
167 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
168 for k in range(0, tl.cdiv(K, BLOCK_K)):
169 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
170 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
172 k_start = k * BLOCK_K
173 offs_ks = k_start // group_k
174 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
175 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
176 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
177 a_ptrs += BLOCK_K * stride_ak
178 b_ptrs += BLOCK_K * stride_bk
180 if C.dtype.element_ty == tl.bfloat16:
181 c = acc.to(tl.bfloat16)
182 elif C.dtype.element_ty == tl.float16:
183 c = acc.to(tl.float16)
184 else:
185 c = acc.to(tl.float32)
187 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
188 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
189 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
190 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
191 tl.store(c_ptrs, c, mask=c_mask)
194@libentry()
195@libtuner(
196 configs=_get_placeholder_tuner_configs(pre_hook=None),
197 key=["M", "N", "K", "stride_am", "stride_bk"],
198 strategy=["align32", "align32", "align32", "align32", "align32"],
199 warmup=5,
200 rep=5,
201 flagtune_op_name="w8a8_block_fp8_matmul",
202 flagtune_expand_op_name="w8a8_block_fp8_general_splitk",
203 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
204 flagtune_pre_hook=None,
205)
206@triton.jit
207def w8a8_block_fp8_matmul_kernel_splitk(
208 A,
209 B,
210 C,
211 As,
212 Bs,
213 M,
214 N,
215 K,
216 group_n,
217 group_k,
218 stride_am,
219 stride_ak,
220 stride_bk,
221 stride_bn,
222 stride_cm,
223 stride_cn,
224 stride_As_m,
225 stride_As_k,
226 stride_Bs_k,
227 stride_Bs_n,
228 BLOCK_M: tl.constexpr,
229 BLOCK_N: tl.constexpr,
230 BLOCK_K: tl.constexpr,
231 SPLIT_K: tl.constexpr,
232):
233 pid = tl.program_id(0)
234 pid_k = tl.program_id(1)
236 # grid_m = tl.cdiv(M, BLOCK_M)
237 grid_n = tl.cdiv(N, BLOCK_N)
238 pid_m = pid // grid_n
239 pid_n = pid % grid_n
241 offset_am = pid_m * BLOCK_M
242 offset_bn = pid_n * BLOCK_N
243 offs_am = offset_am + tl.arange(0, BLOCK_M)
244 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
246 total_k_iters = tl.cdiv(K, BLOCK_K)
247 k_per_split = tl.cdiv(total_k_iters, SPLIT_K)
248 k_start = pid_k * k_per_split
249 k_end = min((pid_k + 1) * k_per_split, total_k_iters)
251 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
252 for k in range(k_start, k_end):
253 offset_k = k * BLOCK_K
254 offs_k = offset_k + tl.arange(0, BLOCK_K)
256 a = tl.load(
257 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak,
258 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K),
259 other=0.0,
260 )
261 b = tl.load(
262 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn,
263 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N),
264 other=0.0,
265 )
267 offs_ks = offset_k // group_k
268 a_s = tl.load(
269 As + offs_am * stride_As_m + offs_ks * stride_As_k,
270 mask=offs_am < M,
271 other=0.0,
272 )
273 b_s = tl.load(
274 Bs + offs_ks * stride_Bs_k + (offs_bn // group_n) * stride_Bs_n,
275 mask=offs_bn < N,
276 other=0.0,
277 )
278 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
280 offs_cm = offset_am + tl.arange(0, BLOCK_M)
281 offs_cn = offset_bn + tl.arange(0, BLOCK_N)
282 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
283 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
284 if C.dtype.element_ty == tl.bfloat16:
285 tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mask)
286 elif C.dtype.element_ty == tl.float16:
287 tl.atomic_add(c_ptrs, acc.to(tl.float16), mask=mask)
288 else:
289 tl.atomic_add(c_ptrs, acc.to(tl.float32), mask=mask)
292def general_w8a8_block_fp8_matmul(a, b, c, a_s, b_s, M, N, K, group_n, group_k):
293 logger.debug(
294 "GEMS w8a8_block_fp8_matmul-hopper, [scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
295 "[A column-major]: %s, [B column-major]: %s",
296 M,
297 N,
298 K,
299 a.stride(0) == 1,
300 b.stride(0) == 1,
301 )
303 # Default W8A8 keeps the existing fixed-meta path. When explicitly included
304 # in flag_gems.flagtune(...), launch through LibTuner so expanded configs
305 # are selected by the same registry-driven mechanism used by mm.
306 use_flagtune = runtime.flagtune_enabled("w8a8_block_fp8_matmul")
308 # Split-K path for small-N, large-K shapes
309 if M < 2048 and N < 2112 and K >= 4096:
310 if use_flagtune:
311 splitk_grid = lambda META: (
312 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
313 META["SPLIT_K"],
314 )
315 c.zero_()
316 with torch_device_fn.device(a.device):
317 w8a8_block_fp8_matmul_kernel_splitk[splitk_grid](
318 a,
319 b,
320 c,
321 a_s,
322 b_s,
323 M,
324 N,
325 K,
326 group_n,
327 group_k,
328 a.stride(0),
329 a.stride(1),
330 b.stride(1),
331 b.stride(0),
332 c.stride(0),
333 c.stride(1),
334 a_s.stride(0),
335 a_s.stride(1),
336 b_s.stride(1),
337 b_s.stride(0),
338 )
339 else:
340 SPLITK_BLOCK_K = 128
341 SPLITK_BLOCK_M = 16 if M <= 16 else 64
342 SPLITK_BLOCK_N = 64 if N > 256 else 32
344 grid_m = triton.cdiv(M, SPLITK_BLOCK_M)
345 grid_n = triton.cdiv(N, SPLITK_BLOCK_N)
346 grid_mn = grid_m * grid_n
347 total_k_iters = triton.cdiv(K, SPLITK_BLOCK_K)
349 SM_COUNT = torch.cuda.get_device_properties(a.device).multi_processor_count
350 split_k = min(total_k_iters, max(4, 2 * SM_COUNT // max(grid_mn, 1)))
352 c.zero_()
353 splitk_grid = (grid_mn, split_k)
355 with torch_device_fn.device(a.device):
356 w8a8_block_fp8_matmul_kernel_splitk.fn.fn[splitk_grid](
357 a,
358 b,
359 c,
360 a_s,
361 b_s,
362 M,
363 N,
364 K,
365 group_n,
366 group_k,
367 a.stride(0),
368 a.stride(1),
369 b.stride(1),
370 b.stride(0),
371 c.stride(0),
372 c.stride(1),
373 a_s.stride(0),
374 a_s.stride(1),
375 b_s.stride(1),
376 b_s.stride(0),
377 BLOCK_M=SPLITK_BLOCK_M,
378 BLOCK_N=SPLITK_BLOCK_N,
379 BLOCK_K=SPLITK_BLOCK_K,
380 SPLIT_K=split_k,
381 )
382 return c
384 else:
385 grid = lambda meta: (
386 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
387 )
388 fixed_meta = (
389 None
390 if use_flagtune
391 else _get_fixed_matmul_meta(M, N, K, block_n=group_n, block_k=group_k)
392 )
394 def alloc_fn(size: int, align: int, stream: Optional[int]):
395 return torch.empty(size, dtype=torch.int8, device=a.device)
397 triton.set_allocator(alloc_fn)
398 if use_flagtune:
399 launch = lambda: w8a8_block_fp8_matmul_kernel_general[grid](
400 a,
401 b,
402 c,
403 a_s,
404 b_s,
405 M,
406 N,
407 K,
408 group_n,
409 group_k,
410 a.stride(0),
411 a.stride(1),
412 b.stride(1),
413 b.stride(0),
414 c.stride(0),
415 c.stride(1),
416 a_s.stride(0),
417 a_s.stride(1),
418 b_s.stride(1),
419 b_s.stride(0),
420 )
421 else:
422 launch = lambda: w8a8_block_fp8_matmul_kernel_general.fn.fn[grid](
423 a,
424 b,
425 c,
426 a_s,
427 b_s,
428 M,
429 N,
430 K,
431 group_n,
432 group_k,
433 a.stride(0),
434 a.stride(1),
435 b.stride(1),
436 b.stride(0),
437 c.stride(0),
438 c.stride(1),
439 a_s.stride(0),
440 a_s.stride(1),
441 b_s.stride(1),
442 b_s.stride(0),
443 **fixed_meta,
444 )
446 with torch_device_fn.device(a.device):
447 launch()
448 return c
451def w8a8_block_fp8_matmul(
452 A: torch.Tensor,
453 B: torch.Tensor,
454 As: torch.Tensor,
455 Bs: torch.Tensor,
456 block_size: List[int],
457 output_dtype: torch.dtype = torch.float16,
458) -> torch.Tensor:
459 device = A.device
460 assert len(block_size) == 2
461 block_n, block_k = block_size
463 # handle non-contiguous inputs if necessary
464 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1:
465 A = A.contiguous()
466 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1:
467 B = B.contiguous()
468 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1:
469 As = As.contiguous()
470 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1:
471 Bs = Bs.contiguous()
473 # checks constraints
474 assert A.shape[-1] == B.shape[-1], "incompatible dimensions"
475 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch"
476 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape"
477 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D"
479 M = A.numel() // A.shape[-1]
480 N, K = B.shape
481 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension"
482 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension"
484 # allocates output
485 output_shape = A.shape[:-1] + (N,)
486 c = torch.empty(output_shape, device=device, dtype=output_dtype)
488 a_2d = A.reshape(M, K)
489 as_2d = As.reshape(M, As.shape[-1])
490 c_2d = c.reshape(M, N)
492 return general_w8a8_block_fp8_matmul(
493 a_2d,
494 B,
495 c_2d,
496 as_2d,
497 Bs,
498 M,
499 N,
500 K,
501 block_n,
502 block_k,
503 ).reshape(c.shape)