Coverage for src/flag_gems/runtime/backend/_arm/ops/mm.py: 0%

329 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import os 

3from collections import OrderedDict 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11MM_GENERIC_CONFIG_TABLE = ( 

12 # Decode-like long vocab projection prefers narrower N tiles. 

13 {"m_max": 1, "n_min": 65536, "k_min": 0, "config": (4, 16, 8)}, 

14 # Batched decode/prefill small-M cases with bf16-direct inputs. 

15 # BM=4, BN=8 is optimal for M=2-4 (1.08-1.19x vs native bf16 on ARM). 

16 {"m_max": 4, "n_min": 2048, "k_min": 0, "config": (4, 8, 8)}, 

17 # Prefill with large K: use larger BLOCK_K to reduce loop iterations. 

18 {"m_max": 8, "n_min": 0, "k_min": 2048, "config": (8, 8, 32)}, 

19 {"m_max": 8, "n_min": 2048, "k_min": 0, "config": (8, 8, 8)}, 

20 {"m_max": 8, "n_min": 0, "k_min": 0, "config": (8, 8, 8)}, 

21 # Prefill M>8: (64,32,32) benchmarked as best on CIX P1 (2026-03-07). 

22 # Triton BF16 prefill is still ~3x slower than ATen BFMMLA — fundamental 

23 # limit of Triton not emitting BFMMLA for tl.dot(bf16,bf16). Larger tiles 

24 # reduce overhead vs (8,8,8) default but cannot close the BFMMLA gap. 

25 {"m_max": None, "n_min": 0, "k_min": 0, "config": (64, 32, 32)}, 

26) 

27 

28MM_M1_CONFIG_TABLE = ( 

29 # Keep very large vocab projection on the generic kernel. 

30 {"n_min": 65536, "k_min": 0, "config": None}, 

31 # Qwen3-4B gate/up (N=9728, K=2560): BN=64 BK=16 is 9% faster. 

32 # K≥2560 threshold avoids regressing 1.7B (K=2048) shapes. 

33 {"n_min": 4096, "k_min": 2560, "config": (64, 16)}, 

34 {"n_min": 2048, "k_min": 0, "config": (32, 8)}, 

35 # Small N (e.g. k/v_proj N=128): use smaller BLOCK_N for better efficiency. 

36 {"n_min": 256, "k_min": 3072, "config": (128, 8)}, 

37 {"n_min": 256, "k_min": 2048, "config": (32, 16)}, 

38 {"n_min": 256, "k_min": 0, "config": (64, 8)}, 

39 # N < 256: skip M1 fastpath, fall through to generic kernel. 

40) 

41 

42MM_M1_TRANSPOSED_CONFIG_TABLE = ( 

43 # Large vocab projection (lm_head N~=152k): BN=2 for fine-grained OMP 

44 # load balancing; BK=64 fills a full 64-byte cache line per K-step. 

45 # Tuned on CIX P1 aarch64 (2026-03-04): 30ms vs ATen 65ms (2.17x faster). 

46 {"n_min": 65536, "k_min": 0, "k_max": 1536, "config": (2, 64)}, 

47 {"n_min": 2048, "k_min": 0, "k_max": 1536, "config": (4, 64)}, 

48 {"n_min": 0, "k_min": 2048, "config": (4, 64)}, 

49 {"n_min": 0, "k_min": 0, "config": (4, 64)}, 

50) 

51 

52_MM_PREPACK_CACHE = OrderedDict() 

53_MM_PREPACK_CACHE_BYTES = 0 

54_MM_FP32_CAST_CACHE = OrderedDict() 

55_MM_FP32_CAST_CACHE_BYTES = 0 

56 

57 

58@triton.jit 

59def mm_kernel( 

60 A, 

61 B, 

62 C, 

63 M, 

64 N, 

65 K, 

66 stride_am, 

67 stride_ak, 

68 stride_bk, 

69 stride_bn, 

70 stride_cm, 

71 stride_cn, 

72 dot_out_dtype: tl.constexpr, 

73 BLOCK_M: tl.constexpr, 

74 BLOCK_N: tl.constexpr, 

75 BLOCK_K: tl.constexpr, 

76 GROUP_M: tl.constexpr, 

77 SPLIT_K: tl.constexpr, 

78 EVEN_K: tl.constexpr, 

79): 

80 # matrix multiplication 

81 pid = tle.program_id(0) 

82 pid_z = tle.program_id(1) 

83 grid_m = tl.cdiv(M, BLOCK_M) 

84 grid_n = tl.cdiv(N, BLOCK_N) 

85 # re-order program ID for better L2 performance 

86 width = GROUP_M * grid_n 

87 group_id = pid // width 

88 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

89 pid_m = group_id * GROUP_M + (pid % group_size) 

90 pid_n = (pid % width) // (group_size) 

91 # do matrix multiplication 

92 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

93 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

94 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 

95 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 

96 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

97 # pointers 

98 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 

99 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 

100 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

101 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 

102 if EVEN_K: 

103 a = tl.load(A) 

104 b = tl.load(B) 

105 else: 

106 k_remaining = K - k * (BLOCK_K * SPLIT_K) 

107 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) 

108 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) 

109 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) 

110 if a.dtype != b.dtype: 

111 a = a.to(C.dtype.element_ty) 

112 b = b.to(C.dtype.element_ty) 

113 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

114 A += BLOCK_K * SPLIT_K * stride_ak 

115 B += BLOCK_K * SPLIT_K * stride_bk 

116 acc = acc.to(C.dtype.element_ty) 

117 # rematerialize rm and rn to save registers 

118 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

119 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

120 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

121 mask = (rm < M)[:, None] & (rn < N)[None, :] 

122 # handles write-back with reduction-splitting 

123 if SPLIT_K == 1: 

124 tl.store(C, acc, mask=mask) 

125 else: 

126 tl.atomic_add(C, acc, mask=mask) 

127 

128 

129@triton.jit 

130def mm_m1_kernel( 

131 A, 

132 B, 

133 C, 

134 N, 

135 K, 

136 stride_ak, 

137 stride_bk, 

138 stride_bn, 

139 stride_cn, 

140 BLOCK_N: tl.constexpr, 

141 BLOCK_K: tl.constexpr, 

142 EVEN_K: tl.constexpr, 

143): 

144 pid_n = tle.program_id(0) 

145 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

146 rk = tl.arange(0, BLOCK_K) 

147 

148 a_ptr = A + rk * stride_ak 

149 b_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn 

150 acc = tl.zeros((BLOCK_N,), dtype=tl.float32) 

151 

152 for k in range(0, tl.cdiv(K, BLOCK_K)): 

153 if EVEN_K: 

154 a = tl.load(a_ptr) 

155 b = tl.load(b_ptr) 

156 else: 

157 k_remaining = K - k * BLOCK_K 

158 a = tl.load(a_ptr, mask=rk < k_remaining, other=0.0) 

159 b = tl.load( 

160 b_ptr, 

161 mask=(rk[:, None] < k_remaining) & (rn[None, :] < N), 

162 other=0.0, 

163 ) 

164 

165 if a.dtype != b.dtype: 

166 a = a.to(C.dtype.element_ty) 

167 b = b.to(C.dtype.element_ty) 

168 

169 acc += tl.sum(b * a[:, None], axis=0) 

170 a_ptr += BLOCK_K * stride_ak 

171 b_ptr += BLOCK_K * stride_bk 

172 

173 c_ptr = C + rn * stride_cn 

174 tl.store(c_ptr, acc.to(C.dtype.element_ty), mask=rn < N) 

175 

176 

177@triton.jit 

178def mm_m1_transposed_rhs_kernel( 

179 A, 

180 B, 

181 C, 

182 N, 

183 K, 

184 stride_ak, 

185 stride_bk, 

186 stride_bn, 

187 stride_cn, 

188 BLOCK_N: tl.constexpr, 

189 BLOCK_K: tl.constexpr, 

190 EVEN_K: tl.constexpr, 

191): 

192 pid_n = tle.program_id(0) 

193 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

194 rk = tl.arange(0, BLOCK_K) 

195 

196 a_ptr = A + rk * stride_ak 

197 # For transposed RHS views (stride_bk == 1), load [BLOCK_N, BLOCK_K] 

198 # so the K dimension is contiguous in memory. 

199 bt_ptr = B + rn[:, None] * stride_bn + rk[None, :] * stride_bk 

200 acc = tl.zeros((BLOCK_N,), dtype=tl.float32) 

201 

202 for k in range(0, tl.cdiv(K, BLOCK_K)): 

203 if EVEN_K: 

204 a = tl.load(a_ptr) 

205 bt = tl.load(bt_ptr, mask=rn[:, None] < N, other=0.0) 

206 else: 

207 k_remaining = K - k * BLOCK_K 

208 a = tl.load(a_ptr, mask=rk < k_remaining, other=0.0) 

209 bt = tl.load( 

210 bt_ptr, 

211 mask=(rn[:, None] < N) & (rk[None, :] < k_remaining), 

212 other=0.0, 

213 ) 

214 

215 a_fp = a.to(tl.float32) 

216 bt_fp = bt.to(tl.float32) 

217 acc += tl.sum(bt_fp * a_fp[None, :], axis=1) 

218 a_ptr += BLOCK_K * stride_ak 

219 bt_ptr += BLOCK_K * stride_bk 

220 

221 c_ptr = C + rn * stride_cn 

222 tl.store(c_ptr, acc.to(C.dtype.element_ty), mask=rn < N) 

223 

224 

225_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

226 

227 

228def get_higher_dtype(a, b): 

229 if a is b: 

230 return a 

231 

232 assert a in _ordered_datatypes 

233 assert b in _ordered_datatypes 

234 

235 for d in _ordered_datatypes: 

236 if a is d: 

237 return b 

238 if b is d: 

239 return a 

240 

241 

242def _match_mnk_rule(M, N, K, rule): 

243 m_max = rule.get("m_max") 

244 n_min = rule.get("n_min", 0) 

245 k_min = rule.get("k_min", 0) 

246 if m_max is not None and M > m_max: 

247 return False 

248 if N < n_min: 

249 return False 

250 if K < k_min: 

251 return False 

252 return True 

253 

254 

255def _select_mm_config(M, N, K): 

256 for rule in MM_GENERIC_CONFIG_TABLE: 

257 if _match_mnk_rule(M, N, K, rule): 

258 return rule["config"] 

259 return 8, 8, 8 

260 

261 

262def _select_mm_m1_config(N, K): 

263 for rule in MM_M1_CONFIG_TABLE: 

264 if N >= rule.get("n_min", 0) and K >= rule.get("k_min", 0): 

265 return rule["config"] 

266 # No matching rule (e.g. N < 256): skip M1 fastpath 

267 return None 

268 

269 

270def _select_mm_m1_transposed_config(N, K): 

271 for rule in MM_M1_TRANSPOSED_CONFIG_TABLE: 

272 k_max = rule.get("k_max") 

273 if ( 

274 N >= rule.get("n_min", 0) 

275 and K >= rule.get("k_min", 0) 

276 and (k_max is None or K <= k_max) 

277 ): 

278 return rule["config"] 

279 return 64, 8 

280 

281 

282def _m1_fastpath_enabled(): 

283 return os.getenv("FLAGGEMS_ARM_M1_FASTPATH", "1").lower() in ("1", "true", "on") 

284 

285 

286def _m1_transposed_fastpath_enabled(): 

287 return os.getenv("FLAGGEMS_ARM_M1_TRANSPOSED_FASTPATH", "1").lower() in ( 

288 "1", 

289 "true", 

290 "on", 

291 ) 

292 

293 

294def _use_m1_transposed_fastpath_shape(N, K): 

295 # Tiny matrices can hit unstable LLVM lowering on ARM cpu backend for this 

296 # specialized kernel; keep generic path for those shapes. 

297 return N >= 256 and K >= 256 

298 

299 

300def _mm_prepack_enabled(): 

301 return os.getenv("FLAGGEMS_ARM_MM_PREPACK", "0").lower() in ("1", "true", "on") 

302 

303 

304def _get_env_int(name, default): 

305 try: 

306 return int(os.getenv(name, str(default))) 

307 except (TypeError, ValueError): 

308 return default 

309 

310 

311def _tensor_nbytes(t): 

312 return int(t.numel()) * int(t.element_size()) 

313 

314 

315def _is_rhs_transposed_layout(rhs): 

316 if rhs.ndim != 2: 

317 return False 

318 # Typical weight.t() view: stride(0) == 1, stride(1) == K. 

319 return rhs.stride(0) == 1 and rhs.stride(1) >= rhs.shape[0] 

320 

321 

322def _prepack_key(rhs): 

323 return ( 

324 int(rhs.data_ptr()), 

325 tuple(rhs.shape), 

326 tuple(rhs.stride()), 

327 str(rhs.dtype), 

328 str(rhs.device), 

329 ) 

330 

331 

332def _maybe_get_prepacked_rhs(rhs): 

333 global _MM_PREPACK_CACHE_BYTES 

334 if not _mm_prepack_enabled(): 

335 return None 

336 

337 max_bytes = max(_get_env_int("FLAGGEMS_ARM_MM_PREPACK_MAX_BYTES", 0), 0) 

338 if max_bytes <= 0: 

339 return None 

340 

341 max_tensor_bytes = max( 

342 _get_env_int("FLAGGEMS_ARM_MM_PREPACK_MAX_TENSOR_BYTES", 8 * 1024 * 1024), 0 

343 ) 

344 rhs_bytes = _tensor_nbytes(rhs) 

345 if max_tensor_bytes > 0 and rhs_bytes > max_tensor_bytes: 

346 return None 

347 if rhs_bytes > max_bytes: 

348 return None 

349 

350 key = _prepack_key(rhs) 

351 packed = _MM_PREPACK_CACHE.get(key) 

352 if packed is not None: 

353 _MM_PREPACK_CACHE.move_to_end(key) 

354 return packed 

355 

356 packed = rhs.contiguous() 

357 packed_bytes = _tensor_nbytes(packed) 

358 max_entries = max(_get_env_int("FLAGGEMS_ARM_MM_PREPACK_MAX_ENTRIES", 32), 1) 

359 while _MM_PREPACK_CACHE and ( 

360 _MM_PREPACK_CACHE_BYTES + packed_bytes > max_bytes 

361 or len(_MM_PREPACK_CACHE) >= max_entries 

362 ): 

363 _, evicted = _MM_PREPACK_CACHE.popitem(last=False) 

364 _MM_PREPACK_CACHE_BYTES -= _tensor_nbytes(evicted) 

365 

366 if packed_bytes > max_bytes: 

367 return None 

368 

369 _MM_PREPACK_CACHE[key] = packed 

370 _MM_PREPACK_CACHE_BYTES += packed_bytes 

371 return packed 

372 

373 

374def _mm_fp32_cast_cache_enabled(): 

375 return os.getenv("FLAGGEMS_ARM_MM_FP32_CAST_CACHE", "1").lower() in ( 

376 "1", 

377 "true", 

378 "on", 

379 ) 

380 

381 

382def _fp32_cast_key(t): 

383 return ( 

384 int(t.data_ptr()), 

385 tuple(t.shape), 

386 tuple(t.stride()), 

387 int(getattr(t, "_version", 0)), 

388 str(t.dtype), 

389 str(t.device), 

390 ) 

391 

392 

393def _maybe_get_cached_fp32(t): 

394 global _MM_FP32_CAST_CACHE_BYTES 

395 if not _mm_fp32_cast_cache_enabled(): 

396 return t.to(torch.float32) 

397 if t.dtype is not torch.bfloat16: 

398 return t.to(torch.float32) 

399 if t.requires_grad: 

400 return t.to(torch.float32) 

401 

402 min_numel = max(_get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MIN_NUMEL", 4096), 0) 

403 if t.numel() < min_numel: 

404 return t.to(torch.float32) 

405 

406 max_bytes = max(_get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MAX_BYTES", 2**31), 0) 

407 if max_bytes <= 0: 

408 return t.to(torch.float32) 

409 

410 key = _fp32_cast_key(t) 

411 cached = _MM_FP32_CAST_CACHE.get(key) 

412 if cached is not None: 

413 _MM_FP32_CAST_CACHE.move_to_end(key) 

414 return cached 

415 

416 fp32_t = t.to(torch.float32) 

417 fp32_bytes = _tensor_nbytes(fp32_t) 

418 max_tensor_bytes = max( 

419 _get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MAX_TENSOR_BYTES", 2**30), 0 

420 ) 

421 if ( 

422 max_tensor_bytes > 0 and fp32_bytes > max_tensor_bytes 

423 ) or fp32_bytes > max_bytes: 

424 return fp32_t 

425 

426 max_entries = max(_get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MAX_ENTRIES", 64), 1) 

427 while _MM_FP32_CAST_CACHE and ( 

428 _MM_FP32_CAST_CACHE_BYTES + fp32_bytes > max_bytes 

429 or len(_MM_FP32_CAST_CACHE) >= max_entries 

430 ): 

431 _, evicted = _MM_FP32_CAST_CACHE.popitem(last=False) 

432 _MM_FP32_CAST_CACHE_BYTES -= _tensor_nbytes(evicted) 

433 

434 if fp32_bytes > max_bytes: 

435 return fp32_t 

436 

437 _MM_FP32_CAST_CACHE[key] = fp32_t 

438 _MM_FP32_CAST_CACHE_BYTES += fp32_bytes 

439 return fp32_t 

440 

441 

442def _launch_mm_m1_kernel(a, b, c, N, K): 

443 m1_cfg = _select_mm_m1_config(N, K) 

444 if m1_cfg is None: 

445 return False 

446 BLOCK_N, BLOCK_K = m1_cfg 

447 EVEN_K = K % BLOCK_K == 0 

448 grid = lambda META: (triton.cdiv(N, BLOCK_N),) 

449 mm_m1_kernel[grid]( 

450 a, 

451 b, 

452 c, 

453 N, 

454 K, 

455 a.stride(1), 

456 b.stride(0), 

457 b.stride(1), 

458 c.stride(1), 

459 BLOCK_N=BLOCK_N, 

460 BLOCK_K=BLOCK_K, 

461 EVEN_K=EVEN_K, 

462 ) 

463 return True 

464 

465 

466def _launch_mm_m1_transposed_rhs_kernel(a, b, c, N, K): 

467 cfg = _select_mm_m1_transposed_config(N, K) 

468 if cfg is None: 

469 return False 

470 BLOCK_N, BLOCK_K = cfg 

471 EVEN_K = K % BLOCK_K == 0 

472 grid = lambda META: (triton.cdiv(N, BLOCK_N),) 

473 mm_m1_transposed_rhs_kernel[grid]( 

474 a, 

475 b, 

476 c, 

477 N, 

478 K, 

479 a.stride(1), 

480 b.stride(0), 

481 b.stride(1), 

482 c.stride(1), 

483 BLOCK_N=BLOCK_N, 

484 BLOCK_K=BLOCK_K, 

485 EVEN_K=EVEN_K, 

486 ) 

487 return True 

488 

489 

490def mm(a, b): 

491 logging.debug("GEMS MM") 

492 device = a.device 

493 # handle non-contiguous inputs if necessary 

494 if a.stride(0) > 1 and a.stride(1) > 1: 

495 a = a.contiguous() 

496 if b.stride(0) > 1 and b.stride(1) > 1: 

497 b = b.contiguous() 

498 # checks constraints 

499 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

500 M, K = a.shape 

501 _, N = b.shape 

502 # Small-shape fallback: use numpy BLAS for shapes where Triton has excessive 

503 # overhead (e.g., k/v_proj decode M=1, N=128, K=896). 

504 if N < 256 and M <= 8 and a.dtype in (torch.float32, torch.float64): 

505 import numpy as np 

506 

507 return torch.from_numpy(np.dot(a.detach().numpy(), b.detach().numpy())) 

508 # allocates output 

509 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

510 use_fp32_kernel = a.dtype is torch.bfloat16 or b.dtype is torch.bfloat16 

511 if M == 1: 

512 # Keep decode-path tensors in native dtype to avoid expensive full-tensor 

513 # bf16<->fp32 copies; kernels accumulate in fp32 internally. 

514 a_kernel = a 

515 b_kernel = b 

516 m1_out_fp32 = use_fp32_kernel 

517 c_kernel = torch.empty( 

518 (M, N), 

519 device=device, 

520 dtype=(torch.float32 if m1_out_fp32 else c_dtype), 

521 ) 

522 if ( 

523 _m1_transposed_fastpath_enabled() 

524 and _use_m1_transposed_fastpath_shape(N, K) 

525 and _is_rhs_transposed_layout(b_kernel) 

526 ): 

527 packed_rhs = _maybe_get_prepacked_rhs(b_kernel) 

528 if packed_rhs is not None and _launch_mm_m1_kernel( 

529 a_kernel, packed_rhs, c_kernel, N, K 

530 ): 

531 return c_kernel.to(c_dtype) if m1_out_fp32 else c_kernel 

532 if _launch_mm_m1_transposed_rhs_kernel(a_kernel, b_kernel, c_kernel, N, K): 

533 return c_kernel.to(c_dtype) if m1_out_fp32 else c_kernel 

534 if _m1_fastpath_enabled() and _launch_mm_m1_kernel( 

535 a_kernel, b_kernel, c_kernel, N, K 

536 ): 

537 return c_kernel.to(c_dtype) if m1_out_fp32 else c_kernel 

538 

539 # M>1 BF16: fallback to ATen native mm (ARM BFMMLA, 3-5x faster than Triton). 

540 # Cannot call torch.mm() here (infinite recursion via torch.library override). 

541 # torch.addmm(beta=0) bypasses aten::mm dispatch and uses ATen BFMMLA directly. 

542 if M > 1 and use_fp32_kernel: 

543 return torch.addmm( 

544 torch.empty(N, device=device, dtype=c_dtype), a, b, beta=0, alpha=1 

545 ) 

546 

547 # Generic path: for M>1 bf16, pass bf16 inputs directly to the Triton kernel 

548 # instead of casting to fp32 first. The kernel uses tl.dot(out_dtype=tl.float32) 

549 # for fp32 accumulation, so bf16 inputs are handled natively. This avoids the 

550 # expensive full-tensor bf16->fp32 conversion that was 2-4x slower than native. 

551 if use_fp32_kernel and M > 1: 

552 a_kernel = a 

553 b_kernel = b 

554 else: 

555 a_kernel = a.to(torch.float32) if use_fp32_kernel else a 

556 b_kernel = _maybe_get_cached_fp32(b) if use_fp32_kernel else b 

557 c_kernel = torch.empty( 

558 (M, N), 

559 device=device, 

560 dtype=(torch.float32 if use_fp32_kernel else c_dtype), 

561 ) 

562 

563 BLOCK_M, BLOCK_N, BLOCK_K = _select_mm_config(M, N, K) 

564 EVEN_K = K % BLOCK_K == 0 

565 # launch kernel 

566 grid = lambda META: ( 

567 triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 

568 1, 

569 ) 

570 mm_kernel[grid]( 

571 a_kernel, 

572 b_kernel, 

573 c_kernel, 

574 M, 

575 N, 

576 K, 

577 a_kernel.stride(0), 

578 a_kernel.stride(1), 

579 b_kernel.stride(0), 

580 b_kernel.stride(1), 

581 c_kernel.stride(0), 

582 c_kernel.stride(1), 

583 dot_out_dtype=tl.float32, 

584 BLOCK_M=BLOCK_M, 

585 BLOCK_N=BLOCK_N, 

586 BLOCK_K=BLOCK_K, 

587 GROUP_M=8, 

588 SPLIT_K=1, 

589 EVEN_K=EVEN_K, 

590 ) 

591 return c_kernel.to(c_dtype) if use_fp32_kernel else c_kernel 

592 

593 

594def mm_out(a, b, *, out): 

595 logging.debug("GEMS MM_OUT") 

596 if a.stride(0) > 1 and a.stride(1) > 1: 

597 a = a.contiguous() 

598 if b.stride(0) > 1 and b.stride(1) > 1: 

599 b = b.contiguous() 

600 

601 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

602 M, K = a.shape 

603 _, N = b.shape 

604 assert out is not None, "out tensor is required" 

605 assert out.shape == (M, N), "incompatible out shape" 

606 use_fp32_kernel = a.dtype is torch.bfloat16 or b.dtype is torch.bfloat16 

607 if M == 1: 

608 a_kernel = a 

609 b_kernel = b 

610 m1_out_fp32 = use_fp32_kernel 

611 out_kernel = ( 

612 torch.empty((M, N), device=out.device, dtype=torch.float32) 

613 if m1_out_fp32 

614 else out 

615 ) 

616 if ( 

617 _m1_transposed_fastpath_enabled() 

618 and _use_m1_transposed_fastpath_shape(N, K) 

619 and _is_rhs_transposed_layout(b_kernel) 

620 ): 

621 packed_rhs = _maybe_get_prepacked_rhs(b_kernel) 

622 if packed_rhs is not None and _launch_mm_m1_kernel( 

623 a_kernel, packed_rhs, out_kernel, N, K 

624 ): 

625 if m1_out_fp32: 

626 out.copy_(out_kernel.to(out.dtype)) 

627 return out 

628 if _launch_mm_m1_transposed_rhs_kernel( 

629 a_kernel, b_kernel, out_kernel, N, K 

630 ): 

631 if m1_out_fp32: 

632 out.copy_(out_kernel.to(out.dtype)) 

633 return out 

634 if _m1_fastpath_enabled() and _launch_mm_m1_kernel( 

635 a_kernel, b_kernel, out_kernel, N, K 

636 ): 

637 if m1_out_fp32: 

638 out.copy_(out_kernel.to(out.dtype)) 

639 return out 

640 

641 # M>1 BF16: fallback to ATen native mm (see mm() for rationale). 

642 if M > 1 and use_fp32_kernel: 

643 torch.addmm( 

644 torch.empty(N, device=out.device, dtype=out.dtype), 

645 a, 

646 b, 

647 beta=0, 

648 alpha=1, 

649 out=out, 

650 ) 

651 return out 

652 

653 # For M>1 bf16, pass bf16 inputs directly to Triton kernel (see mm() comment). 

654 if use_fp32_kernel and M > 1: 

655 a_kernel = a 

656 b_kernel = b 

657 else: 

658 a_kernel = a.to(torch.float32) if use_fp32_kernel else a 

659 b_kernel = _maybe_get_cached_fp32(b) if use_fp32_kernel else b 

660 out_kernel = ( 

661 torch.empty((M, N), device=out.device, dtype=torch.float32) 

662 if use_fp32_kernel 

663 else out 

664 ) 

665 

666 BLOCK_M, BLOCK_N, BLOCK_K = _select_mm_config(M, N, K) 

667 EVEN_K = K % BLOCK_K == 0 

668 

669 grid = lambda META: ( 

670 triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 

671 1, 

672 ) 

673 mm_kernel[grid]( 

674 a_kernel, 

675 b_kernel, 

676 out_kernel, 

677 M, 

678 N, 

679 K, 

680 a_kernel.stride(0), 

681 a_kernel.stride(1), 

682 b_kernel.stride(0), 

683 b_kernel.stride(1), 

684 out_kernel.stride(0), 

685 out_kernel.stride(1), 

686 dot_out_dtype=tl.float32, 

687 BLOCK_M=BLOCK_M, 

688 BLOCK_N=BLOCK_N, 

689 BLOCK_K=BLOCK_K, 

690 GROUP_M=8, 

691 SPLIT_K=1, 

692 EVEN_K=EVEN_K, 

693 ) 

694 if use_fp32_kernel: 

695 out.copy_(out_kernel.to(out.dtype)) 

696 return out