Coverage for src/flag_gems/ops/mm.py: 40%

159 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.mm_streamk import streamk_mm 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.device_info import get_device_capability, get_sm_count 

13 

14CACHE_USAGE_THRESHOLD = 0.8 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19@triton.jit 

20def prev_multiple_of(a, b): 

21 # the largest x<a that x%b ==0 

22 return tl.cdiv(a, b) * b - b 

23 

24 

25@libentry() 

26@libtuner( 

27 configs=runtime.get_tuned_config("mm"), 

28 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

29 key=["M", "N", "K", "stride_am", "stride_bk"], 

30 strategy=["align32", "align32", "align32", "align32", "align32"], 

31 warmup=5, 

32 rep=10, 

33) 

34@triton.jit 

35def mm_kernel_general( 

36 A, 

37 B, 

38 C, 

39 M, 

40 N, 

41 K, 

42 stride_am, 

43 stride_ak, 

44 stride_bk, 

45 stride_bn, 

46 stride_cm, 

47 stride_cn, 

48 BLOCK_M: tl.constexpr, 

49 BLOCK_N: tl.constexpr, 

50 BLOCK_K: tl.constexpr, 

51 GROUP_M: tl.constexpr, 

52 IS_FP64: tl.constexpr = False, 

53): 

54 # matrix multiplication 

55 pid = tle.program_id(0) 

56 grid_m = tl.cdiv(M, BLOCK_M) 

57 grid_n = tl.cdiv(N, BLOCK_N) 

58 # re-order program ID for better L2 performance 

59 width = GROUP_M * grid_n 

60 group_id = pid // width 

61 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

62 pid_m = group_id * GROUP_M + (pid % group_size) 

63 pid_n = (pid % width) // (group_size) 

64 # do matrix multiplication 

65 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

66 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

67 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

68 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

69 rm = rm.to(tl.int64) 

70 rn = rn.to(tl.int64) 

71 prev_multiple = prev_multiple_of(K, BLOCK_K) 

72 

73 if IS_FP64: 

74 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

75 else: 

76 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

77 for start_k in range(0, prev_multiple, BLOCK_K): 

78 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

79 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

80 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

81 if a.dtype != b.dtype: 

82 a = a.to(C.dtype.element_ty) 

83 b = b.to(C.dtype.element_ty) 

84 if IS_FP64: 

85 acc += tl.dot(a, b, allow_tf32=False) 

86 else: 

87 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

88 

89 # loop peeling 

90 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

91 mask_k = rk < K 

92 a = tl.load( 

93 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

94 mask=mask_k[None, :], 

95 other=0.0, 

96 ) 

97 b = tl.load( 

98 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

99 mask=mask_k[:, None], 

100 other=0.0, 

101 ) 

102 if a.dtype != b.dtype: 

103 a = a.to(C.dtype.element_ty) 

104 b = b.to(C.dtype.element_ty) 

105 if IS_FP64: 

106 acc += tl.dot(a, b, allow_tf32=False) 

107 else: 

108 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

109 

110 acc = acc.to(C.dtype.element_ty) 

111 # rematerialize rm and rn to save registers 

112 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

113 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

114 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

115 mask = (rm < M)[:, None] & (rn < N)[None, :] 

116 # handles write-back with reduction-splitting 

117 tl.store(C, acc, mask=mask) 

118 

119 

120_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

121 

122 

123def get_higher_dtype(a, b): 

124 if a is b: 

125 return a 

126 

127 assert a in _ordered_datatypes 

128 assert b in _ordered_datatypes 

129 

130 for d in _ordered_datatypes: 

131 if a is d: 

132 return b 

133 if b is d: 

134 return a 

135 

136 

137def general_mm(a, b, c, M, N, K): 

138 grid = lambda META: ( 

139 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

140 ) 

141 with torch_device_fn.device(a.device): 

142 mm_kernel_general[grid]( 

143 a, 

144 b, 

145 c, 

146 M, 

147 N, 

148 K, 

149 a.stride(0), 

150 a.stride(1), 

151 b.stride(0), 

152 b.stride(1), 

153 c.stride(0), 

154 c.stride(1), 

155 GROUP_M=8, 

156 IS_FP64=a.dtype == torch.float64, 

157 ) 

158 return c 

159 

160 

161@libentry() 

162@libtuner( 

163 configs=runtime.get_tuned_config("mm_self_transpose"), 

164 key=["M", "K", "stride_am", "stride_ak"], 

165 strategy=["align32", "align32", "align32", "align32"], 

166 warmup=2, 

167 rep=4, 

168) 

169@triton.jit 

170def mm_kernel_syrk( 

171 A, 

172 C, 

173 M, 

174 K, 

175 stride_am, 

176 stride_ak, 

177 stride_cm, 

178 stride_cn, 

179 BLOCK_M: tl.constexpr, 

180 BLOCK_K: tl.constexpr, 

181): 

182 pid = tl.program_id(0) 

183 

184 # Packed lower-triangular launch domain: 

185 # pid = row * (row + 1) / 2 + col, where 0 <= col <= row. 

186 # 

187 # Invert the triangular-number indexing by solving: 

188 # row^2 + row - 2 * pid = 0 

189 # => row = (-1 + sqrt(1 + 8 * pid)) / 2 

190 # 

191 # We take floor(...) as the candidate row, then apply an integer +/-1 correction 

192 # because fp32 sqrt can be off near triangular-number boundaries. 

193 pid_f = pid.to(tl.float32) 

194 pid_m = tl.floor((tl.sqrt(8.0 * pid_f + 1.0) - 1.0) / 2.0).to(tl.int32) 

195 tri_start = pid_m * (pid_m + 1) // 2 

196 pid_m = tl.where(tri_start > pid, pid_m - 1, pid_m) 

197 next_tri_start = (pid_m + 1) * (pid_m + 2) // 2 

198 pid_m = tl.where(next_tri_start <= pid, pid_m + 1, pid_m) 

199 tri_start = pid_m * (pid_m + 1) // 2 

200 pid_n = pid - tri_start 

201 

202 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

203 rn = pid_n * BLOCK_M + tl.arange(0, BLOCK_M) 

204 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

205 ran = tl.max_contiguous(tl.multiple_of(rn % M, BLOCK_M), BLOCK_M).to(tl.int64) 

206 rm = rm.to(tl.int64) 

207 rn = rn.to(tl.int64) 

208 acc = tl.zeros((BLOCK_M, BLOCK_M), dtype=tl.float32) 

209 

210 for start_k in range(0, K, BLOCK_K): 

211 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

212 mask_k = rk < K 

213 a = tl.load( 

214 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

215 mask=mask_k[None, :], 

216 other=0.0, 

217 ) 

218 b = tl.load( 

219 A + (rk[:, None] * stride_ak + ran[None, :] * stride_am), 

220 mask=mask_k[:, None], 

221 other=0.0, 

222 ) 

223 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

224 

225 out = acc.to(C.dtype.element_ty) 

226 c_ptr = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

227 mask = (rm < M)[:, None] & (rn < M)[None, :] 

228 tl.store(c_ptr, out, mask=mask) 

229 

230 if pid_m > pid_n: 

231 c_t_ptr = C + (rn[:, None] * stride_cm + rm[None, :] * stride_cn) 

232 mask_t = (rn < M)[:, None] & (rm < M)[None, :] 

233 tl.store(c_t_ptr, tl.trans(out), mask=mask_t) 

234 

235 

236def is_syrk_transpose_pair(a, b): 

237 return ( 

238 a.ndim == 2 

239 and b.ndim == 2 

240 and a.shape[0] == b.shape[1] 

241 and a.shape[1] == b.shape[0] 

242 and a.stride(0) == b.stride(1) 

243 and a.stride(1) == b.stride(0) 

244 and a.storage_offset() == b.storage_offset() 

245 and a.data_ptr() == b.data_ptr() 

246 ) 

247 

248 

249def syrk_mm(a, c, M, K): 

250 grid = lambda META: ( 

251 # Number of tile rows is tiles = ceil(M / BLOCK_M). 

252 # Packed lower triangle contains: 

253 # 1 + 2 + ... + tiles = tiles * (tiles + 1) / 2 

254 triton.cdiv(M, META["BLOCK_M"]) 

255 * (triton.cdiv(M, META["BLOCK_M"]) + 1) 

256 // 2, 

257 ) 

258 with torch_device_fn.device(a.device): 

259 mm_kernel_syrk[grid]( 

260 a, 

261 c, 

262 M, 

263 K, 

264 a.stride(0), 

265 a.stride(1), 

266 c.stride(0), 

267 c.stride(1), 

268 ) 

269 return c 

270 

271 

272def streamk_scenario(a, b, M, N, K): 

273 # TODO: this my change sometime according to the realbenchmark result 

274 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

275 # The optimal settings for other devices need to be determined through real testing. 

276 capability = get_device_capability() 

277 return ( 

278 capability[0] == 8 

279 and a.dtype in [torch.float16, torch.bfloat16] 

280 and b.dtype in [torch.float16, torch.bfloat16] 

281 and a.is_contiguous() 

282 and b.is_contiguous() 

283 and K > M * 5 

284 and K > N * 5 

285 ) 

286 

287 

288def mm(a, b): 

289 logger.debug("GEMS MM") 

290 

291 device = a.device 

292 if is_syrk_transpose_pair(a, b): 

293 M, K = a.shape 

294 c = torch.empty((M, M), device=device, dtype=a.dtype) 

295 return syrk_mm(a, c, M, K) 

296 # handle non-contiguous inputs if necessary 

297 if a.stride(0) > 1 and a.stride(1) > 1: 

298 a = a.contiguous() 

299 if b.stride(0) > 1 and b.stride(1) > 1: 

300 b = b.contiguous() 

301 # checks constraints 

302 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

303 M, K = a.shape 

304 _, N = b.shape 

305 # allocates output 

306 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

307 c = torch.empty((M, N), device=device, dtype=c_dtype) 

308 # l2_cache_size = get_l2_cache_size() 

309 sm_count = get_sm_count() 

310 if streamk_scenario(a, b, M, N, K): 

311 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

312 else: 

313 return general_mm(a, b, c, M, N, K) 

314 

315 

316def mm_out(a, b, *, out): 

317 logger.debug("GEMS MM_OUT") 

318 

319 if is_syrk_transpose_pair(a, b): 

320 M, K = a.shape 

321 return syrk_mm(a, out, M, K) 

322 # handle non-contiguous inputs if necessary 

323 if a.stride(0) > 1 and a.stride(1) > 1: 

324 a = a.contiguous() 

325 if b.stride(0) > 1 and b.stride(1) > 1: 

326 b = b.contiguous() 

327 # checks constraints 

328 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

329 M, K = a.shape 

330 _, N = b.shape 

331 # l2_cache_size = get_l2_cache_size() 

332 sm_count = get_sm_count() 

333 if streamk_scenario(a, b, M, N, K): 

334 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

335 else: 

336 return general_mm(a, b, out, M, N, K)