Coverage for src/flag_gems/runtime/backend/_arm/ops/addmm.py: 0%

181 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import broadcastable_to 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11ADDMM_M1_CONFIG_TABLE = ( 

12 {"n_min": 4096, "k_min": 0, "config": (64, 8)}, 

13 {"n_min": 2048, "k_min": 0, "config": (32, 16)}, 

14 {"n_min": 0, "k_min": 3072, "config": (16, 16)}, 

15 {"n_min": 0, "k_min": 0, "config": (8, 32)}, 

16) 

17 

18ADDMM_M1_TRANSPOSED_CONFIG_TABLE = ( 

19 # Tuned on CIX P1 aarch64 (2026-03-04): BK=64 fills a full cache line. 

20 {"n_min": 65536, "k_min": 0, "config": (2, 64)}, 

21 {"n_min": 2048, "k_min": 0, "config": (4, 64)}, 

22 {"n_min": 0, "k_min": 2048, "config": (4, 64)}, 

23 {"n_min": 0, "k_min": 0, "config": (4, 64)}, 

24) 

25 

26 

27def _select_addmm_m1_config(N, K): 

28 for rule in ADDMM_M1_CONFIG_TABLE: 

29 if N >= rule.get("n_min", 0) and K >= rule.get("k_min", 0): 

30 return rule["config"] 

31 return 8, 32 

32 

33 

34def _select_addmm_m1_transposed_config(N, K): 

35 for rule in ADDMM_M1_TRANSPOSED_CONFIG_TABLE: 

36 if N >= rule.get("n_min", 0) and K >= rule.get("k_min", 0): 

37 return rule["config"] 

38 return 8, 32 

39 

40 

41def _is_rhs_transposed_layout(rhs): 

42 if rhs.ndim != 2: 

43 return False 

44 return rhs.stride(0) == 1 and rhs.stride(1) >= rhs.shape[0] 

45 

46 

47def _use_addmm_m1_transposed_fastpath_shape(N, K): 

48 # Avoid unstable LLVM lowering for tiny matrices on ARM cpu backend. 

49 return N >= 256 and K >= 256 

50 

51 

52def _use_addmm_m1_fastpath_shape(N, K): 

53 return N >= 256 and K >= 256 

54 

55 

56@triton.jit(do_not_specialize=["alpha", "beta"]) 

57def addmm_m1_kernel( 

58 a_ptr, 

59 b_ptr, 

60 i_ptr, 

61 c_ptr, 

62 alpha, 

63 beta, 

64 N, 

65 K, 

66 stride_ak, 

67 stride_bk, 

68 stride_bn, 

69 stride_in, 

70 stride_cn, 

71 BLOCK_N: tl.constexpr, 

72 BLOCK_K: tl.constexpr, 

73 EVEN_K: tl.constexpr, 

74): 

75 pid_n = tle.program_id(0) 

76 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

77 rk = tl.arange(0, BLOCK_K) 

78 

79 a_ptrs = a_ptr + rk * stride_ak 

80 b_ptrs = b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn 

81 acc = tl.zeros((BLOCK_N,), dtype=tl.float32) 

82 

83 for k in range(0, tl.cdiv(K, BLOCK_K)): 

84 if EVEN_K: 

85 a = tl.load(a_ptrs) 

86 b = tl.load(b_ptrs) 

87 else: 

88 k_remaining = K - k * BLOCK_K 

89 a = tl.load(a_ptrs, mask=rk < k_remaining, other=0.0) 

90 b = tl.load( 

91 b_ptrs, 

92 mask=(rk[:, None] < k_remaining) & (rn[None, :] < N), 

93 other=0.0, 

94 ) 

95 

96 a_fp = a.to(tl.float32) 

97 b_fp = b.to(tl.float32) 

98 acc += tl.sum(b_fp * a_fp[:, None], axis=0) 

99 a_ptrs += BLOCK_K * stride_ak 

100 b_ptrs += BLOCK_K * stride_bk 

101 

102 if beta == 0: 

103 out = acc * alpha 

104 else: 

105 bias_ptrs = i_ptr + rn * stride_in 

106 bias = tl.load(bias_ptrs, mask=rn < N, other=0.0).to(tl.float32) 

107 out = acc * alpha + bias * beta 

108 c_ptrs = c_ptr + rn * stride_cn 

109 tl.store(c_ptrs, out.to(c_ptr.dtype.element_ty), mask=rn < N) 

110 

111 

112@triton.jit(do_not_specialize=["alpha", "beta"]) 

113def addmm_m1_transposed_rhs_kernel( 

114 a_ptr, 

115 b_ptr, 

116 i_ptr, 

117 c_ptr, 

118 alpha, 

119 beta, 

120 N, 

121 K, 

122 stride_ak, 

123 stride_bk, 

124 stride_bn, 

125 stride_in, 

126 stride_cn, 

127 BLOCK_N: tl.constexpr, 

128 BLOCK_K: tl.constexpr, 

129 EVEN_K: tl.constexpr, 

130): 

131 pid_n = tle.program_id(0) 

132 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

133 rk = tl.arange(0, BLOCK_K) 

134 

135 a_ptrs = a_ptr + rk * stride_ak 

136 bt_ptrs = b_ptr + rn[:, None] * stride_bn + rk[None, :] * stride_bk 

137 acc = tl.zeros((BLOCK_N,), dtype=tl.float32) 

138 

139 for k in range(0, tl.cdiv(K, BLOCK_K)): 

140 if EVEN_K: 

141 a = tl.load(a_ptrs) 

142 bt = tl.load(bt_ptrs, mask=rn[:, None] < N, other=0.0) 

143 else: 

144 k_remaining = K - k * BLOCK_K 

145 a = tl.load(a_ptrs, mask=rk < k_remaining, other=0.0) 

146 bt = tl.load( 

147 bt_ptrs, 

148 mask=(rn[:, None] < N) & (rk[None, :] < k_remaining), 

149 other=0.0, 

150 ) 

151 

152 a_fp = a.to(tl.float32) 

153 bt_fp = bt.to(tl.float32) 

154 acc += tl.sum(bt_fp * a_fp[None, :], axis=1) 

155 a_ptrs += BLOCK_K * stride_ak 

156 bt_ptrs += BLOCK_K * stride_bk 

157 

158 if beta == 0: 

159 out = acc * alpha 

160 else: 

161 bias_ptrs = i_ptr + rn * stride_in 

162 bias = tl.load(bias_ptrs, mask=rn < N, other=0.0).to(tl.float32) 

163 out = acc * alpha + bias * beta 

164 c_ptrs = c_ptr + rn * stride_cn 

165 tl.store(c_ptrs, out.to(c_ptr.dtype.element_ty), mask=rn < N) 

166 

167 

168def _launch_addmm_m1_kernel(mat1, mat2, bias, out, alpha, beta, N, K): 

169 block_n, block_k = _select_addmm_m1_config(N, K) 

170 grid = lambda META: (triton.cdiv(N, block_n),) 

171 addmm_m1_kernel[grid]( 

172 mat1, 

173 mat2, 

174 bias, 

175 out, 

176 alpha, 

177 beta, 

178 N, 

179 K, 

180 mat1.stride(1), 

181 mat2.stride(0), 

182 mat2.stride(1), 

183 bias.stride(1), 

184 out.stride(1), 

185 BLOCK_N=block_n, 

186 BLOCK_K=block_k, 

187 EVEN_K=(K % block_k == 0), 

188 ) 

189 

190 

191def _launch_addmm_m1_transposed_rhs_kernel(mat1, mat2, bias, out, alpha, beta, N, K): 

192 block_n, block_k = _select_addmm_m1_transposed_config(N, K) 

193 grid = lambda META: (triton.cdiv(N, block_n),) 

194 addmm_m1_transposed_rhs_kernel[grid]( 

195 mat1, 

196 mat2, 

197 bias, 

198 out, 

199 alpha, 

200 beta, 

201 N, 

202 K, 

203 mat1.stride(1), 

204 mat2.stride(0), 

205 mat2.stride(1), 

206 bias.stride(1), 

207 out.stride(1), 

208 BLOCK_N=block_n, 

209 BLOCK_K=block_k, 

210 EVEN_K=(K % block_k == 0), 

211 ) 

212 

213 

214# @libentry() 

215@triton.autotune( 

216 configs=runtime.get_tuned_config("addmm"), 

217 key=["M", "N", "K"], 

218) 

219@triton.jit(do_not_specialize=["alpha", "beta"]) 

220def addmm_kernel( 

221 a_ptr, 

222 b_ptr, 

223 i_ptr, 

224 c_ptr, 

225 alpha, 

226 beta, 

227 M, 

228 N, 

229 K, 

230 stride_am, 

231 stride_ak, 

232 stride_bk, 

233 stride_bn, 

234 stride_im, 

235 stride_in, 

236 stride_cm, 

237 stride_cn, 

238 BLOCK_SIZE_M: tl.constexpr, 

239 BLOCK_SIZE_N: tl.constexpr, 

240 BLOCK_SIZE_K: tl.constexpr, 

241): 

242 pid_m = tle.program_id(0) 

243 pid_n = tle.program_id(1) 

244 

245 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

246 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

247 offs_k = tl.arange(0, BLOCK_SIZE_K) 

248 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) 

249 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) 

250 

251 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) 

252 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

253 a = tl.load( 

254 a_ptrs, 

255 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), 

256 other=0.0, 

257 ) 

258 b = tl.load( 

259 b_ptrs, 

260 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), 

261 other=0.0, 

262 ) 

263 accumulator += tl.dot(a, b, allow_tf32=False) 

264 a_ptrs += BLOCK_SIZE_K * stride_ak 

265 b_ptrs += BLOCK_SIZE_K * stride_bk 

266 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

267 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

268 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) 

269 if beta == 0: 

270 c = (accumulator * alpha).to(c_ptr.dtype.element_ty) 

271 else: 

272 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :] 

273 bias = tl.load(i_ptrs, mask=c_mask, other=0.0) 

274 accumulator = accumulator * alpha + bias * beta 

275 c = accumulator.to(bias.dtype) 

276 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

277 tl.store(c_ptrs, c, mask=c_mask) 

278 

279 

280def addmm(bias, mat1, mat2, *, beta=1, alpha=1): 

281 logging.debug("GEMS ADDMM") 

282 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

283 assert broadcastable_to( 

284 bias.shape, (mat1.shape[0], mat2.shape[1]) 

285 ), "Incompatible input shape" 

286 M, K = mat1.shape 

287 _, N = mat2.shape 

288 

289 if mat1.stride(0) > 1 and mat1.stride(1) > 1: 

290 mat1 = mat1.contiguous() 

291 if mat2.stride(0) > 1 and mat2.stride(1) > 1: 

292 mat2 = mat2.contiguous() 

293 out_shape = (M, N) 

294 bias = bias.broadcast_to(out_shape) 

295 

296 if M == 1 and _use_addmm_m1_fastpath_shape(N, K): 

297 use_fp32_m1 = ( 

298 mat1.dtype is torch.bfloat16 

299 or mat2.dtype is torch.bfloat16 

300 or bias.dtype is torch.bfloat16 

301 ) 

302 # BF16 masked_load on v8bf16 is not supported in AArch64 LLVM 

303 # backend (fatal "Cannot select" error in addmm_m1_kernel bias 

304 # tl.load). Cast all bf16 inputs to fp32 — matches the generic 

305 # kernel path below. 

306 mat1_kernel = mat1.to(torch.float32) if use_fp32_m1 else mat1 

307 mat2_kernel = mat2.to(torch.float32) if use_fp32_m1 else mat2 

308 bias_kernel = bias.to(torch.float32) if use_fp32_m1 else bias 

309 out_kernel = torch.empty( 

310 out_shape, 

311 device=mat1.device, 

312 dtype=(torch.float32 if use_fp32_m1 else mat1.dtype), 

313 ) 

314 if _is_rhs_transposed_layout( 

315 mat2_kernel 

316 ) and _use_addmm_m1_transposed_fastpath_shape(N, K): 

317 _launch_addmm_m1_transposed_rhs_kernel( 

318 mat1_kernel, mat2_kernel, bias_kernel, out_kernel, alpha, beta, N, K 

319 ) 

320 else: 

321 _launch_addmm_m1_kernel( 

322 mat1_kernel, mat2_kernel, bias_kernel, out_kernel, alpha, beta, N, K 

323 ) 

324 return out_kernel.to(mat1.dtype) if use_fp32_m1 else out_kernel 

325 

326 use_fp32_generic = ( 

327 mat1.dtype is torch.bfloat16 

328 or mat2.dtype is torch.bfloat16 

329 or bias.dtype is torch.bfloat16 

330 ) 

331 # Always cast bf16 to fp32 for the generic kernel: masked_load on bf16 

332 # (v8bf16) is not supported in the AArch64 LLVM backend and causes a 

333 # fatal "Cannot select" error. The M=1 fastpath handles bf16 the same way. 

334 mat1_kernel = mat1.to(torch.float32) if use_fp32_generic else mat1 

335 mat2_kernel = mat2.to(torch.float32) if use_fp32_generic else mat2 

336 bias_kernel = bias.to(torch.float32) if use_fp32_generic else bias 

337 out = torch.empty( 

338 out_shape, 

339 device=mat1.device, 

340 dtype=(torch.float32 if use_fp32_generic else mat1.dtype), 

341 ) 

342 grid = lambda META: ( 

343 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

344 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

345 ) 

346 addmm_kernel[grid]( 

347 mat1_kernel, 

348 mat2_kernel, 

349 bias_kernel, 

350 out, 

351 alpha, 

352 beta, 

353 M, 

354 N, 

355 K, 

356 mat1_kernel.stride(0), 

357 mat1_kernel.stride(1), 

358 mat2_kernel.stride(0), 

359 mat2_kernel.stride(1), 

360 bias_kernel.stride(0), 

361 bias_kernel.stride(1), 

362 out.stride(0), 

363 out.stride(1), 

364 ) 

365 return out.to(mat1.dtype) if use_fp32_generic else out 

366 

367 

368def addmm_out(bias, mat1, mat2, *, beta=1, alpha=1, out=None): 

369 logging.debug("GEMS ADDMM_OUT") 

370 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions" 

371 M, K = mat1.shape 

372 _, N = mat2.shape 

373 

374 if out is None: 

375 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype) 

376 else: 

377 assert out.shape == (M, N), "Incompatible output shape" 

378 

379 assert broadcastable_to(bias.shape, out.shape), "Incompatible input shape" 

380 

381 if mat1.stride(0) > 1 and mat1.stride(1) > 1: 

382 mat1 = mat1.contiguous() 

383 if mat2.stride(0) > 1 and mat2.stride(1) > 1: 

384 mat2 = mat2.contiguous() 

385 bias = bias.broadcast_to(out.shape) 

386 

387 if M == 1 and _use_addmm_m1_fastpath_shape(N, K): 

388 bias_kernel = bias 

389 use_fp32_m1 = ( 

390 mat1.dtype is torch.bfloat16 

391 or mat2.dtype is torch.bfloat16 

392 or bias.dtype is torch.bfloat16 

393 ) 

394 out_kernel = ( 

395 torch.empty(out.shape, device=out.device, dtype=torch.float32) 

396 if use_fp32_m1 

397 else out 

398 ) 

399 if _is_rhs_transposed_layout(mat2) and _use_addmm_m1_transposed_fastpath_shape( 

400 N, K 

401 ): 

402 _launch_addmm_m1_transposed_rhs_kernel( 

403 mat1, mat2, bias_kernel, out_kernel, alpha, beta, N, K 

404 ) 

405 else: 

406 _launch_addmm_m1_kernel( 

407 mat1, mat2, bias_kernel, out_kernel, alpha, beta, N, K 

408 ) 

409 if use_fp32_m1: 

410 out.copy_(out_kernel.to(out.dtype)) 

411 return out 

412 

413 use_fp32_generic = ( 

414 mat1.dtype is torch.bfloat16 

415 or mat2.dtype is torch.bfloat16 

416 or bias.dtype is torch.bfloat16 

417 ) 

418 # Always cast bf16 to fp32: see comment in addmm() above. 

419 mat1_kernel = mat1.to(torch.float32) if use_fp32_generic else mat1 

420 mat2_kernel = mat2.to(torch.float32) if use_fp32_generic else mat2 

421 bias_kernel = bias.to(torch.float32) if use_fp32_generic else bias 

422 out_kernel = ( 

423 torch.empty(out.shape, device=out.device, dtype=torch.float32) 

424 if use_fp32_generic 

425 else out 

426 ) 

427 grid = lambda META: ( 

428 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

429 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

430 ) 

431 addmm_kernel[grid]( 

432 mat1_kernel, 

433 mat2_kernel, 

434 bias_kernel, 

435 out_kernel, 

436 alpha, 

437 beta, 

438 M, 

439 N, 

440 K, 

441 mat1_kernel.stride(0), 

442 mat1_kernel.stride(1), 

443 mat2_kernel.stride(0), 

444 mat2_kernel.stride(1), 

445 bias_kernel.stride(0), 

446 bias_kernel.stride(1), 

447 out_kernel.stride(0), 

448 out_kernel.stride(1), 

449 ) 

450 if use_fp32_generic: 

451 out.copy_(out_kernel.to(out.dtype)) 

452 return out