Coverage for src/flag_gems/runtime/backend/_mthreads/ops/addmm.py: 0%
138 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
7from triton.tools.tensor_descriptor import TensorDescriptor
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import broadcastable_to, libentry, libtuner
12from flag_gems.utils import triton_lang_extension as ext
14logger = logging.getLogger(
15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
16)
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(os.path.dirname(__file__), "..", "addmm_mthreads_expand.yaml")
21)
24def is_supported_sqmma_layout(tensor):
25 return tensor.is_contiguous() or (
26 tensor.stride(0) == 1 and tensor.stride(1) == tensor.shape[0]
27 )
30def is_sqmma_compatible(a, b, N, K):
31 return (
32 a.dim() == 2
33 and b.dim() == 2
34 and a.dtype == b.dtype
35 and a.dtype in (torch.float16, torch.bfloat16)
36 and is_supported_sqmma_layout(a)
37 and is_supported_sqmma_layout(b)
38 and N % 8 == 0
39 and K % 8 == 0
40 )
43@libentry()
44@libtuner(
45 configs=runtime.get_tuned_config("addmm"),
46 key=["M", "N", "K"],
47)
48@triton.jit(do_not_specialize=["alpha", "beta"])
49def addmm_kernel(
50 a_ptr,
51 b_ptr,
52 i_ptr,
53 c_ptr,
54 alpha,
55 beta,
56 M,
57 N,
58 K,
59 stride_am,
60 stride_ak,
61 stride_bk,
62 stride_bn,
63 stride_im,
64 stride_in,
65 stride_cm,
66 stride_cn,
67 BLOCK_SIZE_M: tl.constexpr,
68 BLOCK_SIZE_N: tl.constexpr,
69 BLOCK_SIZE_K: tl.constexpr,
70 IS_FP64: tl.constexpr = False,
71):
72 pid_m = ext.program_id(0)
73 pid_n = ext.program_id(1)
75 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
76 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
77 offs_k = tl.arange(0, BLOCK_SIZE_K)
78 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
79 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
81 if IS_FP64:
82 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
83 else:
84 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
85 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
86 a = tl.load(
87 a_ptrs,
88 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
89 other=0.0,
90 )
91 b = tl.load(
92 b_ptrs,
93 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
94 other=0.0,
95 )
96 if IS_FP64:
97 a = a.to(tl.float32)
98 b = b.to(tl.float32)
99 accumulator += tl.dot(a, b, allow_tf32=False)
100 a_ptrs += BLOCK_SIZE_K * stride_ak
101 b_ptrs += BLOCK_SIZE_K * stride_bk
103 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
104 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
105 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
106 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
107 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
108 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
110 accumulator = accumulator * alpha + bias * beta
111 c = accumulator.to(bias.dtype)
112 tl.store(c_ptrs, c, mask=c_mask)
115def addmm_fma(bias, mat1, mat2, *, beta=1, alpha=1):
116 logger.debug("GEMS_MTHREADS ADDMM(FMA)")
117 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
118 assert broadcastable_to(
119 bias.shape, (mat1.shape[0], mat2.shape[1])
120 ), "Incompatible input shape"
121 M, K = mat1.shape
122 _, N = mat2.shape
124 mat1 = mat1.contiguous()
125 mat2 = mat2.contiguous()
126 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
127 bias = bias.broadcast_to(out.shape).contiguous()
129 grid = lambda META: (
130 triton.cdiv(M, META["BLOCK_SIZE_M"]),
131 triton.cdiv(N, META["BLOCK_SIZE_N"]),
132 )
133 with torch_device_fn.device(mat1.device):
134 addmm_kernel[grid](
135 mat1,
136 mat2,
137 bias,
138 out,
139 alpha,
140 beta,
141 M,
142 N,
143 K,
144 mat1.stride(0),
145 mat1.stride(1),
146 mat2.stride(0),
147 mat2.stride(1),
148 bias.stride(0),
149 bias.stride(1),
150 out.stride(0),
151 out.stride(1),
152 IS_FP64=mat1.dtype == torch.float64,
153 )
154 return out
157@triton.jit(do_not_specialize=["alpha", "beta"])
158def addmm_sqmma_kernel(
159 a_desc,
160 b_desc,
161 bias_desc,
162 c_desc,
163 M,
164 N,
165 K,
166 alpha,
167 beta,
168 BLOCK_SIZE_M: tl.constexpr,
169 BLOCK_SIZE_N: tl.constexpr,
170 BLOCK_SIZE_K: tl.constexpr,
171):
172 pid = tl.program_id(axis=0)
173 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
174 pid_m = pid % num_pid_m
175 pid_n = pid // num_pid_m
176 offs_am = (pid_m * BLOCK_SIZE_M).to(tl.int32)
177 offs_bn = (pid_n * BLOCK_SIZE_N).to(tl.int32)
178 offs_k = 0
179 offs_k = offs_k.to(tl.int32)
180 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
181 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
182 a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k])
183 b = tl.load_tensor_descriptor(b_desc, [offs_k, offs_bn])
184 accumulator = tl.dot(a, b, acc=accumulator)
185 offs_k += BLOCK_SIZE_K
186 bias = tl.load_tensor_descriptor(bias_desc, [offs_am, offs_bn])
187 result = (alpha * accumulator + beta * bias).to(c_desc.dtype)
188 tl.store_tensor_descriptor(c_desc, [offs_am, offs_bn], result)
191def get_triton_type(elem_type):
192 type_map = {
193 torch.float16: tl.float16,
194 torch.bfloat16: tl.bfloat16,
195 torch.float8_e4m3fn: tl.float8e4nv,
196 }
197 return type_map.get(elem_type, None)
200def addmm_sqmma(mat1, mat2, bias, elem_type, alpha, beta, M, N, K):
201 logger.debug("GEMS_MTHREADS ADDMM(SQMMA)")
202 device = mat1.device
203 assert broadcastable_to(
204 bias.shape, (mat1.shape[0], mat2.shape[1])
205 ), "Incompatible input shape"
206 if not mat1.is_contiguous():
207 mat1 = mat1.contiguous()
208 if not mat2.is_contiguous():
209 mat2 = mat2.contiguous()
210 a_type = mat1.dtype
211 b_type = mat2.dtype
212 assert a_type == b_type, "Mat A and Mat B should have the same dtype"
213 c_type = a_type
214 C = torch.empty((M, N), dtype=c_type, device=device)
215 bias = bias.broadcast_to(C.shape).contiguous()
216 BLOCK_SIZE_M = 128
217 BLOCK_SIZE_N = 128
218 BLOCK_SIZE_K = 64
219 desc_a = TensorDescriptor.from_tensor(mat1, [BLOCK_SIZE_M, BLOCK_SIZE_K])
220 desc_b = TensorDescriptor.from_tensor(mat2, [BLOCK_SIZE_K, BLOCK_SIZE_N])
221 desc_bias = TensorDescriptor.from_tensor(bias, [BLOCK_SIZE_M, BLOCK_SIZE_N])
222 desc_c = TensorDescriptor.from_tensor(C, [BLOCK_SIZE_M, BLOCK_SIZE_N])
223 grid = lambda META: (
224 triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),
225 1,
226 1,
227 )
228 addmm_sqmma_kernel[grid](
229 desc_a,
230 desc_b,
231 desc_bias,
232 desc_c,
233 M,
234 N,
235 K,
236 alpha,
237 beta,
238 BLOCK_SIZE_M,
239 BLOCK_SIZE_N,
240 BLOCK_SIZE_K,
241 num_warps=4,
242 num_stages=1,
243 )
244 return C
247def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
248 a_dtype = mat1.dtype
249 M, K = mat1.shape
250 _, N = mat2.shape
252 if is_sqmma_compatible(mat1, mat2, N, K):
253 return addmm_sqmma(
254 mat1,
255 mat2,
256 bias,
257 a_dtype,
258 alpha,
259 beta,
260 M,
261 N,
262 K,
263 )
264 else:
265 return addmm_fma(bias, mat1, mat2, alpha=alpha, beta=beta)
268def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1):
269 logger.debug("GEMS_MTHREADS ADDMM_DTYPE")
270 out = torch.empty(
271 (mat1.shape[0], mat2.shape[1]),
272 device=mat1.device,
273 dtype=out_dtype,
274 )
275 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out)
278def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out):
279 logger.debug("GEMS_MTHREADS ADDMM_DTYPE_OUT")
280 if mat1.dtype != mat2.dtype:
281 raise RuntimeError(
282 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}"
283 )
284 if out.dtype != out_dtype:
285 raise RuntimeError(
286 "out_dtype must be the same as the dtype of the provided out tensor"
287 )
288 if not (
289 out_dtype == mat1.dtype
290 or (
291 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16)
292 )
293 ):
294 raise RuntimeError(
295 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"
296 )
297 if bias.dtype != out_dtype and bias.dtype != mat1.dtype:
298 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype")
300 bias_c = bias.to(out_dtype)
301 M, K = mat1.shape
302 _, N = mat2.shape
303 a_dtype = mat1.dtype
305 if is_sqmma_compatible(mat1, mat2, N, K):
306 result = addmm_sqmma(
307 mat1,
308 mat2,
309 bias_c,
310 a_dtype,
311 alpha,
312 beta,
313 M,
314 N,
315 K,
316 )
317 else:
318 result = addmm_fma(bias_c, mat1, mat2, alpha=alpha, beta=beta)
319 out.copy_(result)
320 return out