Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from triton.tools.tensor_descriptor import TensorDescriptor 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import broadcastable_to, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join(os.path.dirname(__file__), "..", "addmm_mthreads_expand.yaml") 

21) 

22 

23 

24def is_supported_sqmma_layout(tensor): 

25 return tensor.is_contiguous() or ( 

26 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

27 ) 

28 

29 

30def is_sqmma_compatible(a, b, N, K): 

31 return ( 

32 a.dim() == 2 

33 and b.dim() == 2 

34 and a.dtype == b.dtype 

35 and a.dtype in (torch.float16, torch.bfloat16) 

36 and is_supported_sqmma_layout(a) 

37 and is_supported_sqmma_layout(b) 

38 and N % 8 == 0 

39 and K % 8 == 0 

40 ) 

41 

42 

43@libentry() 

44@libtuner( 

45 configs=runtime.get_tuned_config("addmm"), 

46 key=["M", "N", "K"], 

47) 

48@triton.jit(do_not_specialize=["alpha", "beta"]) 

49def addmm_kernel( 

50 a_ptr, 

51 b_ptr, 

52 i_ptr, 

53 c_ptr, 

54 alpha, 

55 beta, 

56 M, 

57 N, 

58 K, 

59 stride_am, 

60 stride_ak, 

61 stride_bk, 

62 stride_bn, 

63 stride_im, 

64 stride_in, 

65 stride_cm, 

66 stride_cn, 

67 BLOCK_SIZE_M: tl.constexpr, 

68 BLOCK_SIZE_N: tl.constexpr, 

69 BLOCK_SIZE_K: tl.constexpr, 

70 IS_FP64: tl.constexpr = False, 

71): 

72 pid_m = ext.program_id(0) 

73 pid_n = ext.program_id(1) 

74 

75 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

76 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

77 offs_k = tl.arange(0, BLOCK_SIZE_K) 

78 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

79 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

80 

81 if IS_FP64: 

82 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64) 

83 else: 

84 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

85 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

86 a = tl.load( 

87 a_ptrs, 

88 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

89 other=0.0, 

90 ) 

91 b = tl.load( 

92 b_ptrs, 

93 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

94 other=0.0, 

95 ) 

96 if IS_FP64: 

97 a = a.to(tl.float32) 

98 b = b.to(tl.float32) 

99 accumulator += tl.dot(a, b, allow_tf32=False) 

100 a_ptrs += BLOCK_SIZE_K * stride_ak 

101 b_ptrs += BLOCK_SIZE_K * stride_bk 

102 

103 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

104 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

105 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

106 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

107 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

108 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

109 

110 accumulator = accumulator * alpha + bias * beta 

111 c = accumulator.to(bias.dtype) 

112 tl.store(c_ptrs, c, mask=c_mask) 

113 

114 

115def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1): 

116 logger.debug("GEMS_MTHREADS ADDMM(FMA)") 

117 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

118 assert broadcastable_to( 

119 bias.shape, (mat1.shape[0], mat2.shape[1]) 

120 ), "Incompatible input shape" 

121 M, K = mat1.shape 

122 _, N = mat2.shape 

123 

124 mat1 = mat1.contiguous() 

125 mat2 = mat2.contiguous() 

126 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

127 bias = bias.broadcast_to(out.shape).contiguous() 

128 

129 grid = lambda META: ( 

130 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

131 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

132 ) 

133 with torch_device_fn.device(mat1.device): 

134 addmm_kernel[grid]( 

135 mat1, 

136 mat2, 

137 bias, 

138 out, 

139 alpha, 

140 beta, 

141 M, 

142 N, 

143 K, 

144 mat1.stride(0), 

145 mat1.stride(1), 

146 mat2.stride(0), 

147 mat2.stride(1), 

148 bias.stride(0), 

149 bias.stride(1), 

150 out.stride(0), 

151 out.stride(1), 

152 IS_FP64=mat1.dtype == torch.float64, 

153 ) 

154 return out 

155 

156 

157def addmm_sqmma_descriptor_pre_hook(nargs): 

158 nargs["a_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_K"]] 

159 nargs["b_desc"].block_shape = [nargs["BLOCK_SIZE_K"], nargs["BLOCK_SIZE_N"]] 

160 nargs["bias_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_N"]] 

161 nargs["c_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_N"]] 

162 

163 

164@libentry() 

165@libtuner( 

166 configs=runtime.ops_get_configs( 

167 "addmm_sqmma", 

168 pre_hook=addmm_sqmma_descriptor_pre_hook, 

169 yaml_path=EXPAND_CONFIG_FILENAME, 

170 ) 

171 if os.environ.get("USE_FLAGTUNE") == "1" 

172 else [ 

173 triton.Config( 

174 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, 

175 num_stages=1, 

176 num_warps=4, 

177 pre_hook=addmm_sqmma_descriptor_pre_hook, 

178 ) 

179 ], 

180 key=["M", "N", "K"], 

181 strategy=runtime.get_expand_config("addmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[ 

182 "strategy" 

183 ] 

184 if os.environ.get("USE_FLAGTUNE") == "1" 

185 else ["default", "default", "default"], 

186 warmup=5, 

187 rep=5, 

188) 

189@triton.jit(do_not_specialize=["alpha", "beta"]) 

190def addmm_sqmma_kernel( 

191 a_desc, 

192 b_desc, 

193 bias_desc, 

194 c_desc, 

195 M, 

196 N, 

197 K, 

198 alpha, 

199 beta, 

200 DTYPE: tl.constexpr, 

201 BLOCK_SIZE_M: tl.constexpr, 

202 BLOCK_SIZE_N: tl.constexpr, 

203 BLOCK_SIZE_K: tl.constexpr, 

204): 

205 pid = tl.program_id(axis=0) 

206 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

207 pid_m = pid % num_pid_m 

208 pid_n = pid // num_pid_m 

209 offs_am = (pid_m * BLOCK_SIZE_M).to(tl.int32) 

210 offs_bn = (pid_n * BLOCK_SIZE_N).to(tl.int32) 

211 offs_k = 0 

212 offs_k = offs_k.to(tl.int32) 

213 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

214 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

215 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

216 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

217 accumulator = tl.dot(a, b, acc=accumulator) 

218 offs_k += BLOCK_SIZE_K 

219 bias = tl.load_tensor_descriptor(bias_desc, [offs_am, offs_bn]) 

220 result = (alpha * accumulator + beta * bias).to(c_desc.dtype) 

221 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], result) 

222 

223 

224def addmm_sqmma(mat1, mat2, bias, elem_type, alpha, beta, M, N, K): 

225 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)") 

226 device = mat1.device 

227 assert broadcastable_to( 

228 bias.shape, (mat1.shape[0], mat2.shape[1]) 

229 ), "Incompatible input shape" 

230 if not mat1.is_contiguous(): 

231 mat1 = mat1.contiguous() 

232 if not mat2.is_contiguous(): 

233 mat2 = mat2.contiguous() 

234 a_type = mat1.dtype 

235 b_type = mat2.dtype 

236 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

237 c_type = a_type 

238 C = torch.empty((M, N), dtype=c_type, device=device) 

239 bias = bias.broadcast_to(C.shape).contiguous() 

240 desc_a = TensorDescriptor.from_tensor(mat1, [1, 1]) 

241 desc_b = TensorDescriptor.from_tensor(mat2, [1, 1]) 

242 desc_bias = TensorDescriptor.from_tensor(bias, [1, 1]) 

243 desc_c = TensorDescriptor.from_tensor(C, [1, 1]) 

244 grid = lambda META: ( 

245 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

246 1, 

247 1, 

248 ) 

249 addmm_sqmma_kernel[grid]( 

250 desc_a, 

251 desc_b, 

252 desc_bias, 

253 desc_c, 

254 M, 

255 N, 

256 K, 

257 alpha, 

258 beta, 

259 str(a_type).split(".")[-1], 

260 ) 

261 return C 

262 

263 

264def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

265 a_dtype = mat1.dtype 

266 M, K = mat1.shape 

267 _, N = mat2.shape 

268 

269 if is_sqmma_compatible(mat1, mat2, N, K): 

270 return addmm_sqmma( 

271 mat1, 

272 mat2, 

273 bias, 

274 a_dtype, 

275 alpha, 

276 beta, 

277 M, 

278 N, 

279 K, 

280 ) 

281 else: 

282 return addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta) 

283 

284 

285def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1): 

286 logger.debug("GEMS_MTHREADS ADDMM_DTYPE") 

287 out = torch.empty( 

288 (mat1.shape[0], mat2.shape[1]), 

289 device=mat1.device, 

290 dtype=out_dtype, 

291 ) 

292 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out) 

293 

294 

295def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out): 

296 logger.debug("GEMS_MTHREADS ADDMM_DTYPE_OUT") 

297 if mat1.dtype != mat2.dtype: 

298 raise RuntimeError( 

299 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}" 

300 ) 

301 if out.dtype != out_dtype: 

302 raise RuntimeError( 

303 "out_dtype must be the same as the dtype of the provided out tensor" 

304 ) 

305 if not ( 

306 out_dtype == mat1.dtype 

307 or ( 

308 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16) 

309 ) 

310 ): 

311 raise RuntimeError( 

312 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs" 

313 ) 

314 if bias.dtype != out_dtype and bias.dtype != mat1.dtype: 

315 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype") 

316 

317 bias_c = bias.to(out_dtype) 

318 M, K = mat1.shape 

319 _, N = mat2.shape 

320 a_dtype = mat1.dtype 

321 

322 if is_sqmma_compatible(mat1, mat2, N, K): 

323 result = addmm_sqmma( 

324 mat1, 

325 mat2, 

326 bias_c, 

327 a_dtype, 

328 alpha, 

329 beta, 

330 M, 

331 N, 

332 K, 

333 ) 

334 else: 

335 result = addmm_fma(bias_c, mat1, mat2, alpha=alpha, beta=beta) 

336 out.copy_(result) 

337 return out