Coverage for src/flag_gems/ops/w8a8_block_fp8_matmul.py: 53%
97 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import functools
2import logging
3import os
4from typing import Any, Dict, List, Optional
6import torch
7import triton
8import triton.language as tl
9import yaml
11import flag_gems
13logger = logging.getLogger(__name__)
16def _get_default_w8a8_block_fp8_config(block_n: int, block_k: int) -> Dict[str, Any]:
17 if flag_gems.device != "cuda":
18 return {
19 "BLOCK_SIZE_M": 64,
20 "BLOCK_SIZE_N": 64,
21 "BLOCK_SIZE_K": 128,
22 "GROUP_SIZE_M": 4,
23 "num_warps": 4,
24 "num_stages": 3,
25 }
27 return {
28 "BLOCK_SIZE_M": 64,
29 "BLOCK_SIZE_N": block_n,
30 "BLOCK_SIZE_K": block_k,
31 "GROUP_SIZE_M": 32,
32 "num_warps": 4,
33 "num_stages": 2,
34 }
37@triton.jit
38def w8a8_block_fp8_matmul_kernel(
39 A,
40 B,
41 C,
42 As,
43 Bs,
44 M,
45 N,
46 K,
47 group_n,
48 group_k,
49 stride_am,
50 stride_ak,
51 stride_bk,
52 stride_bn,
53 stride_cm,
54 stride_cn,
55 stride_As_m,
56 stride_As_k,
57 stride_Bs_k,
58 stride_Bs_n,
59 BLOCK_SIZE_M: tl.constexpr,
60 BLOCK_SIZE_N: tl.constexpr,
61 BLOCK_SIZE_K: tl.constexpr,
62 GROUP_SIZE_M: tl.constexpr,
63):
64 pid = tl.program_id(axis=0)
65 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
66 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
67 num_pid_in_group = GROUP_SIZE_M * num_pid_n
68 group_id = pid // num_pid_in_group
69 first_pid_m = group_id * GROUP_SIZE_M
70 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
71 pid_m = first_pid_m + (pid % group_size_m)
72 pid_n = (pid % num_pid_in_group) // group_size_m
74 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
75 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
76 offs_k = tl.arange(0, BLOCK_SIZE_K)
77 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
78 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
80 As_ptrs = As + offs_am * stride_As_m
81 offs_bsn = offs_bn // group_n
82 Bs_ptrs = Bs + offs_bsn * stride_Bs_n
84 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
85 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
86 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
87 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
89 k_start = k * BLOCK_SIZE_K
90 offs_ks = k_start // group_k
91 a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
92 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
93 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
94 a_ptrs += BLOCK_SIZE_K * stride_ak
95 b_ptrs += BLOCK_SIZE_K * stride_bk
97 if C.dtype.element_ty == tl.bfloat16:
98 c = accumulator.to(tl.bfloat16)
99 elif C.dtype.element_ty == tl.float16:
100 c = accumulator.to(tl.float16)
101 else:
102 c = accumulator.to(tl.float32)
104 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
105 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
106 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
107 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
108 tl.store(c_ptrs, c, mask=c_mask)
111@functools.lru_cache
112def get_w8a8_block_fp8_configs(
113 N: int, K: int, block_n: int, block_k: int
114) -> Optional[Dict[int, Any]]:
115 if not torch.cuda.is_available():
116 logger.debug(
117 "CUDA is unavailable on this backend; using default W8A8 block FP8 config."
118 )
119 return None
121 device_name = torch.cuda.get_device_name().replace(" ", "_")
122 file_name = f"fp8_w8a8-{block_n}-{block_k}.yaml"
124 config_dir = os.path.join(os.path.dirname(__file__), "..", "utils", "configs")
125 cfg_file = os.path.join(config_dir, file_name)
127 if os.path.exists(cfg_file):
128 with open(cfg_file) as f:
129 logger.info(
130 "Using config from %s for W8A8 block FP8 kernel.",
131 cfg_file,
132 )
133 dev_data = yaml.safe_load(f).get(device_name, {})
134 NK_data = dev_data.get(f"{N},{K}", {})
136 result = {}
137 for k, p in NK_data.items():
138 # unpack the list into dictionary
139 result[int(k)] = {
140 "BLOCK_SIZE_M": p[0],
141 "BLOCK_SIZE_N": p[1],
142 "BLOCK_SIZE_K": p[2],
143 "GROUP_SIZE_M": p[3],
144 "num_warps": p[4],
145 "num_stages": p[5],
146 }
147 if not result:
148 return None
149 return result
151 logger.warning(
152 "Using default W8A8 Block FP8 kernel config. Performance might "
153 "be sub-optimal! Config file not found at %s",
154 cfg_file,
155 )
156 return None
159def w8a8_block_fp8_matmul(
160 A: torch.Tensor,
161 B: torch.Tensor,
162 As: torch.Tensor,
163 Bs: torch.Tensor,
164 block_size: List[int],
165 output_dtype: torch.dtype = torch.bfloat16,
166) -> torch.Tensor:
167 assert len(block_size) == 2
168 block_n, block_k = block_size[0], block_size[1]
170 assert A.shape[-1] == B.shape[-1]
171 assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
172 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
173 M = A.numel() // A.shape[-1]
175 assert B.ndim == 2 and Bs.ndim == 2
176 N, K = B.shape
177 assert triton.cdiv(N, block_n) == Bs.shape[0]
178 assert triton.cdiv(K, block_k) == Bs.shape[1]
180 C_shape = A.shape[:-1] + (N,)
181 C = A.new_empty(C_shape, dtype=output_dtype)
183 configs = get_w8a8_block_fp8_configs(N, K, block_n, block_k)
184 if configs:
185 config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
186 else:
187 config = _get_default_w8a8_block_fp8_config(block_n, block_k)
189 def grid(META):
190 return (
191 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
192 )
194 w8a8_block_fp8_matmul_kernel[grid](
195 A,
196 B,
197 C,
198 As,
199 Bs,
200 M,
201 N,
202 K,
203 block_n,
204 block_k,
205 A.stride(-2),
206 A.stride(-1),
207 B.stride(1),
208 B.stride(0),
209 C.stride(-2),
210 C.stride(-1),
211 As.stride(-2),
212 As.stride(-1),
213 Bs.stride(1),
214 Bs.stride(0),
215 **config,
216 )
218 return C