Coverage for src/flag_gems/ops/fp8_matmul.py: 19%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1"""
2FP8 Matrix Multiplication — Triton Kernel (Block-wise Scaling)
3Fixed config version for H20 deployment (no autotune warmup).
5API:
6 fp8_matmul(a, a_s, b, b_s) -> Tensor
8 a: (..., K) float8_e4m3fn, contiguous
9 a_s: (..., K // group_size) float32, per-token-group scale
10 b: (N, K) float8_e4m3fn, contiguous
11 b_s: (N // group_size, K // group_size) float32, per-block scale
12 group_size = 128
14 Returns: (..., N) bfloat16
16Based on v45. Fixed config: BLOCK_M=64, BLOCK_N=64, BLOCK_K=128,
17GROUP_SIZE_M=4, num_stages=3, num_warps=4 (best for M>=128 on H20).
18"""
20import torch
21import triton
22import triton.language as tl
24GROUP_SIZE = 128
26# Fixed config — best for M>=128 on H20 (covers majority of production shapes)
27BLOCK_M = 64
28BLOCK_N = 64
29BLOCK_K = 128
30GROUP_SIZE_M = 4
31NUM_STAGES = 3
32NUM_WARPS = 4
34# Debug print helper (flush immediately for real-time visibility)
35# def _p(msg):
36# rank = os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))
37# print(f"[fp8_matmul][rank{rank}] {msg}", flush=True)
40@triton.jit
41def _fp8_matmul_kernel(
42 A,
43 B,
44 C,
45 As,
46 Bs,
47 M,
48 N,
49 K,
50 stride_am,
51 stride_ak,
52 stride_bn,
53 stride_bk,
54 stride_cm,
55 stride_cn,
56 stride_as_m,
57 stride_as_k,
58 stride_bs_n,
59 stride_bs_k,
60 GROUP_K: tl.constexpr,
61 BLOCK_M: tl.constexpr,
62 BLOCK_N: tl.constexpr,
63 BLOCK_K: tl.constexpr,
64 GROUP_SIZE_M: tl.constexpr,
65):
66 pid = tl.program_id(0)
67 num_pid_m = tl.cdiv(M, BLOCK_M)
68 num_pid_n = tl.cdiv(N, BLOCK_N)
70 num_pid_in_group = GROUP_SIZE_M * num_pid_n
71 group_id = pid // num_pid_in_group
72 first_pid_m = group_id * GROUP_SIZE_M
73 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
74 pid_m = first_pid_m + (pid % group_size_m)
75 pid_n = (pid % num_pid_in_group) // group_size_m
77 offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
78 offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
79 offs_k = tl.arange(0, BLOCK_K)
81 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
82 b_ptrs = B + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk
84 as_ptrs = As + offs_m * stride_as_m
85 offs_bs_n = offs_n // GROUP_K
86 bs_scalar_n_idx = pid_n * BLOCK_N // GROUP_K
88 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
90 num_k_iters = tl.cdiv(K, BLOCK_K)
91 for k in range(0, num_k_iters):
92 k_idx = k * BLOCK_K // GROUP_K
93 a_s = tl.load(as_ptrs + k_idx * stride_as_k)
95 if BLOCK_N <= GROUP_K:
96 b_s_val = tl.load(Bs + bs_scalar_n_idx * stride_bs_n + k_idx * stride_bs_k)
97 else:
98 b_s = tl.load(Bs + offs_bs_n * stride_bs_n + k_idx * stride_bs_k)
100 mask_k = offs_k < K - k * BLOCK_K
101 a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0)
102 b = tl.load(b_ptrs, mask=mask_k[None, :], other=0.0)
104 dot = tl.dot(a, tl.trans(b))
106 if BLOCK_N <= GROUP_K:
107 acc += dot * (a_s[:, None] * b_s_val)
108 else:
109 acc += dot * (a_s[:, None] * b_s[None, :])
111 a_ptrs += BLOCK_K * stride_ak
112 b_ptrs += BLOCK_K * stride_bk
114 c = acc.to(tl.bfloat16)
115 c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
116 tl.store(c_ptrs, c)
119def fp8_matmul(
120 a: torch.Tensor,
121 a_s: torch.Tensor,
122 b: torch.Tensor,
123 b_s: torch.Tensor,
124 scale_dtype: torch.dtype = torch.float32,
125) -> torch.Tensor:
126 """
127 Block-wise scaled FP8 matrix multiplication.
129 Args:
130 a: (..., K) float8_e4m3fn, contiguous
131 a_s: (..., K // 128) float32, per-token-group scale
132 b: (N, K) float8_e4m3fn, contiguous
133 b_s: (N // 128, K // 128) float32, per-block scale
134 Returns:
135 (..., N) bfloat16
136 """
137 assert b.ndim == 2
138 assert a.is_contiguous() and b.is_contiguous()
139 assert a_s.is_contiguous() and b_s.is_contiguous()
141 K = a.size(-1)
142 M = a.numel() // K
143 N, K2 = b.shape
144 assert K == K2
146 if scale_dtype == torch.float8_e8m0fnu:
147 a_s = a_s.to(torch.float32)
148 b_s = b_s.to(torch.float32)
150 out_shape = (*a.size()[:-1], N)
151 a_2d = a.view(M, K)
152 a_s_2d = a_s.view(M, -1)
154 C = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
156 grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)
158 _fp8_matmul_kernel[grid](
159 a_2d,
160 b,
161 C,
162 a_s_2d,
163 b_s,
164 M,
165 N,
166 K,
167 a_2d.stride(0),
168 a_2d.stride(1),
169 b.stride(0),
170 b.stride(1),
171 C.stride(0),
172 C.stride(1),
173 a_s_2d.stride(0),
174 a_s_2d.stride(1),
175 b_s.stride(0),
176 b_s.stride(1),
177 GROUP_K=GROUP_SIZE,
178 BLOCK_M=BLOCK_M,
179 BLOCK_N=BLOCK_N,
180 BLOCK_K=BLOCK_K,
181 GROUP_SIZE_M=GROUP_SIZE_M,
182 num_stages=NUM_STAGES,
183 num_warps=NUM_WARPS,
184 )
186 return C.view(out_shape)