Coverage for src/flag_gems/ops/addmm.py: 50%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import broadcastable_to, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@libentry()
17@libtuner(
18 configs=runtime.ops_get_configs("addmm", pre_hook=None)
19 if os.environ.get("USE_FLAGTUNE") == "1"
20 else runtime.get_tuned_config("addmm"),
21 key=["M", "N", "K"],
22 strategy=runtime.get_expand_config("addmm")["strategy"]
23 if os.environ.get("USE_FLAGTUNE") == "1"
24 else ["align32", "align32", "align32"],
25 warmup=5,
26 rep=10,
27)
28@triton.jit(do_not_specialize=["alpha", "beta"])
29def addmm_kernel(
30 a_ptr,
31 b_ptr,
32 i_ptr,
33 c_ptr,
34 alpha,
35 beta,
36 M,
37 N,
38 K,
39 stride_am,
40 stride_ak,
41 stride_bk,
42 stride_bn,
43 stride_im,
44 stride_in,
45 stride_cm,
46 stride_cn,
47 BLOCK_SIZE_M: tl.constexpr,
48 BLOCK_SIZE_N: tl.constexpr,
49 BLOCK_SIZE_K: tl.constexpr,
50 IS_FP64: tl.constexpr = False,
51):
52 pid_m = ext.program_id(0)
53 pid_n = ext.program_id(1)
55 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
56 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
57 offs_k = tl.arange(0, BLOCK_SIZE_K)
58 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
59 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
61 if IS_FP64:
62 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
63 else:
64 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
65 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
66 a = tl.load(
67 a_ptrs,
68 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
69 other=0.0,
70 )
71 b = tl.load(
72 b_ptrs,
73 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
74 other=0.0,
75 )
76 if IS_FP64:
77 a = a.to(tl.float32)
78 b = b.to(tl.float32)
79 accumulator += tl.dot(a, b, allow_tf32=False)
80 a_ptrs += BLOCK_SIZE_K * stride_ak
81 b_ptrs += BLOCK_SIZE_K * stride_bk
83 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
84 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
85 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
86 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
87 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
88 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
90 accumulator = accumulator * alpha + bias * beta
91 c = accumulator.to(bias.dtype)
92 tl.store(c_ptrs, c, mask=c_mask)
95def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
96 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
97 assert broadcastable_to(
98 bias.shape, (mat1.shape[0], mat2.shape[1])
99 ), "Incompatible input shape"
100 M, K = mat1.shape
101 _, N = mat2.shape
103 logger.debug(
104 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
105 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
106 M,
107 N,
108 K,
109 mat1.stride(0) == 1,
110 mat2.stride(0) == 1,
111 bias.stride(0) == 1,
112 )
113 mat1 = mat1.contiguous()
114 # mat2 = mat2.contiguous()
115 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
116 bias = bias.broadcast_to(out.shape)
118 grid = lambda META: (
119 triton.cdiv(M, META["BLOCK_SIZE_M"]),
120 triton.cdiv(N, META["BLOCK_SIZE_N"]),
121 )
122 with torch_device_fn.device(mat1.device):
123 addmm_kernel[grid](
124 mat1,
125 mat2,
126 bias,
127 out,
128 alpha,
129 beta,
130 M,
131 N,
132 K,
133 mat1.stride(0),
134 mat1.stride(1),
135 mat2.stride(0),
136 mat2.stride(1),
137 bias.stride(0),
138 bias.stride(1),
139 out.stride(0),
140 out.stride(1),
141 IS_FP64=mat1.dtype == torch.float64,
142 )
143 return out
146def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None):
147 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
148 assert broadcastable_to(
149 bias.shape, (mat1.shape[0], mat2.shape[1])
150 ), "Incompatible input shape"
151 M, K = mat1.shape
152 _, N = mat2.shape
153 if out is None:
154 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
155 else:
156 assert out.shape == (M, N), "Incompatible output shape"
157 logger.debug(
158 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
159 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
160 M,
161 N,
162 K,
163 mat1.stride(0) == 1,
164 mat2.stride(0) == 1,
165 bias.stride(0) == 1,
166 )
167 mat1 = mat1.contiguous()
168 bias = bias.broadcast_to(out.shape)
170 grid = lambda META: (
171 triton.cdiv(M, META["BLOCK_SIZE_M"]),
172 triton.cdiv(N, META["BLOCK_SIZE_N"]),
173 )
174 with torch_device_fn.device(mat1.device):
175 addmm_kernel[grid](
176 mat1,
177 mat2,
178 bias,
179 out,
180 alpha,
181 beta,
182 M,
183 N,
184 K,
185 mat1.stride(0),
186 mat1.stride(1),
187 mat2.stride(0),
188 mat2.stride(1),
189 bias.stride(0),
190 bias.stride(1),
191 out.stride(0),
192 out.stride(1),
193 IS_FP64=mat1.dtype == torch.float64,
194 )
195 return out
198def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1):
199 logger.debug("GEMS ADDMM_DTYPE")
200 out = torch.empty(
201 (mat1.shape[0], mat2.shape[1]),
202 device=mat1.device,
203 dtype=out_dtype,
204 )
205 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out)
208def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out):
209 logger.debug("GEMS ADDMM_DTYPE_OUT")
210 if mat1.dtype != mat2.dtype:
211 raise RuntimeError(
212 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}"
213 )
214 if out.dtype != out_dtype:
215 raise RuntimeError(
216 "out_dtype must be the same as the dtype of the provided out tensor"
217 )
218 if not (
219 out_dtype == mat1.dtype
220 or (
221 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16)
222 )
223 ):
224 raise RuntimeError(
225 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"
226 )
227 if bias.dtype != out_dtype and bias.dtype != mat1.dtype:
228 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype")
230 bias_c = bias.to(out_dtype)
231 return addmm_out(bias_c, mat1, mat2, beta=beta, alpha=alpha, out=out)