Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%

138 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7from triton.tools.tensor_descriptor import TensorDescriptor 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import broadcastable_to, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join(os.path.dirname(__file__), "..", "addmm_mthreads_expand.yaml") 

21) 

22 

23 

24def is_supported_sqmma_layout(tensor): 

25 return tensor.is_contiguous() or ( 

26 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0] 

27 ) 

28 

29 

30def is_sqmma_compatible(a, b, N, K): 

31 return ( 

32 a.dim() == 2 

33 and b.dim() == 2 

34 and a.dtype == b.dtype 

35 and a.dtype in (torch.float16, torch.bfloat16) 

36 and is_supported_sqmma_layout(a) 

37 and is_supported_sqmma_layout(b) 

38 and N % 8 == 0 

39 and K % 8 == 0 

40 ) 

41 

42 

43@libentry() 

44@libtuner( 

45 configs=runtime.get_tuned_config("addmm"), 

46 key=["M", "N", "K"], 

47) 

48@triton.jit(do_not_specialize=["alpha", "beta"]) 

49def addmm_kernel( 

50 a_ptr, 

51 b_ptr, 

52 i_ptr, 

53 c_ptr, 

54 alpha, 

55 beta, 

56 M, 

57 N, 

58 K, 

59 stride_am, 

60 stride_ak, 

61 stride_bk, 

62 stride_bn, 

63 stride_im, 

64 stride_in, 

65 stride_cm, 

66 stride_cn, 

67 BLOCK_SIZE_M: tl.constexpr, 

68 BLOCK_SIZE_N: tl.constexpr, 

69 BLOCK_SIZE_K: tl.constexpr, 

70 IS_FP64: tl.constexpr = False, 

71): 

72 pid_m = ext.program_id(0) 

73 pid_n = ext.program_id(1) 

74 

75 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

76 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

77 offs_k = tl.arange(0, BLOCK_SIZE_K) 

78 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

79 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

80 

81 if IS_FP64: 

82 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64) 

83 else: 

84 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

85 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

86 a = tl.load( 

87 a_ptrs, 

88 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

89 other=0.0, 

90 ) 

91 b = tl.load( 

92 b_ptrs, 

93 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

94 other=0.0, 

95 ) 

96 if IS_FP64: 

97 a = a.to(tl.float32) 

98 b = b.to(tl.float32) 

99 accumulator += tl.dot(a, b, allow_tf32=False) 

100 a_ptrs += BLOCK_SIZE_K * stride_ak 

101 b_ptrs += BLOCK_SIZE_K * stride_bk 

102 

103 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

104 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

105 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

106 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

107 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

108 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

109 

110 accumulator = accumulator * alpha + bias * beta 

111 c = accumulator.to(bias.dtype) 

112 tl.store(c_ptrs, c, mask=c_mask) 

113 

114 

115def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1): 

116 logger.debug("GEMS_MTHREADS ADDMM(FMA)") 

117 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

118 assert broadcastable_to( 

119 bias.shape, (mat1.shape[0], mat2.shape[1]) 

120 ), "Incompatible input shape" 

121 M, K = mat1.shape 

122 _, N = mat2.shape 

123 

124 mat1 = mat1.contiguous() 

125 mat2 = mat2.contiguous() 

126 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

127 bias = bias.broadcast_to(out.shape).contiguous() 

128 

129 grid = lambda META: ( 

130 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

131 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

132 ) 

133 with torch_device_fn.device(mat1.device): 

134 addmm_kernel[grid]( 

135 mat1, 

136 mat2, 

137 bias, 

138 out, 

139 alpha, 

140 beta, 

141 M, 

142 N, 

143 K, 

144 mat1.stride(0), 

145 mat1.stride(1), 

146 mat2.stride(0), 

147 mat2.stride(1), 

148 bias.stride(0), 

149 bias.stride(1), 

150 out.stride(0), 

151 out.stride(1), 

152 IS_FP64=mat1.dtype == torch.float64, 

153 ) 

154 return out 

155 

156 

157@triton.jit(do_not_specialize=["alpha", "beta"]) 

158def addmm_sqmma_kernel( 

159 a_desc, 

160 b_desc, 

161 bias_desc, 

162 c_desc, 

163 M, 

164 N, 

165 K, 

166 alpha, 

167 beta, 

168 BLOCK_SIZE_M: tl.constexpr, 

169 BLOCK_SIZE_N: tl.constexpr, 

170 BLOCK_SIZE_K: tl.constexpr, 

171): 

172 pid = tl.program_id(axis=0) 

173 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

174 pid_m = pid % num_pid_m 

175 pid_n = pid // num_pid_m 

176 offs_am = (pid_m * BLOCK_SIZE_M).to(tl.int32) 

177 offs_bn = (pid_n * BLOCK_SIZE_N).to(tl.int32) 

178 offs_k = 0 

179 offs_k = offs_k.to(tl.int32) 

180 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

181 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

182 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) 

183 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn]) 

184 accumulator = tl.dot(a, b, acc=accumulator) 

185 offs_k += BLOCK_SIZE_K 

186 bias = tl.load_tensor_descriptor(bias_desc, [offs_am, offs_bn]) 

187 result = (alpha * accumulator + beta * bias).to(c_desc.dtype) 

188 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], result) 

189 

190 

191def get_triton_type(elem_type): 

192 type_map = { 

193 torch.float16: tl.float16, 

194 torch.bfloat16: tl.bfloat16, 

195 torch.float8_e4m3fn: tl.float8e4nv, 

196 } 

197 return type_map.get(elem_type, None) 

198 

199 

200def addmm_sqmma(mat1, mat2, bias, elem_type, alpha, beta, M, N, K): 

201 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)") 

202 device = mat1.device 

203 assert broadcastable_to( 

204 bias.shape, (mat1.shape[0], mat2.shape[1]) 

205 ), "Incompatible input shape" 

206 if not mat1.is_contiguous(): 

207 mat1 = mat1.contiguous() 

208 if not mat2.is_contiguous(): 

209 mat2 = mat2.contiguous() 

210 a_type = mat1.dtype 

211 b_type = mat2.dtype 

212 assert a_type == b_type, "Mat A and Mat B should have the same dtype" 

213 c_type = a_type 

214 C = torch.empty((M, N), dtype=c_type, device=device) 

215 bias = bias.broadcast_to(C.shape).contiguous() 

216 BLOCK_SIZE_M = 128 

217 BLOCK_SIZE_N = 128 

218 BLOCK_SIZE_K = 64 

219 desc_a = TensorDescriptor.from_tensor(mat1, [BLOCK_SIZE_M, BLOCK_SIZE_K]) 

220 desc_b = TensorDescriptor.from_tensor(mat2, [BLOCK_SIZE_K, BLOCK_SIZE_N]) 

221 desc_bias = TensorDescriptor.from_tensor(bias, [BLOCK_SIZE_M, BLOCK_SIZE_N]) 

222 desc_c = TensorDescriptor.from_tensor(C, [BLOCK_SIZE_M, BLOCK_SIZE_N]) 

223 grid = lambda META: ( 

224 triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 

225 1, 

226 1, 

227 ) 

228 addmm_sqmma_kernel[grid]( 

229 desc_a, 

230 desc_b, 

231 desc_bias, 

232 desc_c, 

233 M, 

234 N, 

235 K, 

236 alpha, 

237 beta, 

238 BLOCK_SIZE_M, 

239 BLOCK_SIZE_N, 

240 BLOCK_SIZE_K, 

241 num_warps=4, 

242 num_stages=1, 

243 ) 

244 return C 

245 

246 

247def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

248 a_dtype = mat1.dtype 

249 M, K = mat1.shape 

250 _, N = mat2.shape 

251 

252 if is_sqmma_compatible(mat1, mat2, N, K): 

253 return addmm_sqmma( 

254 mat1, 

255 mat2, 

256 bias, 

257 a_dtype, 

258 alpha, 

259 beta, 

260 M, 

261 N, 

262 K, 

263 ) 

264 else: 

265 return addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta) 

266 

267 

268def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1): 

269 logger.debug("GEMS_MTHREADS ADDMM_DTYPE") 

270 out = torch.empty( 

271 (mat1.shape[0], mat2.shape[1]), 

272 device=mat1.device, 

273 dtype=out_dtype, 

274 ) 

275 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out) 

276 

277 

278def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out): 

279 logger.debug("GEMS_MTHREADS ADDMM_DTYPE_OUT") 

280 if mat1.dtype != mat2.dtype: 

281 raise RuntimeError( 

282 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}" 

283 ) 

284 if out.dtype != out_dtype: 

285 raise RuntimeError( 

286 "out_dtype must be the same as the dtype of the provided out tensor" 

287 ) 

288 if not ( 

289 out_dtype == mat1.dtype 

290 or ( 

291 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16) 

292 ) 

293 ): 

294 raise RuntimeError( 

295 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs" 

296 ) 

297 if bias.dtype != out_dtype and bias.dtype != mat1.dtype: 

298 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype") 

299 

300 bias_c = bias.to(out_dtype) 

301 M, K = mat1.shape 

302 _, N = mat2.shape 

303 a_dtype = mat1.dtype 

304 

305 if is_sqmma_compatible(mat1, mat2, N, K): 

306 result = addmm_sqmma( 

307 mat1, 

308 mat2, 

309 bias_c, 

310 a_dtype, 

311 alpha, 

312 beta, 

313 M, 

314 N, 

315 K, 

316 ) 

317 else: 

318 result = addmm_fma(bias_c, mat1, mat2, alpha=alpha, beta=beta) 

319 out.copy_(result) 

320 return out