Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_matmul.py: 0%
178 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import functools
2import logging
3import os
4from typing import Any, Dict, List, Optional
6import torch
7import triton
8import triton.language as tl
9import yaml
11from flag_gems import runtime
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
15logger = logging.getLogger(
16 "flag_gems.runtime.backend._nvidia.hopper.ops.w8a8_block_fp8_matmul"
17)
18CACHE_USAGE_THRESHOLD = 0.8
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(
21 os.path.dirname(__file__),
22 "..",
23 "w8a8_block_fp8_matmul_hopper_expand.yaml",
24 )
25)
28@functools.lru_cache
29def get_w8a8_block_fp8_hopper_configs(N: int, K: int) -> Optional[Dict[int, Any]]:
30 device_name = torch.cuda.get_device_name().replace(" ", "_")
31 file_name = "w8a8_block_fp8_matmul_hopper.yaml"
33 cfg_file = os.path.join(os.path.dirname(__file__), "..", file_name)
35 if os.path.exists(cfg_file):
36 with open(cfg_file) as f:
37 logger.info(
38 "Using config from %s for W8A8 block FP8 kernel.",
39 cfg_file,
40 )
41 dev_data = yaml.safe_load(f).get(device_name, {})
42 NK_data = dev_data.get(f"{N},{K}", {})
44 result = {}
45 for k, p in NK_data.items():
46 # unpack the list into dictionary
47 result[int(k)] = {
48 "BLOCK_SIZE_M": p[0],
49 "BLOCK_SIZE_N": p[1],
50 "BLOCK_SIZE_K": p[2],
51 "GROUP_SIZE_M": p[3],
52 "num_warps": p[4],
53 "num_stages": p[5],
54 }
56 if not result:
57 return None
58 return result
60 logger.warning(
61 "Using default W8A8 Block FP8 kernel config. Performance might "
62 "be sub-optimal! Config file not found at %s",
63 cfg_file,
64 )
65 return None
68def _get_placeholder_tuner_configs(pre_hook=None):
69 # Placeholder config for libtuner initialization before runtime shapes are known.
70 return [
71 triton.Config(
72 {
73 "BLOCK_M": 64,
74 "BLOCK_N": 64,
75 "BLOCK_K": 128,
76 "GROUP_M": 8,
77 },
78 num_stages=3,
79 num_warps=4,
80 pre_hook=pre_hook,
81 )
82 ]
85def _get_fixed_matmul_meta(M: int, N: int, K: int, block_n: int, block_k: int):
86 configs = get_w8a8_block_fp8_hopper_configs(N, K)
87 if not configs:
88 return {
89 "BLOCK_M": 64,
90 "BLOCK_N": block_n,
91 "BLOCK_K": block_k,
92 "GROUP_M": 32,
93 "num_warps": 4,
94 "num_stages": 2,
95 }
97 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
98 return {
99 "BLOCK_M": config["BLOCK_SIZE_M"],
100 "BLOCK_N": config["BLOCK_SIZE_N"],
101 "BLOCK_K": config["BLOCK_SIZE_K"],
102 "GROUP_M": config["GROUP_SIZE_M"],
103 "num_warps": config["num_warps"],
104 "num_stages": config["num_stages"],
105 }
108@libentry()
109@libtuner(
110 configs=runtime.ops_get_configs(
111 "w8a8_block_fp8_general",
112 pre_hook=None,
113 yaml_path=EXPAND_CONFIG_FILENAME,
114 )
115 if os.environ.get("USE_FLAGTUNE") == "1"
116 else _get_placeholder_tuner_configs(pre_hook=None),
117 key=["M", "N", "K", "stride_am", "stride_bk"],
118 strategy=runtime.get_expand_config(
119 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME
120 )["strategy"]
121 if os.environ.get("USE_FLAGTUNE") == "1"
122 else ["align32", "align32", "align32", "align32", "align32"],
123 warmup=5,
124 rep=5,
125)
126@triton.jit
127def w8a8_block_fp8_matmul_kernel_general(
128 A,
129 B,
130 C,
131 As,
132 Bs,
133 M,
134 N,
135 K,
136 group_n,
137 group_k,
138 stride_am,
139 stride_ak,
140 stride_bk,
141 stride_bn,
142 stride_cm,
143 stride_cn,
144 stride_As_m,
145 stride_As_k,
146 stride_Bs_k,
147 stride_Bs_n,
148 BLOCK_M: tl.constexpr,
149 BLOCK_N: tl.constexpr,
150 BLOCK_K: tl.constexpr,
151 GROUP_M: tl.constexpr,
152):
153 pid = tl.program_id(axis=0)
154 num_pid_m = tl.cdiv(M, BLOCK_M)
155 num_pid_n = tl.cdiv(N, BLOCK_N)
156 num_pid_in_group = GROUP_M * num_pid_n
157 group_id = pid // num_pid_in_group
158 first_pid_m = group_id * GROUP_M
159 group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
160 pid_m = first_pid_m + (pid % group_size_m)
161 pid_n = (pid % num_pid_in_group) // group_size_m
163 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
164 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
165 offs_k = tl.arange(0, BLOCK_K)
166 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
167 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
169 As_ptrs = As + offs_am * stride_As_m
170 offs_bsn = offs_bn // group_n
171 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
173 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
174 for k in range(0, tl.cdiv(K, BLOCK_K)):
175 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
176 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
178 k_start = k * BLOCK_K
179 offs_ks = k_start // group_k
180 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
181 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
182 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
183 a_ptrs += BLOCK_K * stride_ak
184 b_ptrs += BLOCK_K * stride_bk
186 if C.dtype.element_ty == tl.bfloat16:
187 c = acc.to(tl.bfloat16)
188 elif C.dtype.element_ty == tl.float16:
189 c = acc.to(tl.float16)
190 else:
191 c = acc.to(tl.float32)
193 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
194 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
195 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
196 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
197 tl.store(c_ptrs, c, mask=c_mask)
200@libentry()
201@libtuner(
202 configs=runtime.ops_get_configs(
203 "w8a8_block_fp8_general_splitk",
204 yaml_path=EXPAND_CONFIG_FILENAME,
205 )
206 if os.environ.get("USE_FLAGTUNE") == "1"
207 else _get_placeholder_tuner_configs(pre_hook=None),
208 key=["M", "N", "K", "stride_am", "stride_bk"],
209 strategy=runtime.get_expand_config(
210 "w8a8_block_fp8_general_splitk", yaml_path=EXPAND_CONFIG_FILENAME
211 )["strategy"]
212 if os.environ.get("USE_FLAGTUNE") == "1"
213 else ["align32", "align32", "align32", "align32", "align32"],
214 warmup=5,
215 rep=5,
216)
217@triton.jit
218def w8a8_block_fp8_matmul_kernel_splitk(
219 A,
220 B,
221 C,
222 As,
223 Bs,
224 M,
225 N,
226 K,
227 group_n,
228 group_k,
229 stride_am,
230 stride_ak,
231 stride_bk,
232 stride_bn,
233 stride_cm,
234 stride_cn,
235 stride_As_m,
236 stride_As_k,
237 stride_Bs_k,
238 stride_Bs_n,
239 BLOCK_M: tl.constexpr,
240 BLOCK_N: tl.constexpr,
241 BLOCK_K: tl.constexpr,
242 SPLIT_K: tl.constexpr,
243):
244 pid = tl.program_id(0)
245 pid_k = tl.program_id(1)
247 # grid_m = tl.cdiv(M, BLOCK_M)
248 grid_n = tl.cdiv(N, BLOCK_N)
249 pid_m = pid // grid_n
250 pid_n = pid % grid_n
252 offset_am = pid_m * BLOCK_M
253 offset_bn = pid_n * BLOCK_N
254 offs_am = offset_am + tl.arange(0, BLOCK_M)
255 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
257 total_k_iters = tl.cdiv(K, BLOCK_K)
258 k_per_split = tl.cdiv(total_k_iters, SPLIT_K)
259 k_start = pid_k * k_per_split
260 k_end = min((pid_k + 1) * k_per_split, total_k_iters)
262 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
263 for k in range(k_start, k_end):
264 offset_k = k * BLOCK_K
265 offs_k = offset_k + tl.arange(0, BLOCK_K)
267 a = tl.load(
268 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak,
269 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K),
270 other=0.0,
271 )
272 b = tl.load(
273 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn,
274 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N),
275 other=0.0,
276 )
278 offs_ks = offset_k // group_k
279 a_s = tl.load(
280 As + offs_am * stride_As_m + offs_ks * stride_As_k,
281 mask=offs_am < M,
282 other=0.0,
283 )
284 b_s = tl.load(
285 Bs + offs_ks * stride_Bs_k + (offs_bn // group_n) * stride_Bs_n,
286 mask=offs_bn < N,
287 other=0.0,
288 )
289 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
291 offs_cm = offset_am + tl.arange(0, BLOCK_M)
292 offs_cn = offset_bn + tl.arange(0, BLOCK_N)
293 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
294 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
295 if C.dtype.element_ty == tl.bfloat16:
296 tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mask)
297 elif C.dtype.element_ty == tl.float16:
298 tl.atomic_add(c_ptrs, acc.to(tl.float16), mask=mask)
299 else:
300 tl.atomic_add(c_ptrs, acc.to(tl.float32), mask=mask)
303def general_w8a8_block_fp8_matmul(a, b, c, a_s, b_s, M, N, K, group_n, group_k):
304 logger.debug(
305 "GEMS w8a8_block_fp8_matmul-hopper, [scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
306 "[A column-major]: %s, [B column-major]: %s",
307 M,
308 N,
309 K,
310 a.stride(0) == 1,
311 b.stride(0) == 1,
312 )
314 use_flagtune = os.environ.get("USE_FLAGTUNE") == "1"
316 # Split-K path for small-N, large-K shapes
317 if M < 2048 and N < 2112 and K >= 4096:
318 if use_flagtune:
319 splitk_grid = lambda META: (
320 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
321 META["SPLIT_K"],
322 )
323 c.zero_()
324 with torch_device_fn.device(a.device):
325 w8a8_block_fp8_matmul_kernel_splitk[splitk_grid](
326 a,
327 b,
328 c,
329 a_s,
330 b_s,
331 M,
332 N,
333 K,
334 group_n,
335 group_k,
336 a.stride(0),
337 a.stride(1),
338 b.stride(1),
339 b.stride(0),
340 c.stride(0),
341 c.stride(1),
342 a_s.stride(0),
343 a_s.stride(1),
344 b_s.stride(1),
345 b_s.stride(0),
346 )
347 else:
348 SPLITK_BLOCK_K = 128
349 SPLITK_BLOCK_M = 16 if M <= 16 else 64
350 SPLITK_BLOCK_N = 64 if N > 256 else 32
352 grid_m = triton.cdiv(M, SPLITK_BLOCK_M)
353 grid_n = triton.cdiv(N, SPLITK_BLOCK_N)
354 grid_mn = grid_m * grid_n
355 total_k_iters = triton.cdiv(K, SPLITK_BLOCK_K)
357 SM_COUNT = torch.cuda.get_device_properties(a.device).multi_processor_count
358 split_k = min(total_k_iters, max(4, 2 * SM_COUNT // max(grid_mn, 1)))
360 c.zero_()
361 splitk_grid = (grid_mn, split_k)
363 with torch_device_fn.device(a.device):
364 w8a8_block_fp8_matmul_kernel_splitk.fn.fn[splitk_grid](
365 a,
366 b,
367 c,
368 a_s,
369 b_s,
370 M,
371 N,
372 K,
373 group_n,
374 group_k,
375 a.stride(0),
376 a.stride(1),
377 b.stride(1),
378 b.stride(0),
379 c.stride(0),
380 c.stride(1),
381 a_s.stride(0),
382 a_s.stride(1),
383 b_s.stride(1),
384 b_s.stride(0),
385 BLOCK_M=SPLITK_BLOCK_M,
386 BLOCK_N=SPLITK_BLOCK_N,
387 BLOCK_K=SPLITK_BLOCK_K,
388 SPLIT_K=split_k,
389 )
390 return c
392 else:
393 grid = lambda meta: (
394 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
395 )
396 fixed_meta = (
397 None
398 if use_flagtune
399 else _get_fixed_matmul_meta(M, N, K, block_n=group_n, block_k=group_k)
400 )
402 def alloc_fn(size: int, align: int, stream: Optional[int]):
403 return torch.empty(size, dtype=torch.int8, device=a.device)
405 triton.set_allocator(alloc_fn)
406 if use_flagtune:
407 launch = lambda: w8a8_block_fp8_matmul_kernel_general[grid](
408 a,
409 b,
410 c,
411 a_s,
412 b_s,
413 M,
414 N,
415 K,
416 group_n,
417 group_k,
418 a.stride(0),
419 a.stride(1),
420 b.stride(1),
421 b.stride(0),
422 c.stride(0),
423 c.stride(1),
424 a_s.stride(0),
425 a_s.stride(1),
426 b_s.stride(1),
427 b_s.stride(0),
428 )
429 else:
430 launch = lambda: w8a8_block_fp8_matmul_kernel_general.fn.fn[grid](
431 a,
432 b,
433 c,
434 a_s,
435 b_s,
436 M,
437 N,
438 K,
439 group_n,
440 group_k,
441 a.stride(0),
442 a.stride(1),
443 b.stride(1),
444 b.stride(0),
445 c.stride(0),
446 c.stride(1),
447 a_s.stride(0),
448 a_s.stride(1),
449 b_s.stride(1),
450 b_s.stride(0),
451 **fixed_meta,
452 )
454 with torch_device_fn.device(a.device):
455 launch()
456 return c
459def w8a8_block_fp8_matmul(
460 A: torch.Tensor,
461 B: torch.Tensor,
462 As: torch.Tensor,
463 Bs: torch.Tensor,
464 block_size: List[int],
465 output_dtype: torch.dtype = torch.float16,
466) -> torch.Tensor:
467 device = A.device
468 assert len(block_size) == 2
469 block_n, block_k = block_size
471 # handle non-contiguous inputs if necessary
472 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1:
473 A = A.contiguous()
474 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1:
475 B = B.contiguous()
476 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1:
477 As = As.contiguous()
478 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1:
479 Bs = Bs.contiguous()
481 # checks constraints
482 assert A.shape[-1] == B.shape[-1], "incompatible dimensions"
483 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch"
484 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape"
485 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D"
487 M = A.numel() // A.shape[-1]
488 N, K = B.shape
489 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension"
490 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension"
492 # allocates output
493 output_shape = A.shape[:-1] + (N,)
494 c = torch.empty(output_shape, device=device, dtype=output_dtype)
496 a_2d = A.reshape(M, K)
497 as_2d = As.reshape(M, As.shape[-1])
498 c_2d = c.reshape(M, N)
500 return general_w8a8_block_fp8_matmul(
501 a_2d,
502 B,
503 c_2d,
504 as_2d,
505 Bs,
506 M,
507 N,
508 K,
509 block_n,
510 block_k,
511 ).reshape(c.shape)