Coverage for src/flag_gems/ops/w8a8_block_fp8_matmul.py: 53%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import functools 

2import logging 

3import os 

4from typing import Any, Dict, List, Optional 

5 

6import torch 

7import triton 

8import triton.language as tl 

9import yaml 

10 

11import flag_gems 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16def _get_default_w8a8_block_fp8_config(block_n: int, block_k: int) -> Dict[str, Any]: 

17 if flag_gems.device != "cuda": 

18 return { 

19 "BLOCK_SIZE_M": 64, 

20 "BLOCK_SIZE_N": 64, 

21 "BLOCK_SIZE_K": 128, 

22 "GROUP_SIZE_M": 4, 

23 "num_warps": 4, 

24 "num_stages": 3, 

25 } 

26 

27 return { 

28 "BLOCK_SIZE_M": 64, 

29 "BLOCK_SIZE_N": block_n, 

30 "BLOCK_SIZE_K": block_k, 

31 "GROUP_SIZE_M": 32, 

32 "num_warps": 4, 

33 "num_stages": 2, 

34 } 

35 

36 

37@triton.jit 

38def w8a8_block_fp8_matmul_kernel( 

39 A, 

40 B, 

41 C, 

42 As, 

43 Bs, 

44 M, 

45 N, 

46 K, 

47 group_n, 

48 group_k, 

49 stride_am, 

50 stride_ak, 

51 stride_bk, 

52 stride_bn, 

53 stride_cm, 

54 stride_cn, 

55 stride_As_m, 

56 stride_As_k, 

57 stride_Bs_k, 

58 stride_Bs_n, 

59 BLOCK_SIZE_M: tl.constexpr, 

60 BLOCK_SIZE_N: tl.constexpr, 

61 BLOCK_SIZE_K: tl.constexpr, 

62 GROUP_SIZE_M: tl.constexpr, 

63): 

64 pid = tl.program_id(axis=0) 

65 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

66 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

67 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

68 group_id = pid // num_pid_in_group 

69 first_pid_m = group_id * GROUP_SIZE_M 

70 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

71 pid_m = first_pid_m + (pid % group_size_m) 

72 pid_n = (pid % num_pid_in_group) // group_size_m 

73 

74 offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M 

75 offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N 

76 offs_k = tl.arange(0, BLOCK_SIZE_K) 

77 a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

78 b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

79 

80 As_ptrs = As + offs_am * stride_As_m 

81 offs_bsn = offs_bn // group_n 

82 Bs_ptrs = Bs + offs_bsn * stride_Bs_n 

83 

84 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

85 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

86 a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) 

87 b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) 

88 

89 k_start = k * BLOCK_SIZE_K 

90 offs_ks = k_start // group_k 

91 a_s = tl.load(As_ptrs + offs_ks * stride_As_k) 

92 b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) 

93 accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] 

94 a_ptrs += BLOCK_SIZE_K * stride_ak 

95 b_ptrs += BLOCK_SIZE_K * stride_bk 

96 

97 if C.dtype.element_ty == tl.bfloat16: 

98 c = accumulator.to(tl.bfloat16) 

99 elif C.dtype.element_ty == tl.float16: 

100 c = accumulator.to(tl.float16) 

101 else: 

102 c = accumulator.to(tl.float32) 

103 

104 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

105 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

106 c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

107 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

108 tl.store(c_ptrs, c, mask=c_mask) 

109 

110 

111@functools.lru_cache 

112def get_w8a8_block_fp8_configs( 

113 N: int, K: int, block_n: int, block_k: int 

114) -> Optional[Dict[int, Any]]: 

115 if not torch.cuda.is_available(): 

116 logger.debug( 

117 "CUDA is unavailable on this backend; using default W8A8 block FP8 config." 

118 ) 

119 return None 

120 

121 device_name = torch.cuda.get_device_name().replace(" ", "_") 

122 file_name = f"fp8_w8a8-{block_n}-{block_k}.yaml" 

123 

124 config_dir = os.path.join(os.path.dirname(__file__), "..", "utils", "configs") 

125 cfg_file = os.path.join(config_dir, file_name) 

126 

127 if os.path.exists(cfg_file): 

128 with open(cfg_file) as f: 

129 logger.info( 

130 "Using config from %s for W8A8 block FP8 kernel.", 

131 cfg_file, 

132 ) 

133 dev_data = yaml.safe_load(f).get(device_name, {}) 

134 NK_data = dev_data.get(f"{N},{K}", {}) 

135 

136 result = {} 

137 for k, p in NK_data.items(): 

138 # unpack the list into dictionary 

139 result[int(k)] = { 

140 "BLOCK_SIZE_M": p[0], 

141 "BLOCK_SIZE_N": p[1], 

142 "BLOCK_SIZE_K": p[2], 

143 "GROUP_SIZE_M": p[3], 

144 "num_warps": p[4], 

145 "num_stages": p[5], 

146 } 

147 if not result: 

148 return None 

149 return result 

150 

151 logger.warning( 

152 "Using default W8A8 Block FP8 kernel config. Performance might " 

153 "be sub-optimal! Config file not found at %s", 

154 cfg_file, 

155 ) 

156 return None 

157 

158 

159def w8a8_block_fp8_matmul( 

160 A: torch.Tensor, 

161 B: torch.Tensor, 

162 As: torch.Tensor, 

163 Bs: torch.Tensor, 

164 block_size: List[int], 

165 output_dtype: torch.dtype = torch.bfloat16, 

166) -> torch.Tensor: 

167 assert len(block_size) == 2 

168 block_n, block_k = block_size[0], block_size[1] 

169 

170 assert A.shape[-1] == B.shape[-1] 

171 assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() 

172 assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] 

173 M = A.numel() // A.shape[-1] 

174 

175 assert B.ndim == 2 and Bs.ndim == 2 

176 N, K = B.shape 

177 assert triton.cdiv(N, block_n) == Bs.shape[0] 

178 assert triton.cdiv(K, block_k) == Bs.shape[1] 

179 

180 C_shape = A.shape[:-1] + (N,) 

181 C = A.new_empty(C_shape, dtype=output_dtype) 

182 

183 configs = get_w8a8_block_fp8_configs(N, K, block_n, block_k) 

184 if configs: 

185 config = configs[min(configs.keys(), key=lambda x: abs(x - M))] 

186 else: 

187 config = _get_default_w8a8_block_fp8_config(block_n, block_k) 

188 

189 def grid(META): 

190 return ( 

191 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

192 ) 

193 

194 w8a8_block_fp8_matmul_kernel[grid]( 

195 A, 

196 B, 

197 C, 

198 As, 

199 Bs, 

200 M, 

201 N, 

202 K, 

203 block_n, 

204 block_k, 

205 A.stride(-2), 

206 A.stride(-1), 

207 B.stride(1), 

208 B.stride(0), 

209 C.stride(-2), 

210 C.stride(-1), 

211 As.stride(-2), 

212 As.stride(-1), 

213 Bs.stride(1), 

214 Bs.stride(0), 

215 **config, 

216 ) 

217 

218 return C