Coverage for src/flag_gems/ops/fp8_matmul.py: 19%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1""" 

2FP8 Matrix Multiplication — Triton Kernel (Block-wise Scaling) 

3Fixed config version for H20 deployment (no autotune warmup). 

4 

5API: 

6 fp8_matmul(a, a_s, b, b_s) -> Tensor 

7 

8 a: (..., K) float8_e4m3fn, contiguous 

9 a_s: (..., K // group_size) float32, per-token-group scale 

10 b: (N, K) float8_e4m3fn, contiguous 

11 b_s: (N // group_size, K // group_size) float32, per-block scale 

12 group_size = 128 

13 

14 Returns: (..., N) bfloat16 

15 

16Based on v45. Fixed config: BLOCK_M=64, BLOCK_N=64, BLOCK_K=128, 

17GROUP_SIZE_M=4, num_stages=3, num_warps=4 (best for M>=128 on H20). 

18""" 

19 

20import torch 

21import triton 

22import triton.language as tl 

23 

24GROUP_SIZE = 128 

25 

26# Fixed config — best for M>=128 on H20 (covers majority of production shapes) 

27BLOCK_M = 64 

28BLOCK_N = 64 

29BLOCK_K = 128 

30GROUP_SIZE_M = 4 

31NUM_STAGES = 3 

32NUM_WARPS = 4 

33 

34# Debug print helper (flush immediately for real-time visibility) 

35# def _p(msg): 

36# rank = os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0")) 

37# print(f"[fp8_matmul][rank{rank}] {msg}", flush=True) 

38 

39 

40@triton.jit 

41def _fp8_matmul_kernel( 

42 A, 

43 B, 

44 C, 

45 As, 

46 Bs, 

47 M, 

48 N, 

49 K, 

50 stride_am, 

51 stride_ak, 

52 stride_bn, 

53 stride_bk, 

54 stride_cm, 

55 stride_cn, 

56 stride_as_m, 

57 stride_as_k, 

58 stride_bs_n, 

59 stride_bs_k, 

60 GROUP_K: tl.constexpr, 

61 BLOCK_M: tl.constexpr, 

62 BLOCK_N: tl.constexpr, 

63 BLOCK_K: tl.constexpr, 

64 GROUP_SIZE_M: tl.constexpr, 

65): 

66 pid = tl.program_id(0) 

67 num_pid_m = tl.cdiv(M, BLOCK_M) 

68 num_pid_n = tl.cdiv(N, BLOCK_N) 

69 

70 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

71 group_id = pid // num_pid_in_group 

72 first_pid_m = group_id * GROUP_SIZE_M 

73 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

74 pid_m = first_pid_m + (pid % group_size_m) 

75 pid_n = (pid % num_pid_in_group) // group_size_m 

76 

77 offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M 

78 offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N 

79 offs_k = tl.arange(0, BLOCK_K) 

80 

81 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

82 b_ptrs = B + offs_n[:, None] * stride_bn + offs_k[None, :] * stride_bk 

83 

84 as_ptrs = As + offs_m * stride_as_m 

85 offs_bs_n = offs_n // GROUP_K 

86 bs_scalar_n_idx = pid_n * BLOCK_N // GROUP_K 

87 

88 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

89 

90 num_k_iters = tl.cdiv(K, BLOCK_K) 

91 for k in range(0, num_k_iters): 

92 k_idx = k * BLOCK_K // GROUP_K 

93 a_s = tl.load(as_ptrs + k_idx * stride_as_k) 

94 

95 if BLOCK_N <= GROUP_K: 

96 b_s_val = tl.load(Bs + bs_scalar_n_idx * stride_bs_n + k_idx * stride_bs_k) 

97 else: 

98 b_s = tl.load(Bs + offs_bs_n * stride_bs_n + k_idx * stride_bs_k) 

99 

100 mask_k = offs_k < K - k * BLOCK_K 

101 a = tl.load(a_ptrs, mask=mask_k[None, :], other=0.0) 

102 b = tl.load(b_ptrs, mask=mask_k[None, :], other=0.0) 

103 

104 dot = tl.dot(a, tl.trans(b)) 

105 

106 if BLOCK_N <= GROUP_K: 

107 acc += dot * (a_s[:, None] * b_s_val) 

108 else: 

109 acc += dot * (a_s[:, None] * b_s[None, :]) 

110 

111 a_ptrs += BLOCK_K * stride_ak 

112 b_ptrs += BLOCK_K * stride_bk 

113 

114 c = acc.to(tl.bfloat16) 

115 c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

116 tl.store(c_ptrs, c) 

117 

118 

119def fp8_matmul( 

120 a: torch.Tensor, 

121 a_s: torch.Tensor, 

122 b: torch.Tensor, 

123 b_s: torch.Tensor, 

124 scale_dtype: torch.dtype = torch.float32, 

125) -> torch.Tensor: 

126 """ 

127 Block-wise scaled FP8 matrix multiplication. 

128 

129 Args: 

130 a: (..., K) float8_e4m3fn, contiguous 

131 a_s: (..., K // 128) float32, per-token-group scale 

132 b: (N, K) float8_e4m3fn, contiguous 

133 b_s: (N // 128, K // 128) float32, per-block scale 

134 Returns: 

135 (..., N) bfloat16 

136 """ 

137 assert b.ndim == 2 

138 assert a.is_contiguous() and b.is_contiguous() 

139 assert a_s.is_contiguous() and b_s.is_contiguous() 

140 

141 K = a.size(-1) 

142 M = a.numel() // K 

143 N, K2 = b.shape 

144 assert K == K2 

145 

146 if scale_dtype == torch.float8_e8m0fnu: 

147 a_s = a_s.to(torch.float32) 

148 b_s = b_s.to(torch.float32) 

149 

150 out_shape = (*a.size()[:-1], N) 

151 a_2d = a.view(M, K) 

152 a_s_2d = a_s.view(M, -1) 

153 

154 C = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) 

155 

156 grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) 

157 

158 _fp8_matmul_kernel[grid]( 

159 a_2d, 

160 b, 

161 C, 

162 a_s_2d, 

163 b_s, 

164 M, 

165 N, 

166 K, 

167 a_2d.stride(0), 

168 a_2d.stride(1), 

169 b.stride(0), 

170 b.stride(1), 

171 C.stride(0), 

172 C.stride(1), 

173 a_s_2d.stride(0), 

174 a_s_2d.stride(1), 

175 b_s.stride(0), 

176 b_s.stride(1), 

177 GROUP_K=GROUP_SIZE, 

178 BLOCK_M=BLOCK_M, 

179 BLOCK_N=BLOCK_N, 

180 BLOCK_K=BLOCK_K, 

181 GROUP_SIZE_M=GROUP_SIZE_M, 

182 num_stages=NUM_STAGES, 

183 num_warps=NUM_WARPS, 

184 ) 

185 

186 return C.view(out_shape)