Coverage for src/flag_gems/ops/addmm.py: 49%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry, libtuner
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("addmm"),
18 key=["M", "N", "K"],
19 strategy=["align32", "align32", "align32"],
20 warmup=5,
21 rep=10,
22 flagtune_op_name="addmm",
23)
24@triton.jit(do_not_specialize=["alpha", "beta"])
25def addmm_kernel(
26 a_ptr,
27 b_ptr,
28 i_ptr,
29 c_ptr,
30 alpha,
31 beta,
32 M,
33 N,
34 K,
35 stride_am,
36 stride_ak,
37 stride_bk,
38 stride_bn,
39 stride_im,
40 stride_in,
41 stride_cm,
42 stride_cn,
43 BLOCK_SIZE_M: tl.constexpr,
44 BLOCK_SIZE_N: tl.constexpr,
45 BLOCK_SIZE_K: tl.constexpr,
46 IS_FP64: tl.constexpr = False,
47):
48 pid_m = ext.program_id(0)
49 pid_n = ext.program_id(1)
51 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
52 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
53 offs_k = tl.arange(0, BLOCK_SIZE_K)
54 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
55 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
57 if IS_FP64:
58 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
59 else:
60 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
61 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
62 a = tl.load(
63 a_ptrs,
64 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
65 other=0.0,
66 )
67 b = tl.load(
68 b_ptrs,
69 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
70 other=0.0,
71 )
72 if IS_FP64:
73 a = a.to(tl.float32)
74 b = b.to(tl.float32)
75 accumulator += tl.dot(a, b, allow_tf32=False)
76 a_ptrs += BLOCK_SIZE_K * stride_ak
77 b_ptrs += BLOCK_SIZE_K * stride_bk
79 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
80 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
81 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
82 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
83 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
84 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
86 accumulator = accumulator * alpha + bias * beta
87 c = accumulator.to(bias.dtype)
88 tl.store(c_ptrs, c, mask=c_mask)
91def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
92 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
93 assert broadcastable_to(
94 bias.shape, (mat1.shape[0], mat2.shape[1])
95 ), "Incompatible input shape"
96 M, K = mat1.shape
97 _, N = mat2.shape
99 logger.debug(
100 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
101 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
102 M,
103 N,
104 K,
105 mat1.stride(0) == 1,
106 mat2.stride(0) == 1,
107 bias.stride(0) == 1,
108 )
109 mat1 = mat1.contiguous()
110 # mat2 = mat2.contiguous()
111 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
112 bias = bias.broadcast_to(out.shape)
114 grid = lambda META: (
115 triton.cdiv(M, META["BLOCK_SIZE_M"]),
116 triton.cdiv(N, META["BLOCK_SIZE_N"]),
117 )
118 with torch_device_fn.device(mat1.device):
119 addmm_kernel[grid](
120 mat1,
121 mat2,
122 bias,
123 out,
124 alpha,
125 beta,
126 M,
127 N,
128 K,
129 mat1.stride(0),
130 mat1.stride(1),
131 mat2.stride(0),
132 mat2.stride(1),
133 bias.stride(0),
134 bias.stride(1),
135 out.stride(0),
136 out.stride(1),
137 IS_FP64=mat1.dtype == torch.float64,
138 )
139 return out
142def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None):
143 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
144 assert broadcastable_to(
145 bias.shape, (mat1.shape[0], mat2.shape[1])
146 ), "Incompatible input shape"
147 M, K = mat1.shape
148 _, N = mat2.shape
149 if out is None:
150 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
151 else:
152 assert out.shape == (M, N), "Incompatible output shape"
153 logger.debug(
154 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
155 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
156 M,
157 N,
158 K,
159 mat1.stride(0) == 1,
160 mat2.stride(0) == 1,
161 bias.stride(0) == 1,
162 )
163 mat1 = mat1.contiguous()
164 bias = bias.broadcast_to(out.shape)
166 grid = lambda META: (
167 triton.cdiv(M, META["BLOCK_SIZE_M"]),
168 triton.cdiv(N, META["BLOCK_SIZE_N"]),
169 )
170 with torch_device_fn.device(mat1.device):
171 addmm_kernel[grid](
172 mat1,
173 mat2,
174 bias,
175 out,
176 alpha,
177 beta,
178 M,
179 N,
180 K,
181 mat1.stride(0),
182 mat1.stride(1),
183 mat2.stride(0),
184 mat2.stride(1),
185 bias.stride(0),
186 bias.stride(1),
187 out.stride(0),
188 out.stride(1),
189 IS_FP64=mat1.dtype == torch.float64,
190 )
191 return out
194def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1):
195 logger.debug("GEMS ADDMM_DTYPE")
196 out = torch.empty(
197 (mat1.shape[0], mat2.shape[1]),
198 device=mat1.device,
199 dtype=out_dtype,
200 )
201 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out)
204def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out):
205 logger.debug("GEMS ADDMM_DTYPE_OUT")
206 if mat1.dtype != mat2.dtype:
207 raise RuntimeError(
208 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}"
209 )
210 if out.dtype != out_dtype:
211 raise RuntimeError(
212 "out_dtype must be the same as the dtype of the provided out tensor"
213 )
214 if not (
215 out_dtype == mat1.dtype
216 or (
217 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16)
218 )
219 ):
220 raise RuntimeError(
221 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"
222 )
223 if bias.dtype != out_dtype and bias.dtype != mat1.dtype:
224 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype")
226 bias_c = bias.to(out_dtype)
227 return addmm_out(bias_c, mat1, mat2, beta=beta, alpha=alpha, out=out)