Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
7from triton.tools.tensor_descriptor import TensorDescriptor
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import broadcastable_to, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(os.path.dirname(__file__), "..", "addmm_mthreads_expand.yaml")
21)
24def is_supported_sqmma_layout(tensor):
25 return tensor.is_contiguous() or (
26 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
27 )
30def is_sqmma_compatible(a, b, N, K):
31 return (
32 a.dim() == 2
33 and b.dim() == 2
34 and a.dtype == b.dtype
35 and a.dtype in (torch.float16, torch.bfloat16)
36 and is_supported_sqmma_layout(a)
37 and is_supported_sqmma_layout(b)
38 and N % 8 == 0
39 and K % 8 == 0
40 )
43@libentry()
44@libtuner(
45 configs=runtime.get_tuned_config("addmm"),
46 key=["M", "N", "K"],
47)
48@triton.jit(do_not_specialize=["alpha", "beta"])
49def addmm_kernel(
50 a_ptr,
51 b_ptr,
52 i_ptr,
53 c_ptr,
54 alpha,
55 beta,
56 M,
57 N,
58 K,
59 stride_am,
60 stride_ak,
61 stride_bk,
62 stride_bn,
63 stride_im,
64 stride_in,
65 stride_cm,
66 stride_cn,
67 BLOCK_SIZE_M: tl.constexpr,
68 BLOCK_SIZE_N: tl.constexpr,
69 BLOCK_SIZE_K: tl.constexpr,
70 IS_FP64: tl.constexpr = False,
71):
72 pid_m = ext.program_id(0)
73 pid_n = ext.program_id(1)
75 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
76 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77 offs_k = tl.arange(0, BLOCK_SIZE_K)
78 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
79 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
81 if IS_FP64:
82 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
83 else:
84 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
85 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
86 a = tl.load(
87 a_ptrs,
88 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
89 other=0.0,
90 )
91 b = tl.load(
92 b_ptrs,
93 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
94 other=0.0,
95 )
96 if IS_FP64:
97 a = a.to(tl.float32)
98 b = b.to(tl.float32)
99 accumulator += tl.dot(a, b, allow_tf32=False)
100 a_ptrs += BLOCK_SIZE_K * stride_ak
101 b_ptrs += BLOCK_SIZE_K * stride_bk
103 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
104 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
106 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
107 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
108 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
110 accumulator = accumulator * alpha + bias * beta
111 c = accumulator.to(bias.dtype)
112 tl.store(c_ptrs, c, mask=c_mask)
115def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1):
116 logger.debug("GEMS_MTHREADS ADDMM(FMA)")
117 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
118 assert broadcastable_to(
119 bias.shape, (mat1.shape[0], mat2.shape[1])
120 ), "Incompatible input shape"
121 M, K = mat1.shape
122 _, N = mat2.shape
124 mat1 = mat1.contiguous()
125 mat2 = mat2.contiguous()
126 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
127 bias = bias.broadcast_to(out.shape).contiguous()
129 grid = lambda META: (
130 triton.cdiv(M, META["BLOCK_SIZE_M"]),
131 triton.cdiv(N, META["BLOCK_SIZE_N"]),
132 )
133 with torch_device_fn.device(mat1.device):
134 addmm_kernel[grid](
135 mat1,
136 mat2,
137 bias,
138 out,
139 alpha,
140 beta,
141 M,
142 N,
143 K,
144 mat1.stride(0),
145 mat1.stride(1),
146 mat2.stride(0),
147 mat2.stride(1),
148 bias.stride(0),
149 bias.stride(1),
150 out.stride(0),
151 out.stride(1),
152 IS_FP64=mat1.dtype == torch.float64,
153 )
154 return out
157def addmm_sqmma_descriptor_pre_hook(nargs):
158 nargs["a_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_K"]]
159 nargs["b_desc"].block_shape = [nargs["BLOCK_SIZE_K"], nargs["BLOCK_SIZE_N"]]
160 nargs["bias_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_N"]]
161 nargs["c_desc"].block_shape = [nargs["BLOCK_SIZE_M"], nargs["BLOCK_SIZE_N"]]
164@libentry()
165@libtuner(
166 configs=runtime.ops_get_configs(
167 "addmm_sqmma",
168 pre_hook=addmm_sqmma_descriptor_pre_hook,
169 yaml_path=EXPAND_CONFIG_FILENAME,
170 )
171 if os.environ.get("USE_FLAGTUNE") == "1"
172 else [
173 triton.Config(
174 {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
175 num_stages=1,
176 num_warps=4,
177 pre_hook=addmm_sqmma_descriptor_pre_hook,
178 )
179 ],
180 key=["M", "N", "K"],
181 strategy=runtime.get_expand_config("addmm_sqmma", yaml_path=EXPAND_CONFIG_FILENAME)[
182 "strategy"
183 ]
184 if os.environ.get("USE_FLAGTUNE") == "1"
185 else ["default", "default", "default"],
186 warmup=5,
187 rep=5,
188)
189@triton.jit(do_not_specialize=["alpha", "beta"])
190def addmm_sqmma_kernel(
191 a_desc,
192 b_desc,
193 bias_desc,
194 c_desc,
195 M,
196 N,
197 K,
198 alpha,
199 beta,
200 DTYPE: tl.constexpr,
201 BLOCK_SIZE_M: tl.constexpr,
202 BLOCK_SIZE_N: tl.constexpr,
203 BLOCK_SIZE_K: tl.constexpr,
204):
205 pid = tl.program_id(axis=0)
206 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
207 pid_m = pid % num_pid_m
208 pid_n = pid // num_pid_m
209 offs_am = (pid_m * BLOCK_SIZE_M).to(tl.int32)
210 offs_bn = (pid_n * BLOCK_SIZE_N).to(tl.int32)
211 offs_k = 0
212 offs_k = offs_k.to(tl.int32)
213 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
214 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
215 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
216 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
217 accumulator = tl.dot(a, b, acc=accumulator)
218 offs_k += BLOCK_SIZE_K
219 bias = tl.load_tensor_descriptor(bias_desc, [offs_am, offs_bn])
220 result = (alpha * accumulator + beta * bias).to(c_desc.dtype)
221 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], result)
224def addmm_sqmma(mat1, mat2, bias, elem_type, alpha, beta, M, N, K):
225 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)")
226 device = mat1.device
227 assert broadcastable_to(
228 bias.shape, (mat1.shape[0], mat2.shape[1])
229 ), "Incompatible input shape"
230 if not mat1.is_contiguous():
231 mat1 = mat1.contiguous()
232 if not mat2.is_contiguous():
233 mat2 = mat2.contiguous()
234 a_type = mat1.dtype
235 b_type = mat2.dtype
236 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
237 c_type = a_type
238 C = torch.empty((M, N), dtype=c_type, device=device)
239 bias = bias.broadcast_to(C.shape).contiguous()
240 desc_a = TensorDescriptor.from_tensor(mat1, [1, 1])
241 desc_b = TensorDescriptor.from_tensor(mat2, [1, 1])
242 desc_bias = TensorDescriptor.from_tensor(bias, [1, 1])
243 desc_c = TensorDescriptor.from_tensor(C, [1, 1])
244 grid = lambda META: (
245 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
246 1,
247 1,
248 )
249 addmm_sqmma_kernel[grid](
250 desc_a,
251 desc_b,
252 desc_bias,
253 desc_c,
254 M,
255 N,
256 K,
257 alpha,
258 beta,
259 str(a_type).split(".")[-1],
260 )
261 return C
264def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
265 a_dtype = mat1.dtype
266 M, K = mat1.shape
267 _, N = mat2.shape
269 if is_sqmma_compatible(mat1, mat2, N, K):
270 return addmm_sqmma(
271 mat1,
272 mat2,
273 bias,
274 a_dtype,
275 alpha,
276 beta,
277 M,
278 N,
279 K,
280 )
281 else:
282 return addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta)
285def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1):
286 logger.debug("GEMS_MTHREADS ADDMM_DTYPE")
287 out = torch.empty(
288 (mat1.shape[0], mat2.shape[1]),
289 device=mat1.device,
290 dtype=out_dtype,
291 )
292 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out)
295def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out):
296 logger.debug("GEMS_MTHREADS ADDMM_DTYPE_OUT")
297 if mat1.dtype != mat2.dtype:
298 raise RuntimeError(
299 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}"
300 )
301 if out.dtype != out_dtype:
302 raise RuntimeError(
303 "out_dtype must be the same as the dtype of the provided out tensor"
304 )
305 if not (
306 out_dtype == mat1.dtype
307 or (
308 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16)
309 )
310 ):
311 raise RuntimeError(
312 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"
313 )
314 if bias.dtype != out_dtype and bias.dtype != mat1.dtype:
315 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype")
317 bias_c = bias.to(out_dtype)
318 M, K = mat1.shape
319 _, N = mat2.shape
320 a_dtype = mat1.dtype
322 if is_sqmma_compatible(mat1, mat2, N, K):
323 result = addmm_sqmma(
324 mat1,
325 mat2,
326 bias_c,
327 a_dtype,
328 alpha,
329 beta,
330 M,
331 N,
332 K,
333 )
334 else:
335 result = addmm_fma(bias_c, mat1, mat2, alpha=alpha, beta=beta)
336 out.copy_(result)
337 return out