Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_matmul.py: 0%
246 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import functools
2import logging
3import os
4from typing import Any, Dict, List, Optional
6import torch
7import triton
8import triton.language as tl
9import yaml
11from flag_gems import runtime
12from flag_gems.runtime import torch_device_fn
13from flag_gems.utils import libentry, libtuner
15logger = logging.getLogger(
16 "flag_gems.runtime.backend._nvidia.hopper.ops.w8a8_block_fp8_matmul"
17)
18CACHE_USAGE_THRESHOLD = 0.8
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(
21 os.path.dirname(__file__),
22 "..",
23 "w8a8_block_fp8_matmul_hopper_expand.yaml",
24 )
25)
27TMA_ON = False
30@functools.lru_cache
31def get_w8a8_block_fp8_hopper_configs(N: int, K: int) -> Optional[Dict[int, Any]]:
32 device_name = torch.cuda.get_device_name().replace(" ", "_")
33 name_parts = device_name.split("_")
34 if any(part.startswith("H20") for part in name_parts):
35 device_name = "NVIDIA_H20"
36 file_name = "w8a8_block_fp8_matmul_hopper.yaml"
38 cfg_file = os.path.join(os.path.dirname(__file__), "..", file_name)
40 if os.path.exists(cfg_file):
41 with open(cfg_file) as f:
42 logger.info(
43 "Using config from %s for W8A8 block FP8 kernel.",
44 cfg_file,
45 )
46 dev_data = yaml.safe_load(f).get(device_name, {})
47 NK_data = dev_data.get(f"{N},{K}", {})
49 result = {}
50 for k, p in NK_data.items():
51 # unpack the list into dictionary
52 result[int(k)] = {
53 "BLOCK_SIZE_M": p[0],
54 "BLOCK_SIZE_N": p[1],
55 "BLOCK_SIZE_K": p[2],
56 "GROUP_SIZE_M": p[3],
57 "num_warps": p[4],
58 "num_stages": p[5],
59 }
61 if not result:
62 return None
63 return result
65 logger.warning(
66 "Using default W8A8 Block FP8 kernel config. Performance might "
67 "be sub-optimal! Config file not found at %s",
68 cfg_file,
69 )
70 return None
73def _get_placeholder_tuner_configs(pre_hook=None):
74 # Placeholder config for libtuner initialization before runtime shapes are known.
75 return [
76 triton.Config(
77 {
78 "BLOCK_M": 64,
79 "BLOCK_N": 64,
80 "BLOCK_K": 128,
81 "GROUP_M": 8,
82 },
83 num_stages=3,
84 num_warps=4,
85 pre_hook=pre_hook,
86 )
87 ]
90def _get_fixed_matmul_meta(M: int, N: int, K: int, block_n: int, block_k: int):
91 configs = get_w8a8_block_fp8_hopper_configs(N, K)
92 if not configs:
93 return {
94 "BLOCK_M": 64,
95 "BLOCK_N": block_n,
96 "BLOCK_K": block_k,
97 "GROUP_M": 32,
98 "num_warps": 4,
99 "num_stages": 2,
100 }
102 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
103 return {
104 "BLOCK_M": config["BLOCK_SIZE_M"],
105 "BLOCK_N": config["BLOCK_SIZE_N"],
106 "BLOCK_K": config["BLOCK_SIZE_K"],
107 "GROUP_M": config["GROUP_SIZE_M"],
108 "num_warps": config["num_warps"],
109 "num_stages": config["num_stages"],
110 }
113def is_tma_compatible(a, b, n, k):
114 """
115 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
117 TMA requires 128-bit (16-byte) alignment for memory access.
118 For FP8 inputs (1 byte/element), both N and K must be multiples of 16
119 to satisfy the 16-byte alignment requirement.
121 Args:
122 a, b: Input tensors
123 n, k: Matrix dimensions
125 Returns:
126 bool: True if compatible with TMA's 128-bit alignment requirement
127 """
128 return (
129 a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
130 and b.dtype == a.dtype
131 and TMA_ON
132 and n % 16 == 0
133 and k % 16 == 0
134 )
137def matmul_tma_set_block_size_hook(nargs):
138 BLOCK_M = nargs["BLOCK_M"]
139 BLOCK_N = nargs["BLOCK_N"]
140 BLOCK_K = nargs["BLOCK_K"]
141 if nargs["A_ROW_MAJOR"]:
142 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
143 else:
144 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
146 if nargs["B_ROW_MAJOR"]:
147 # B is stored as [N, K] in row-major order, and the kernel loads an
148 # [BLOCK_N, BLOCK_K] tile before transposing it to [BLOCK_K, BLOCK_N].
149 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
150 else:
151 # For the column-major case we build the descriptor on B.T with shape
152 # [K, N], so the loaded tile already has layout [BLOCK_K, BLOCK_N].
153 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
155 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
158@libentry()
159@libtuner(
160 configs=runtime.ops_get_configs(
161 "w8a8_block_fp8_general",
162 pre_hook=None,
163 yaml_path=EXPAND_CONFIG_FILENAME,
164 )
165 if os.environ.get("USE_FLAGTUNE") == "1"
166 else _get_placeholder_tuner_configs(pre_hook=None),
167 key=["M", "N", "K", "stride_am", "stride_bk"],
168 strategy=runtime.get_expand_config(
169 "w8a8_block_fp8_general", yaml_path=EXPAND_CONFIG_FILENAME
170 )["strategy"]
171 if os.environ.get("USE_FLAGTUNE") == "1"
172 else ["align32", "align32", "align32", "align32", "align32"],
173 warmup=5,
174 rep=5,
175)
176@triton.jit
177def w8a8_block_fp8_matmul_kernel_general(
178 A,
179 B,
180 C,
181 As,
182 Bs,
183 M,
184 N,
185 K,
186 group_n,
187 group_k,
188 stride_am,
189 stride_ak,
190 stride_bk,
191 stride_bn,
192 stride_cm,
193 stride_cn,
194 stride_As_m,
195 stride_As_k,
196 stride_Bs_k,
197 stride_Bs_n,
198 BLOCK_M: tl.constexpr,
199 BLOCK_N: tl.constexpr,
200 BLOCK_K: tl.constexpr,
201 GROUP_M: tl.constexpr,
202):
203 pid = tl.program_id(axis=0)
204 num_pid_m = tl.cdiv(M, BLOCK_M)
205 num_pid_n = tl.cdiv(N, BLOCK_N)
206 num_pid_in_group = GROUP_M * num_pid_n
207 group_id = pid // num_pid_in_group
208 first_pid_m = group_id * GROUP_M
209 group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
210 pid_m = first_pid_m + (pid % group_size_m)
211 pid_n = (pid % num_pid_in_group) // group_size_m
213 offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
214 offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
215 offs_k = tl.arange(0, BLOCK_K)
216 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
217 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
219 As_ptrs = As + offs_am * stride_As_m
220 offs_bsn = offs_bn // group_n
221 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
223 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
224 for k in range(0, tl.cdiv(K, BLOCK_K)):
225 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
226 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
228 k_start = k * BLOCK_K
229 offs_ks = k_start // group_k
230 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
231 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
232 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
233 a_ptrs += BLOCK_K * stride_ak
234 b_ptrs += BLOCK_K * stride_bk
236 if C.dtype.element_ty == tl.bfloat16:
237 c = acc.to(tl.bfloat16)
238 elif C.dtype.element_ty == tl.float16:
239 c = acc.to(tl.float16)
240 else:
241 c = acc.to(tl.float32)
243 offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
244 offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
245 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
246 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
247 tl.store(c_ptrs, c, mask=c_mask)
250@libentry()
251@libtuner(
252 configs=runtime.ops_get_configs(
253 "w8a8_block_fp8_general_tma",
254 pre_hook=matmul_tma_set_block_size_hook,
255 yaml_path=EXPAND_CONFIG_FILENAME,
256 )
257 if os.environ.get("USE_FLAGTUNE") == "1"
258 else _get_placeholder_tuner_configs(pre_hook=matmul_tma_set_block_size_hook),
259 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
260 strategy=runtime.get_expand_config(
261 "w8a8_block_fp8_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
262 )["strategy"]
263 if os.environ.get("USE_FLAGTUNE") == "1"
264 else ["align32", "align32", "align32", "align32", "align32", "default"],
265 warmup=5,
266 rep=5,
267)
268@triton.jit
269def w8a8_block_fp8_matmul_kernel_host_tma(
270 a_desc,
271 b_desc,
272 c_desc,
273 As,
274 Bs,
275 M,
276 N,
277 K,
278 group_n,
279 group_k,
280 stride_am,
281 stride_ak,
282 stride_bn,
283 stride_bk,
284 stride_cm,
285 stride_cn,
286 stride_As_m,
287 stride_As_k,
288 stride_Bs_n,
289 stride_Bs_k,
290 BLOCK_M: tl.constexpr,
291 BLOCK_N: tl.constexpr,
292 BLOCK_K: tl.constexpr,
293 GROUP_M: tl.constexpr,
294 A_ROW_MAJOR: tl.constexpr,
295 B_ROW_MAJOR: tl.constexpr,
296 dtype: tl.constexpr,
297 enable_warp_specialization=True,
298):
299 # matrix multiplication
300 pid = tl.program_id(0)
301 grid_m = tl.cdiv(M, BLOCK_M)
302 grid_n = tl.cdiv(N, BLOCK_N)
303 # re-order program ID for better L2 performance
304 width = GROUP_M * grid_n
305 group_id = pid // width
306 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
307 pid_m = group_id * GROUP_M + (pid % group_size)
308 pid_n = (pid % width) // group_size
310 offset_am = (pid_m * BLOCK_M).to(tl.int32)
311 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
312 offs_am = offset_am + tl.arange(0, BLOCK_M)
313 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
314 iters = tl.cdiv(K, BLOCK_K)
315 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
317 for k in range(iters):
318 offset_ak = (k * BLOCK_K).to(tl.int32)
320 if A_ROW_MAJOR:
321 a = a_desc.load([offset_am, offset_ak])
322 else:
323 a_t = a_desc.load([offset_ak, offset_am])
324 a = tl.trans(a_t)
326 if B_ROW_MAJOR:
327 b_t = b_desc.load([offset_bn, offset_ak])
328 b = tl.trans(b_t)
329 else:
330 b = b_desc.load([offset_ak, offset_bn])
332 offs_ks = (offset_ak // group_k).to(tl.int32)
333 a_s = tl.load(
334 As + offs_am * stride_As_m + offs_ks * stride_As_k,
335 mask=offs_am < M,
336 other=0.0,
337 )
338 b_s = tl.load(
339 Bs + (offs_bn // group_n) * stride_Bs_n + offs_ks * stride_Bs_k,
340 mask=offs_bn < N,
341 other=0.0,
342 )
343 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
345 c_desc.store([offset_am, offset_bn], acc.to(c_desc.dtype))
348@libentry()
349@libtuner(
350 configs=runtime.ops_get_configs(
351 "w8a8_block_fp8_general_splitk",
352 yaml_path=EXPAND_CONFIG_FILENAME,
353 )
354 if os.environ.get("USE_FLAGTUNE") == "1"
355 else _get_placeholder_tuner_configs(pre_hook=None),
356 key=["M", "N", "K", "stride_am", "stride_bk"],
357 strategy=runtime.get_expand_config(
358 "w8a8_block_fp8_general_splitk", yaml_path=EXPAND_CONFIG_FILENAME
359 )["strategy"]
360 if os.environ.get("USE_FLAGTUNE") == "1"
361 else ["align32", "align32", "align32", "align32", "align32"],
362 warmup=5,
363 rep=5,
364)
365@triton.jit
366def w8a8_block_fp8_matmul_kernel_splitk(
367 A,
368 B,
369 C,
370 As,
371 Bs,
372 M,
373 N,
374 K,
375 group_n,
376 group_k,
377 stride_am,
378 stride_ak,
379 stride_bk,
380 stride_bn,
381 stride_cm,
382 stride_cn,
383 stride_As_m,
384 stride_As_k,
385 stride_Bs_k,
386 stride_Bs_n,
387 BLOCK_M: tl.constexpr,
388 BLOCK_N: tl.constexpr,
389 BLOCK_K: tl.constexpr,
390 SPLIT_K: tl.constexpr,
391):
392 pid = tl.program_id(0)
393 pid_k = tl.program_id(1)
395 # grid_m = tl.cdiv(M, BLOCK_M)
396 grid_n = tl.cdiv(N, BLOCK_N)
397 pid_m = pid // grid_n
398 pid_n = pid % grid_n
400 offset_am = pid_m * BLOCK_M
401 offset_bn = pid_n * BLOCK_N
402 offs_am = offset_am + tl.arange(0, BLOCK_M)
403 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
405 total_k_iters = tl.cdiv(K, BLOCK_K)
406 k_per_split = tl.cdiv(total_k_iters, SPLIT_K)
407 k_start = pid_k * k_per_split
408 k_end = min((pid_k + 1) * k_per_split, total_k_iters)
410 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
411 for k in range(k_start, k_end):
412 offset_k = k * BLOCK_K
413 offs_k = offset_k + tl.arange(0, BLOCK_K)
415 a = tl.load(
416 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak,
417 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K),
418 other=0.0,
419 )
420 b = tl.load(
421 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn,
422 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N),
423 other=0.0,
424 )
426 offs_ks = offset_k // group_k
427 a_s = tl.load(
428 As + offs_am * stride_As_m + offs_ks * stride_As_k,
429 mask=offs_am < M,
430 other=0.0,
431 )
432 b_s = tl.load(
433 Bs + offs_ks * stride_Bs_k + (offs_bn // group_n) * stride_Bs_n,
434 mask=offs_bn < N,
435 other=0.0,
436 )
437 acc += tl.dot(a, b, out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
439 offs_cm = offset_am + tl.arange(0, BLOCK_M)
440 offs_cn = offset_bn + tl.arange(0, BLOCK_N)
441 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
442 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
443 if C.dtype.element_ty == tl.bfloat16:
444 tl.atomic_add(c_ptrs, acc.to(tl.bfloat16), mask=mask)
445 elif C.dtype.element_ty == tl.float16:
446 tl.atomic_add(c_ptrs, acc.to(tl.float16), mask=mask)
447 else:
448 tl.atomic_add(c_ptrs, acc.to(tl.float32), mask=mask)
451def general_w8a8_block_fp8_matmul(a, b, c, a_s, b_s, M, N, K, group_n, group_k):
452 logger.debug(
453 "GEMS w8a8_block_fp8_matmul-hopper, [scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
454 "[A column-major]: %s, [B column-major]: %s",
455 M,
456 N,
457 K,
458 a.stride(0) == 1,
459 b.stride(0) == 1,
460 )
462 use_flagtune = os.environ.get("USE_FLAGTUNE") == "1"
464 # Split-K path for small-N, large-K shapes
465 if N <= 512 and K == 7168 and M < 8276:
466 if use_flagtune:
467 splitk_grid = lambda META: (
468 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
469 META["SPLIT_K"],
470 )
471 c.zero_()
472 with torch_device_fn.device(a.device):
473 w8a8_block_fp8_matmul_kernel_splitk[splitk_grid](
474 a,
475 b,
476 c,
477 a_s,
478 b_s,
479 M,
480 N,
481 K,
482 group_n,
483 group_k,
484 a.stride(0),
485 a.stride(1),
486 b.stride(1),
487 b.stride(0),
488 c.stride(0),
489 c.stride(1),
490 a_s.stride(0),
491 a_s.stride(1),
492 b_s.stride(1),
493 b_s.stride(0),
494 )
495 else:
496 SPLITK_BLOCK_K = 128
497 SPLITK_BLOCK_M = 16 if M <= 16 else 64
498 SPLITK_BLOCK_N = 64 if N > 256 else 32
500 grid_m = triton.cdiv(M, SPLITK_BLOCK_M)
501 grid_n = triton.cdiv(N, SPLITK_BLOCK_N)
502 grid_mn = grid_m * grid_n
503 total_k_iters = triton.cdiv(K, SPLITK_BLOCK_K)
505 SM_COUNT = torch.cuda.get_device_properties(a.device).multi_processor_count
506 split_k = min(total_k_iters, max(4, 2 * SM_COUNT // max(grid_mn, 1)))
508 c.zero_()
509 splitk_grid = (grid_mn, split_k)
511 with torch_device_fn.device(a.device):
512 w8a8_block_fp8_matmul_kernel_splitk.fn.fn[splitk_grid](
513 a,
514 b,
515 c,
516 a_s,
517 b_s,
518 M,
519 N,
520 K,
521 group_n,
522 group_k,
523 a.stride(0),
524 a.stride(1),
525 b.stride(1),
526 b.stride(0),
527 c.stride(0),
528 c.stride(1),
529 a_s.stride(0),
530 a_s.stride(1),
531 b_s.stride(1),
532 b_s.stride(0),
533 BLOCK_M=SPLITK_BLOCK_M,
534 BLOCK_N=SPLITK_BLOCK_N,
535 BLOCK_K=SPLITK_BLOCK_K,
536 SPLIT_K=split_k,
537 )
538 return c
540 grid = lambda meta: (
541 triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),
542 )
543 fixed_meta = (
544 None
545 if use_flagtune
546 else _get_fixed_matmul_meta(M, N, K, block_n=group_n, block_k=group_k)
547 )
549 if hasattr(
550 triton.tools.tensor_descriptor, "TensorDescriptor"
551 ) and is_tma_compatible(a, b, N, K):
552 a_row_major = a.stride(1) == 1
553 b_row_major = b.stride(1) == 1
554 dummy_block = [1, 1]
555 # triton 3.5.0
556 from triton.tools.tensor_descriptor import TensorDescriptor
558 if a_row_major:
559 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
560 else:
561 a_desc = TensorDescriptor(a.T, a.T.shape, a.T.stride(), dummy_block)
563 if b_row_major:
564 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
565 else:
566 b_desc = TensorDescriptor(b.T, b.T.shape, b.T.stride(), dummy_block)
568 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
569 if use_flagtune:
570 launch = lambda: w8a8_block_fp8_matmul_kernel_host_tma[grid](
571 a_desc,
572 b_desc,
573 c_desc,
574 a_s,
575 b_s,
576 M,
577 N,
578 K,
579 group_n,
580 group_k,
581 a.stride(0),
582 a.stride(1),
583 b.stride(0),
584 b.stride(1),
585 c.stride(0),
586 c.stride(1),
587 a_s.stride(0),
588 a_s.stride(1),
589 b_s.stride(0),
590 b_s.stride(1),
591 A_ROW_MAJOR=a_row_major,
592 B_ROW_MAJOR=b_row_major,
593 dtype=str(a.dtype).split(".")[-1],
594 )
595 else:
596 # The fixed-config path bypasses libtuner, so we must apply the
597 # descriptor block-shape update that would normally run via the
598 # TMA pre_hook before launching the underlying JIT kernel.
599 matmul_tma_set_block_size_hook(
600 {
601 "BLOCK_M": fixed_meta["BLOCK_M"],
602 "BLOCK_N": fixed_meta["BLOCK_N"],
603 "BLOCK_K": fixed_meta["BLOCK_K"],
604 "a_desc": a_desc,
605 "b_desc": b_desc,
606 "c_desc": c_desc,
607 "A_ROW_MAJOR": a_row_major,
608 "B_ROW_MAJOR": b_row_major,
609 }
610 )
611 launch = lambda: w8a8_block_fp8_matmul_kernel_host_tma.fn.fn[grid](
612 a_desc,
613 b_desc,
614 c_desc,
615 a_s,
616 b_s,
617 M,
618 N,
619 K,
620 group_n,
621 group_k,
622 a.stride(0),
623 a.stride(1),
624 b.stride(0),
625 b.stride(1),
626 c.stride(0),
627 c.stride(1),
628 a_s.stride(0),
629 a_s.stride(1),
630 b_s.stride(0),
631 b_s.stride(1),
632 A_ROW_MAJOR=a_row_major,
633 B_ROW_MAJOR=b_row_major,
634 dtype=str(a.dtype).split(".")[-1],
635 **fixed_meta,
636 )
638 with torch_device_fn.device(a.device):
639 launch()
640 else:
642 def alloc_fn(size: int, align: int, stream: Optional[int]):
643 return torch.empty(size, dtype=torch.int8, device=a.device)
645 triton.set_allocator(alloc_fn)
646 if use_flagtune:
647 launch = lambda: w8a8_block_fp8_matmul_kernel_general[grid](
648 a,
649 b,
650 c,
651 a_s,
652 b_s,
653 M,
654 N,
655 K,
656 group_n,
657 group_k,
658 a.stride(0),
659 a.stride(1),
660 b.stride(1),
661 b.stride(0),
662 c.stride(0),
663 c.stride(1),
664 a_s.stride(0),
665 a_s.stride(1),
666 b_s.stride(1),
667 b_s.stride(0),
668 )
669 else:
670 launch = lambda: w8a8_block_fp8_matmul_kernel_general.fn.fn[grid](
671 a,
672 b,
673 c,
674 a_s,
675 b_s,
676 M,
677 N,
678 K,
679 group_n,
680 group_k,
681 a.stride(0),
682 a.stride(1),
683 b.stride(1),
684 b.stride(0),
685 c.stride(0),
686 c.stride(1),
687 a_s.stride(0),
688 a_s.stride(1),
689 b_s.stride(1),
690 b_s.stride(0),
691 **fixed_meta,
692 )
694 with torch_device_fn.device(a.device):
695 launch()
696 return c
699def w8a8_block_fp8_matmul(
700 A: torch.Tensor,
701 B: torch.Tensor,
702 As: torch.Tensor,
703 Bs: torch.Tensor,
704 block_size: List[int],
705 output_dtype: torch.dtype = torch.float16,
706) -> torch.Tensor:
707 device = A.device
708 assert len(block_size) == 2
709 block_n, block_k = block_size
711 # handle non-contiguous inputs if necessary
712 if A.ndim >= 2 and A.stride(-2) > 1 and A.stride(-1) > 1:
713 A = A.contiguous()
714 if B.ndim == 2 and B.stride(0) > 1 and B.stride(1) > 1:
715 B = B.contiguous()
716 if As.ndim >= 2 and As.stride(-2) > 1 and As.stride(-1) > 1:
717 As = As.contiguous()
718 if Bs.ndim == 2 and Bs.stride(0) > 1 and Bs.stride(1) > 1:
719 Bs = Bs.contiguous()
721 # checks constraints
722 assert A.shape[-1] == B.shape[-1], "incompatible dimensions"
723 assert A.shape[:-1] == As.shape[:-1], "A and As dimensions mismatch"
724 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1], "invalid As shape"
725 assert B.ndim == 2 and Bs.ndim == 2, "B and Bs must be 2D"
727 M = A.numel() // A.shape[-1]
728 N, K = B.shape
729 assert triton.cdiv(N, block_n) == Bs.shape[0], "invalid Bs N dimension"
730 assert triton.cdiv(K, block_k) == Bs.shape[1], "invalid Bs K dimension"
732 # allocates output
733 output_shape = A.shape[:-1] + (N,)
734 c = torch.empty(output_shape, device=device, dtype=output_dtype)
736 a_2d = A.reshape(M, K)
737 as_2d = As.reshape(M, As.shape[-1])
738 c_2d = c.reshape(M, N)
740 return general_w8a8_block_fp8_matmul(
741 a_2d,
742 B,
743 c_2d,
744 as_2d,
745 Bs,
746 M,
747 N,
748 K,
749 block_n,
750 block_k,
751 ).reshape(c.shape)