Coverage for src/flag_gems/runtime/backend/_sunrise/ops/addmm.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable_to, libentry, libtuner
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("addmm"),
18 key=["M", "N", "K"],
19 strategy=["align32", "align32", "align32"],
20 warmup=5,
21 rep=10,
22)
23@triton.jit(do_not_specialize=["alpha", "beta"])
24def addmm_kernel(
25 a_ptr,
26 b_ptr,
27 i_ptr,
28 c_ptr,
29 alpha,
30 beta,
31 M,
32 N,
33 K,
34 stride_am,
35 stride_ak,
36 stride_bk,
37 stride_bn,
38 stride_im,
39 stride_in,
40 stride_cm,
41 stride_cn,
42 BLOCK_SIZE_M: tl.constexpr,
43 BLOCK_SIZE_N: tl.constexpr,
44 BLOCK_SIZE_K: tl.constexpr,
45 IS_FP64: tl.constexpr = False,
46):
47 pid_m = ext.program_id(0)
48 pid_n = ext.program_id(1)
50 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
51 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
52 offs_k = tl.arange(0, BLOCK_SIZE_K)
53 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
54 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
56 if IS_FP64:
57 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
58 else:
59 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
60 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
61 a = tl.load(
62 a_ptrs,
63 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
64 other=0.0,
65 )
66 b = tl.load(
67 b_ptrs,
68 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
69 other=0.0,
70 )
71 if IS_FP64:
72 a = a.to(tl.float32)
73 b = b.to(tl.float32)
74 accumulator += tl.dot(a, b, allow_tf32=False)
75 a_ptrs += BLOCK_SIZE_K * stride_ak
76 b_ptrs += BLOCK_SIZE_K * stride_bk
78 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
79 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
80 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
81 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
82 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
83 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
85 accumulator = accumulator * alpha + bias * beta
86 c = accumulator.to(bias.dtype)
87 tl.store(c_ptrs, c, mask=c_mask)
90def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
91 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
92 assert broadcastable_to(
93 bias.shape, (mat1.shape[0], mat2.shape[1])
94 ), "Incompatible input shape"
95 M, K = mat1.shape
96 _, N = mat2.shape
98 logger.debug(
99 "GEMS ADDMM, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
100 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
101 M,
102 N,
103 K,
104 mat1.stride(0) == 1,
105 mat2.stride(0) == 1,
106 bias.stride(0) == 1,
107 )
108 mat1 = mat1.contiguous()
109 # mat2 = mat2.contiguous()
110 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
111 bias = bias.broadcast_to(out.shape)
113 grid = lambda META: (
114 triton.cdiv(M, META["BLOCK_SIZE_M"]),
115 triton.cdiv(N, META["BLOCK_SIZE_N"]),
116 )
117 with torch_device_fn.device(mat1.device):
118 addmm_kernel[grid](
119 mat1,
120 mat2,
121 bias,
122 out,
123 alpha,
124 beta,
125 M,
126 N,
127 K,
128 mat1.stride(0),
129 mat1.stride(1),
130 mat2.stride(0),
131 mat2.stride(1),
132 bias.stride(0),
133 bias.stride(1),
134 out.stride(0),
135 out.stride(1),
136 IS_FP64=mat1.dtype == torch.float64,
137 )
138 return out
141def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None):
142 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
143 assert broadcastable_to(
144 bias.shape, (mat1.shape[0], mat2.shape[1])
145 ), "Incompatible input shape"
146 M, K = mat1.shape
147 _, N = mat2.shape
148 if out is None:
149 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
150 else:
151 assert out.shape == (M, N), "Incompatible output shape"
152 logger.debug(
153 "GEMS ADDMM_OUT, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
154 "[A column-major]: %s, [B column-major]: %s, [bias column-major]: %s",
155 M,
156 N,
157 K,
158 mat1.stride(0) == 1,
159 mat2.stride(0) == 1,
160 bias.stride(0) == 1,
161 )
162 mat1 = mat1.contiguous()
163 bias = bias.broadcast_to(out.shape)
165 grid = lambda META: (
166 triton.cdiv(M, META["BLOCK_SIZE_M"]),
167 triton.cdiv(N, META["BLOCK_SIZE_N"]),
168 )
169 with torch_device_fn.device(mat1.device):
170 addmm_kernel[grid](
171 mat1,
172 mat2,
173 bias,
174 out,
175 alpha,
176 beta,
177 M,
178 N,
179 K,
180 mat1.stride(0),
181 mat1.stride(1),
182 mat2.stride(0),
183 mat2.stride(1),
184 bias.stride(0),
185 bias.stride(1),
186 out.stride(0),
187 out.stride(1),
188 IS_FP64=mat1.dtype == torch.float64,
189 )
190 return out
193def addmm_dtype(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1):
194 logger.debug("GEMS ADDMM_DTYPE")
195 out = torch.empty(
196 (mat1.shape[0], mat2.shape[1]),
197 device=mat1.device,
198 dtype=out_dtype,
199 )
200 return addmm_dtype_out(bias, mat1, mat2, out_dtype, beta=beta, alpha=alpha, out=out)
203def addmm_dtype_out(bias, mat1, mat2, out_dtype, *, beta=1, alpha=1, out):
204 logger.debug("GEMS ADDMM_DTYPE_OUT")
205 if mat1.dtype != mat2.dtype:
206 raise RuntimeError(
207 f"mat1 and mat2 must have the same dtype, but got {mat1.dtype} and {mat2.dtype}"
208 )
209 if out.dtype != out_dtype:
210 raise RuntimeError(
211 "out_dtype must be the same as the dtype of the provided out tensor"
212 )
213 if not (
214 out_dtype == mat1.dtype
215 or (
216 out_dtype == torch.float32 and mat1.dtype in (torch.float16, torch.bfloat16)
217 )
218 ):
219 raise RuntimeError(
220 "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs"
221 )
222 if bias.dtype != out_dtype and bias.dtype != mat1.dtype:
223 raise RuntimeError("self dtype must match either out_dtype or mat1 dtype")
225 bias_c = bias.to(out_dtype)
226 return addmm_out(bias_c, mat1, mat2, beta=beta, alpha=alpha, out=out)