Coverage for src/flag_gems/runtime/backend/_arm/ops/addmm.py: 0%
181 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import broadcastable_to
9from flag_gems.utils import triton_lang_extension as tle
11ADDMM_M1_CONFIG_TABLE = (
12 {"n_min": 4096, "k_min": 0, "config": (64, 8)},
13 {"n_min": 2048, "k_min": 0, "config": (32, 16)},
14 {"n_min": 0, "k_min": 3072, "config": (16, 16)},
15 {"n_min": 0, "k_min": 0, "config": (8, 32)},
16)
18ADDMM_M1_TRANSPOSED_CONFIG_TABLE = (
19 # Tuned on CIX P1 aarch64 (2026-03-04): BK=64 fills a full cache line.
20 {"n_min": 65536, "k_min": 0, "config": (2, 64)},
21 {"n_min": 2048, "k_min": 0, "config": (4, 64)},
22 {"n_min": 0, "k_min": 2048, "config": (4, 64)},
23 {"n_min": 0, "k_min": 0, "config": (4, 64)},
24)
27def _select_addmm_m1_config(N, K):
28 for rule in ADDMM_M1_CONFIG_TABLE:
29 if N >= rule.get("n_min", 0) and K >= rule.get("k_min", 0):
30 return rule["config"]
31 return 8, 32
34def _select_addmm_m1_transposed_config(N, K):
35 for rule in ADDMM_M1_TRANSPOSED_CONFIG_TABLE:
36 if N >= rule.get("n_min", 0) and K >= rule.get("k_min", 0):
37 return rule["config"]
38 return 8, 32
41def _is_rhs_transposed_layout(rhs):
42 if rhs.ndim != 2:
43 return False
44 return rhs.stride(0) == 1 and rhs.stride(1) >= rhs.shape[0]
47def _use_addmm_m1_transposed_fastpath_shape(N, K):
48 # Avoid unstable LLVM lowering for tiny matrices on ARM cpu backend.
49 return N >= 256 and K >= 256
52def _use_addmm_m1_fastpath_shape(N, K):
53 return N >= 256 and K >= 256
56@triton.jit(do_not_specialize=["alpha", "beta"])
57def addmm_m1_kernel(
58 a_ptr,
59 b_ptr,
60 i_ptr,
61 c_ptr,
62 alpha,
63 beta,
64 N,
65 K,
66 stride_ak,
67 stride_bk,
68 stride_bn,
69 stride_in,
70 stride_cn,
71 BLOCK_N: tl.constexpr,
72 BLOCK_K: tl.constexpr,
73 EVEN_K: tl.constexpr,
74):
75 pid_n = tle.program_id(0)
76 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
77 rk = tl.arange(0, BLOCK_K)
79 a_ptrs = a_ptr + rk * stride_ak
80 b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn
81 acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
83 for k in range(0, tl.cdiv(K, BLOCK_K)):
84 if EVEN_K:
85 a = tl.load(a_ptrs)
86 b = tl.load(b_ptrs)
87 else:
88 k_remaining = K - k * BLOCK_K
89 a = tl.load(a_ptrs, mask=rk < k_remaining, other=0.0)
90 b = tl.load(
91 b_ptrs,
92 mask=(rk[:, None] < k_remaining) & (rn[None, :] < N),
93 other=0.0,
94 )
96 a_fp = a.to(tl.float32)
97 b_fp = b.to(tl.float32)
98 acc += tl.sum(b_fp * a_fp[:, None], axis=0)
99 a_ptrs += BLOCK_K * stride_ak
100 b_ptrs += BLOCK_K * stride_bk
102 if beta == 0:
103 out = acc * alpha
104 else:
105 bias_ptrs = i_ptr + rn * stride_in
106 bias = tl.load(bias_ptrs, mask=rn < N, other=0.0).to(tl.float32)
107 out = acc * alpha + bias * beta
108 c_ptrs = c_ptr + rn * stride_cn
109 tl.store(c_ptrs, out.to(c_ptr.dtype.element_ty), mask=rn < N)
112@triton.jit(do_not_specialize=["alpha", "beta"])
113def addmm_m1_transposed_rhs_kernel(
114 a_ptr,
115 b_ptr,
116 i_ptr,
117 c_ptr,
118 alpha,
119 beta,
120 N,
121 K,
122 stride_ak,
123 stride_bk,
124 stride_bn,
125 stride_in,
126 stride_cn,
127 BLOCK_N: tl.constexpr,
128 BLOCK_K: tl.constexpr,
129 EVEN_K: tl.constexpr,
130):
131 pid_n = tle.program_id(0)
132 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
133 rk = tl.arange(0, BLOCK_K)
135 a_ptrs = a_ptr + rk * stride_ak
136 bt_ptrs = b_ptr + rn[:, None] * stride_bn + rk[None, :] * stride_bk
137 acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
139 for k in range(0, tl.cdiv(K, BLOCK_K)):
140 if EVEN_K:
141 a = tl.load(a_ptrs)
142 bt = tl.load(bt_ptrs, mask=rn[:, None] < N, other=0.0)
143 else:
144 k_remaining = K - k * BLOCK_K
145 a = tl.load(a_ptrs, mask=rk < k_remaining, other=0.0)
146 bt = tl.load(
147 bt_ptrs,
148 mask=(rn[:, None] < N) & (rk[None, :] < k_remaining),
149 other=0.0,
150 )
152 a_fp = a.to(tl.float32)
153 bt_fp = bt.to(tl.float32)
154 acc += tl.sum(bt_fp * a_fp[None, :], axis=1)
155 a_ptrs += BLOCK_K * stride_ak
156 bt_ptrs += BLOCK_K * stride_bk
158 if beta == 0:
159 out = acc * alpha
160 else:
161 bias_ptrs = i_ptr + rn * stride_in
162 bias = tl.load(bias_ptrs, mask=rn < N, other=0.0).to(tl.float32)
163 out = acc * alpha + bias * beta
164 c_ptrs = c_ptr + rn * stride_cn
165 tl.store(c_ptrs, out.to(c_ptr.dtype.element_ty), mask=rn < N)
168def _launch_addmm_m1_kernel(mat1, mat2, bias, out, alpha, beta, N, K):
169 block_n, block_k = _select_addmm_m1_config(N, K)
170 grid = lambda META: (triton.cdiv(N, block_n),)
171 addmm_m1_kernel[grid](
172 mat1,
173 mat2,
174 bias,
175 out,
176 alpha,
177 beta,
178 N,
179 K,
180 mat1.stride(1),
181 mat2.stride(0),
182 mat2.stride(1),
183 bias.stride(1),
184 out.stride(1),
185 BLOCK_N=block_n,
186 BLOCK_K=block_k,
187 EVEN_K=(K % block_k == 0),
188 )
191def _launch_addmm_m1_transposed_rhs_kernel(mat1, mat2, bias, out, alpha, beta, N, K):
192 block_n, block_k = _select_addmm_m1_transposed_config(N, K)
193 grid = lambda META: (triton.cdiv(N, block_n),)
194 addmm_m1_transposed_rhs_kernel[grid](
195 mat1,
196 mat2,
197 bias,
198 out,
199 alpha,
200 beta,
201 N,
202 K,
203 mat1.stride(1),
204 mat2.stride(0),
205 mat2.stride(1),
206 bias.stride(1),
207 out.stride(1),
208 BLOCK_N=block_n,
209 BLOCK_K=block_k,
210 EVEN_K=(K % block_k == 0),
211 )
214# @libentry()
215@triton.autotune(
216 configs=runtime.get_tuned_config("addmm"),
217 key=["M", "N", "K"],
218)
219@triton.jit(do_not_specialize=["alpha", "beta"])
220def addmm_kernel(
221 a_ptr,
222 b_ptr,
223 i_ptr,
224 c_ptr,
225 alpha,
226 beta,
227 M,
228 N,
229 K,
230 stride_am,
231 stride_ak,
232 stride_bk,
233 stride_bn,
234 stride_im,
235 stride_in,
236 stride_cm,
237 stride_cn,
238 BLOCK_SIZE_M: tl.constexpr,
239 BLOCK_SIZE_N: tl.constexpr,
240 BLOCK_SIZE_K: tl.constexpr,
241):
242 pid_m = tle.program_id(0)
243 pid_n = tle.program_id(1)
245 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
246 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
247 offs_k = tl.arange(0, BLOCK_SIZE_K)
248 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
249 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
251 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
252 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
253 a = tl.load(
254 a_ptrs,
255 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
256 other=0.0,
257 )
258 b = tl.load(
259 b_ptrs,
260 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
261 other=0.0,
262 )
263 accumulator += tl.dot(a, b, allow_tf32=False)
264 a_ptrs += BLOCK_SIZE_K * stride_ak
265 b_ptrs += BLOCK_SIZE_K * stride_bk
266 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
267 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
268 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
269 if beta == 0:
270 c = (accumulator * alpha).to(c_ptr.dtype.element_ty)
271 else:
272 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
273 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
274 accumulator = accumulator * alpha + bias * beta
275 c = accumulator.to(bias.dtype)
276 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
277 tl.store(c_ptrs, c, mask=c_mask)
280def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
281 logging.debug("GEMS ADDMM")
282 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
283 assert broadcastable_to(
284 bias.shape, (mat1.shape[0], mat2.shape[1])
285 ), "Incompatible input shape"
286 M, K = mat1.shape
287 _, N = mat2.shape
289 if mat1.stride(0) > 1 and mat1.stride(1) > 1:
290 mat1 = mat1.contiguous()
291 if mat2.stride(0) > 1 and mat2.stride(1) > 1:
292 mat2 = mat2.contiguous()
293 out_shape = (M, N)
294 bias = bias.broadcast_to(out_shape)
296 if M == 1 and _use_addmm_m1_fastpath_shape(N, K):
297 use_fp32_m1 = (
298 mat1.dtype is torch.bfloat16
299 or mat2.dtype is torch.bfloat16
300 or bias.dtype is torch.bfloat16
301 )
302 # BF16 masked_load on v8bf16 is not supported in AArch64 LLVM
303 # backend (fatal "Cannot select" error in addmm_m1_kernel bias
304 # tl.load). Cast all bf16 inputs to fp32 — matches the generic
305 # kernel path below.
306 mat1_kernel = mat1.to(torch.float32) if use_fp32_m1 else mat1
307 mat2_kernel = mat2.to(torch.float32) if use_fp32_m1 else mat2
308 bias_kernel = bias.to(torch.float32) if use_fp32_m1 else bias
309 out_kernel = torch.empty(
310 out_shape,
311 device=mat1.device,
312 dtype=(torch.float32 if use_fp32_m1 else mat1.dtype),
313 )
314 if _is_rhs_transposed_layout(
315 mat2_kernel
316 ) and _use_addmm_m1_transposed_fastpath_shape(N, K):
317 _launch_addmm_m1_transposed_rhs_kernel(
318 mat1_kernel, mat2_kernel, bias_kernel, out_kernel, alpha, beta, N, K
319 )
320 else:
321 _launch_addmm_m1_kernel(
322 mat1_kernel, mat2_kernel, bias_kernel, out_kernel, alpha, beta, N, K
323 )
324 return out_kernel.to(mat1.dtype) if use_fp32_m1 else out_kernel
326 use_fp32_generic = (
327 mat1.dtype is torch.bfloat16
328 or mat2.dtype is torch.bfloat16
329 or bias.dtype is torch.bfloat16
330 )
331 # Always cast bf16 to fp32 for the generic kernel: masked_load on bf16
332 # (v8bf16) is not supported in the AArch64 LLVM backend and causes a
333 # fatal "Cannot select" error. The M=1 fastpath handles bf16 the same way.
334 mat1_kernel = mat1.to(torch.float32) if use_fp32_generic else mat1
335 mat2_kernel = mat2.to(torch.float32) if use_fp32_generic else mat2
336 bias_kernel = bias.to(torch.float32) if use_fp32_generic else bias
337 out = torch.empty(
338 out_shape,
339 device=mat1.device,
340 dtype=(torch.float32 if use_fp32_generic else mat1.dtype),
341 )
342 grid = lambda META: (
343 triton.cdiv(M, META["BLOCK_SIZE_M"]),
344 triton.cdiv(N, META["BLOCK_SIZE_N"]),
345 )
346 addmm_kernel[grid](
347 mat1_kernel,
348 mat2_kernel,
349 bias_kernel,
350 out,
351 alpha,
352 beta,
353 M,
354 N,
355 K,
356 mat1_kernel.stride(0),
357 mat1_kernel.stride(1),
358 mat2_kernel.stride(0),
359 mat2_kernel.stride(1),
360 bias_kernel.stride(0),
361 bias_kernel.stride(1),
362 out.stride(0),
363 out.stride(1),
364 )
365 return out.to(mat1.dtype) if use_fp32_generic else out
368def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None):
369 logging.debug("GEMS ADDMM_OUT")
370 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
371 M, K = mat1.shape
372 _, N = mat2.shape
374 if out is None:
375 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
376 else:
377 assert out.shape == (M, N), "Incompatible output shape"
379 assert broadcastable_to(bias.shape, out.shape), "Incompatible input shape"
381 if mat1.stride(0) > 1 and mat1.stride(1) > 1:
382 mat1 = mat1.contiguous()
383 if mat2.stride(0) > 1 and mat2.stride(1) > 1:
384 mat2 = mat2.contiguous()
385 bias = bias.broadcast_to(out.shape)
387 if M == 1 and _use_addmm_m1_fastpath_shape(N, K):
388 bias_kernel = bias
389 use_fp32_m1 = (
390 mat1.dtype is torch.bfloat16
391 or mat2.dtype is torch.bfloat16
392 or bias.dtype is torch.bfloat16
393 )
394 out_kernel = (
395 torch.empty(out.shape, device=out.device, dtype=torch.float32)
396 if use_fp32_m1
397 else out
398 )
399 if _is_rhs_transposed_layout(mat2) and _use_addmm_m1_transposed_fastpath_shape(
400 N, K
401 ):
402 _launch_addmm_m1_transposed_rhs_kernel(
403 mat1, mat2, bias_kernel, out_kernel, alpha, beta, N, K
404 )
405 else:
406 _launch_addmm_m1_kernel(
407 mat1, mat2, bias_kernel, out_kernel, alpha, beta, N, K
408 )
409 if use_fp32_m1:
410 out.copy_(out_kernel.to(out.dtype))
411 return out
413 use_fp32_generic = (
414 mat1.dtype is torch.bfloat16
415 or mat2.dtype is torch.bfloat16
416 or bias.dtype is torch.bfloat16
417 )
418 # Always cast bf16 to fp32: see comment in addmm() above.
419 mat1_kernel = mat1.to(torch.float32) if use_fp32_generic else mat1
420 mat2_kernel = mat2.to(torch.float32) if use_fp32_generic else mat2
421 bias_kernel = bias.to(torch.float32) if use_fp32_generic else bias
422 out_kernel = (
423 torch.empty(out.shape, device=out.device, dtype=torch.float32)
424 if use_fp32_generic
425 else out
426 )
427 grid = lambda META: (
428 triton.cdiv(M, META["BLOCK_SIZE_M"]),
429 triton.cdiv(N, META["BLOCK_SIZE_N"]),
430 )
431 addmm_kernel[grid](
432 mat1_kernel,
433 mat2_kernel,
434 bias_kernel,
435 out_kernel,
436 alpha,
437 beta,
438 M,
439 N,
440 K,
441 mat1_kernel.stride(0),
442 mat1_kernel.stride(1),
443 mat2_kernel.stride(0),
444 mat2_kernel.stride(1),
445 bias_kernel.stride(0),
446 bias_kernel.stride(1),
447 out_kernel.stride(0),
448 out_kernel.stride(1),
449 )
450 if use_fp32_generic:
451 out.copy_(out_kernel.to(out.dtype))
452 return out