Coverage for src/flag_gems/runtime/backend/_arm/ops/mm.py: 0%
329 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import os
3from collections import OrderedDict
5import torch
6import triton
7import triton.language as tl
9from flag_gems.utils import triton_lang_extension as tle
11MM_GENERIC_CONFIG_TABLE = (
12 # Decode-like long vocab projection prefers narrower N tiles.
13 {"m_max": 1, "n_min": 65536, "k_min": 0, "config": (4, 16, 8)},
14 # Batched decode/prefill small-M cases with bf16-direct inputs.
15 # BM=4, BN=8 is optimal for M=2-4 (1.08-1.19x vs native bf16 on ARM).
16 {"m_max": 4, "n_min": 2048, "k_min": 0, "config": (4, 8, 8)},
17 # Prefill with large K: use larger BLOCK_K to reduce loop iterations.
18 {"m_max": 8, "n_min": 0, "k_min": 2048, "config": (8, 8, 32)},
19 {"m_max": 8, "n_min": 2048, "k_min": 0, "config": (8, 8, 8)},
20 {"m_max": 8, "n_min": 0, "k_min": 0, "config": (8, 8, 8)},
21 # Prefill M>8: (64,32,32) benchmarked as best on CIX P1 (2026-03-07).
22 # Triton BF16 prefill is still ~3x slower than ATen BFMMLA — fundamental
23 # limit of Triton not emitting BFMMLA for tl.dot(bf16,bf16). Larger tiles
24 # reduce overhead vs (8,8,8) default but cannot close the BFMMLA gap.
25 {"m_max": None, "n_min": 0, "k_min": 0, "config": (64, 32, 32)},
26)
28MM_M1_CONFIG_TABLE = (
29 # Keep very large vocab projection on the generic kernel.
30 {"n_min": 65536, "k_min": 0, "config": None},
31 # Qwen3-4B gate/up (N=9728, K=2560): BN=64 BK=16 is 9% faster.
32 # K≥2560 threshold avoids regressing 1.7B (K=2048) shapes.
33 {"n_min": 4096, "k_min": 2560, "config": (64, 16)},
34 {"n_min": 2048, "k_min": 0, "config": (32, 8)},
35 # Small N (e.g. k/v_proj N=128): use smaller BLOCK_N for better efficiency.
36 {"n_min": 256, "k_min": 3072, "config": (128, 8)},
37 {"n_min": 256, "k_min": 2048, "config": (32, 16)},
38 {"n_min": 256, "k_min": 0, "config": (64, 8)},
39 # N < 256: skip M1 fastpath, fall through to generic kernel.
40)
42MM_M1_TRANSPOSED_CONFIG_TABLE = (
43 # Large vocab projection (lm_head N~=152k): BN=2 for fine-grained OMP
44 # load balancing; BK=64 fills a full 64-byte cache line per K-step.
45 # Tuned on CIX P1 aarch64 (2026-03-04): 30ms vs ATen 65ms (2.17x faster).
46 {"n_min": 65536, "k_min": 0, "k_max": 1536, "config": (2, 64)},
47 {"n_min": 2048, "k_min": 0, "k_max": 1536, "config": (4, 64)},
48 {"n_min": 0, "k_min": 2048, "config": (4, 64)},
49 {"n_min": 0, "k_min": 0, "config": (4, 64)},
50)
52_MM_PREPACK_CACHE = OrderedDict()
53_MM_PREPACK_CACHE_BYTES = 0
54_MM_FP32_CAST_CACHE = OrderedDict()
55_MM_FP32_CAST_CACHE_BYTES = 0
58@triton.jit
59def mm_kernel(
60 A,
61 B,
62 C,
63 M,
64 N,
65 K,
66 stride_am,
67 stride_ak,
68 stride_bk,
69 stride_bn,
70 stride_cm,
71 stride_cn,
72 dot_out_dtype: tl.constexpr,
73 BLOCK_M: tl.constexpr,
74 BLOCK_N: tl.constexpr,
75 BLOCK_K: tl.constexpr,
76 GROUP_M: tl.constexpr,
77 SPLIT_K: tl.constexpr,
78 EVEN_K: tl.constexpr,
79):
80 # matrix multiplication
81 pid = tle.program_id(0)
82 pid_z = tle.program_id(1)
83 grid_m = tl.cdiv(M, BLOCK_M)
84 grid_n = tl.cdiv(N, BLOCK_N)
85 # re-order program ID for better L2 performance
86 width = GROUP_M * grid_n
87 group_id = pid // width
88 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
89 pid_m = group_id * GROUP_M + (pid % group_size)
90 pid_n = (pid % width) // (group_size)
91 # do matrix multiplication
92 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
93 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
94 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
95 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
96 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
97 # pointers
98 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
99 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
100 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
101 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
102 if EVEN_K:
103 a = tl.load(A)
104 b = tl.load(B)
105 else:
106 k_remaining = K - k * (BLOCK_K * SPLIT_K)
107 _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
108 a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
109 b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
110 if a.dtype != b.dtype:
111 a = a.to(C.dtype.element_ty)
112 b = b.to(C.dtype.element_ty)
113 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
114 A += BLOCK_K * SPLIT_K * stride_ak
115 B += BLOCK_K * SPLIT_K * stride_bk
116 acc = acc.to(C.dtype.element_ty)
117 # rematerialize rm and rn to save registers
118 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
119 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
120 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
121 mask = (rm < M)[:, None] & (rn < N)[None, :]
122 # handles write-back with reduction-splitting
123 if SPLIT_K == 1:
124 tl.store(C, acc, mask=mask)
125 else:
126 tl.atomic_add(C, acc, mask=mask)
129@triton.jit
130def mm_m1_kernel(
131 A,
132 B,
133 C,
134 N,
135 K,
136 stride_ak,
137 stride_bk,
138 stride_bn,
139 stride_cn,
140 BLOCK_N: tl.constexpr,
141 BLOCK_K: tl.constexpr,
142 EVEN_K: tl.constexpr,
143):
144 pid_n = tle.program_id(0)
145 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
146 rk = tl.arange(0, BLOCK_K)
148 a_ptr = A + rk * stride_ak
149 b_ptr = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
150 acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
152 for k in range(0, tl.cdiv(K, BLOCK_K)):
153 if EVEN_K:
154 a = tl.load(a_ptr)
155 b = tl.load(b_ptr)
156 else:
157 k_remaining = K - k * BLOCK_K
158 a = tl.load(a_ptr, mask=rk < k_remaining, other=0.0)
159 b = tl.load(
160 b_ptr,
161 mask=(rk[:, None] < k_remaining) & (rn[None, :] < N),
162 other=0.0,
163 )
165 if a.dtype != b.dtype:
166 a = a.to(C.dtype.element_ty)
167 b = b.to(C.dtype.element_ty)
169 acc += tl.sum(b * a[:, None], axis=0)
170 a_ptr += BLOCK_K * stride_ak
171 b_ptr += BLOCK_K * stride_bk
173 c_ptr = C + rn * stride_cn
174 tl.store(c_ptr, acc.to(C.dtype.element_ty), mask=rn < N)
177@triton.jit
178def mm_m1_transposed_rhs_kernel(
179 A,
180 B,
181 C,
182 N,
183 K,
184 stride_ak,
185 stride_bk,
186 stride_bn,
187 stride_cn,
188 BLOCK_N: tl.constexpr,
189 BLOCK_K: tl.constexpr,
190 EVEN_K: tl.constexpr,
191):
192 pid_n = tle.program_id(0)
193 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
194 rk = tl.arange(0, BLOCK_K)
196 a_ptr = A + rk * stride_ak
197 # For transposed RHS views (stride_bk == 1), load [BLOCK_N, BLOCK_K]
198 # so the K dimension is contiguous in memory.
199 bt_ptr = B + rn[:, None] * stride_bn + rk[None, :] * stride_bk
200 acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
202 for k in range(0, tl.cdiv(K, BLOCK_K)):
203 if EVEN_K:
204 a = tl.load(a_ptr)
205 bt = tl.load(bt_ptr, mask=rn[:, None] < N, other=0.0)
206 else:
207 k_remaining = K - k * BLOCK_K
208 a = tl.load(a_ptr, mask=rk < k_remaining, other=0.0)
209 bt = tl.load(
210 bt_ptr,
211 mask=(rn[:, None] < N) & (rk[None, :] < k_remaining),
212 other=0.0,
213 )
215 a_fp = a.to(tl.float32)
216 bt_fp = bt.to(tl.float32)
217 acc += tl.sum(bt_fp * a_fp[None, :], axis=1)
218 a_ptr += BLOCK_K * stride_ak
219 bt_ptr += BLOCK_K * stride_bk
221 c_ptr = C + rn * stride_cn
222 tl.store(c_ptr, acc.to(C.dtype.element_ty), mask=rn < N)
225_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
228def get_higher_dtype(a, b):
229 if a is b:
230 return a
232 assert a in _ordered_datatypes
233 assert b in _ordered_datatypes
235 for d in _ordered_datatypes:
236 if a is d:
237 return b
238 if b is d:
239 return a
242def _match_mnk_rule(M, N, K, rule):
243 m_max = rule.get("m_max")
244 n_min = rule.get("n_min", 0)
245 k_min = rule.get("k_min", 0)
246 if m_max is not None and M > m_max:
247 return False
248 if N < n_min:
249 return False
250 if K < k_min:
251 return False
252 return True
255def _select_mm_config(M, N, K):
256 for rule in MM_GENERIC_CONFIG_TABLE:
257 if _match_mnk_rule(M, N, K, rule):
258 return rule["config"]
259 return 8, 8, 8
262def _select_mm_m1_config(N, K):
263 for rule in MM_M1_CONFIG_TABLE:
264 if N >= rule.get("n_min", 0) and K >= rule.get("k_min", 0):
265 return rule["config"]
266 # No matching rule (e.g. N < 256): skip M1 fastpath
267 return None
270def _select_mm_m1_transposed_config(N, K):
271 for rule in MM_M1_TRANSPOSED_CONFIG_TABLE:
272 k_max = rule.get("k_max")
273 if (
274 N >= rule.get("n_min", 0)
275 and K >= rule.get("k_min", 0)
276 and (k_max is None or K <= k_max)
277 ):
278 return rule["config"]
279 return 64, 8
282def _m1_fastpath_enabled():
283 return os.getenv("FLAGGEMS_ARM_M1_FASTPATH", "1").lower() in ("1", "true", "on")
286def _m1_transposed_fastpath_enabled():
287 return os.getenv("FLAGGEMS_ARM_M1_TRANSPOSED_FASTPATH", "1").lower() in (
288 "1",
289 "true",
290 "on",
291 )
294def _use_m1_transposed_fastpath_shape(N, K):
295 # Tiny matrices can hit unstable LLVM lowering on ARM cpu backend for this
296 # specialized kernel; keep generic path for those shapes.
297 return N >= 256 and K >= 256
300def _mm_prepack_enabled():
301 return os.getenv("FLAGGEMS_ARM_MM_PREPACK", "0").lower() in ("1", "true", "on")
304def _get_env_int(name, default):
305 try:
306 return int(os.getenv(name, str(default)))
307 except (TypeError, ValueError):
308 return default
311def _tensor_nbytes(t):
312 return int(t.numel()) * int(t.element_size())
315def _is_rhs_transposed_layout(rhs):
316 if rhs.ndim != 2:
317 return False
318 # Typical weight.t() view: stride(0) == 1, stride(1) == K.
319 return rhs.stride(0) == 1 and rhs.stride(1) >= rhs.shape[0]
322def _prepack_key(rhs):
323 return (
324 int(rhs.data_ptr()),
325 tuple(rhs.shape),
326 tuple(rhs.stride()),
327 str(rhs.dtype),
328 str(rhs.device),
329 )
332def _maybe_get_prepacked_rhs(rhs):
333 global _MM_PREPACK_CACHE_BYTES
334 if not _mm_prepack_enabled():
335 return None
337 max_bytes = max(_get_env_int("FLAGGEMS_ARM_MM_PREPACK_MAX_BYTES", 0), 0)
338 if max_bytes <= 0:
339 return None
341 max_tensor_bytes = max(
342 _get_env_int("FLAGGEMS_ARM_MM_PREPACK_MAX_TENSOR_BYTES", 8 * 1024 * 1024), 0
343 )
344 rhs_bytes = _tensor_nbytes(rhs)
345 if max_tensor_bytes > 0 and rhs_bytes > max_tensor_bytes:
346 return None
347 if rhs_bytes > max_bytes:
348 return None
350 key = _prepack_key(rhs)
351 packed = _MM_PREPACK_CACHE.get(key)
352 if packed is not None:
353 _MM_PREPACK_CACHE.move_to_end(key)
354 return packed
356 packed = rhs.contiguous()
357 packed_bytes = _tensor_nbytes(packed)
358 max_entries = max(_get_env_int("FLAGGEMS_ARM_MM_PREPACK_MAX_ENTRIES", 32), 1)
359 while _MM_PREPACK_CACHE and (
360 _MM_PREPACK_CACHE_BYTES + packed_bytes > max_bytes
361 or len(_MM_PREPACK_CACHE) >= max_entries
362 ):
363 _, evicted = _MM_PREPACK_CACHE.popitem(last=False)
364 _MM_PREPACK_CACHE_BYTES -= _tensor_nbytes(evicted)
366 if packed_bytes > max_bytes:
367 return None
369 _MM_PREPACK_CACHE[key] = packed
370 _MM_PREPACK_CACHE_BYTES += packed_bytes
371 return packed
374def _mm_fp32_cast_cache_enabled():
375 return os.getenv("FLAGGEMS_ARM_MM_FP32_CAST_CACHE", "1").lower() in (
376 "1",
377 "true",
378 "on",
379 )
382def _fp32_cast_key(t):
383 return (
384 int(t.data_ptr()),
385 tuple(t.shape),
386 tuple(t.stride()),
387 int(getattr(t, "_version", 0)),
388 str(t.dtype),
389 str(t.device),
390 )
393def _maybe_get_cached_fp32(t):
394 global _MM_FP32_CAST_CACHE_BYTES
395 if not _mm_fp32_cast_cache_enabled():
396 return t.to(torch.float32)
397 if t.dtype is not torch.bfloat16:
398 return t.to(torch.float32)
399 if t.requires_grad:
400 return t.to(torch.float32)
402 min_numel = max(_get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MIN_NUMEL", 4096), 0)
403 if t.numel() < min_numel:
404 return t.to(torch.float32)
406 max_bytes = max(_get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MAX_BYTES", 2**31), 0)
407 if max_bytes <= 0:
408 return t.to(torch.float32)
410 key = _fp32_cast_key(t)
411 cached = _MM_FP32_CAST_CACHE.get(key)
412 if cached is not None:
413 _MM_FP32_CAST_CACHE.move_to_end(key)
414 return cached
416 fp32_t = t.to(torch.float32)
417 fp32_bytes = _tensor_nbytes(fp32_t)
418 max_tensor_bytes = max(
419 _get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MAX_TENSOR_BYTES", 2**30), 0
420 )
421 if (
422 max_tensor_bytes > 0 and fp32_bytes > max_tensor_bytes
423 ) or fp32_bytes > max_bytes:
424 return fp32_t
426 max_entries = max(_get_env_int("FLAGGEMS_ARM_MM_FP32_CAST_MAX_ENTRIES", 64), 1)
427 while _MM_FP32_CAST_CACHE and (
428 _MM_FP32_CAST_CACHE_BYTES + fp32_bytes > max_bytes
429 or len(_MM_FP32_CAST_CACHE) >= max_entries
430 ):
431 _, evicted = _MM_FP32_CAST_CACHE.popitem(last=False)
432 _MM_FP32_CAST_CACHE_BYTES -= _tensor_nbytes(evicted)
434 if fp32_bytes > max_bytes:
435 return fp32_t
437 _MM_FP32_CAST_CACHE[key] = fp32_t
438 _MM_FP32_CAST_CACHE_BYTES += fp32_bytes
439 return fp32_t
442def _launch_mm_m1_kernel(a, b, c, N, K):
443 m1_cfg = _select_mm_m1_config(N, K)
444 if m1_cfg is None:
445 return False
446 BLOCK_N, BLOCK_K = m1_cfg
447 EVEN_K = K % BLOCK_K == 0
448 grid = lambda META: (triton.cdiv(N, BLOCK_N),)
449 mm_m1_kernel[grid](
450 a,
451 b,
452 c,
453 N,
454 K,
455 a.stride(1),
456 b.stride(0),
457 b.stride(1),
458 c.stride(1),
459 BLOCK_N=BLOCK_N,
460 BLOCK_K=BLOCK_K,
461 EVEN_K=EVEN_K,
462 )
463 return True
466def _launch_mm_m1_transposed_rhs_kernel(a, b, c, N, K):
467 cfg = _select_mm_m1_transposed_config(N, K)
468 if cfg is None:
469 return False
470 BLOCK_N, BLOCK_K = cfg
471 EVEN_K = K % BLOCK_K == 0
472 grid = lambda META: (triton.cdiv(N, BLOCK_N),)
473 mm_m1_transposed_rhs_kernel[grid](
474 a,
475 b,
476 c,
477 N,
478 K,
479 a.stride(1),
480 b.stride(0),
481 b.stride(1),
482 c.stride(1),
483 BLOCK_N=BLOCK_N,
484 BLOCK_K=BLOCK_K,
485 EVEN_K=EVEN_K,
486 )
487 return True
490def mm(a, b):
491 logging.debug("GEMS MM")
492 device = a.device
493 # handle non-contiguous inputs if necessary
494 if a.stride(0) > 1 and a.stride(1) > 1:
495 a = a.contiguous()
496 if b.stride(0) > 1 and b.stride(1) > 1:
497 b = b.contiguous()
498 # checks constraints
499 assert a.shape[1] == b.shape[0], "incompatible dimensions"
500 M, K = a.shape
501 _, N = b.shape
502 # Small-shape fallback: use numpy BLAS for shapes where Triton has excessive
503 # overhead (e.g., k/v_proj decode M=1, N=128, K=896).
504 if N < 256 and M <= 8 and a.dtype in (torch.float32, torch.float64):
505 import numpy as np
507 return torch.from_numpy(np.dot(a.detach().numpy(), b.detach().numpy()))
508 # allocates output
509 c_dtype = get_higher_dtype(a.dtype, b.dtype)
510 use_fp32_kernel = a.dtype is torch.bfloat16 or b.dtype is torch.bfloat16
511 if M == 1:
512 # Keep decode-path tensors in native dtype to avoid expensive full-tensor
513 # bf16<->fp32 copies; kernels accumulate in fp32 internally.
514 a_kernel = a
515 b_kernel = b
516 m1_out_fp32 = use_fp32_kernel
517 c_kernel = torch.empty(
518 (M, N),
519 device=device,
520 dtype=(torch.float32 if m1_out_fp32 else c_dtype),
521 )
522 if (
523 _m1_transposed_fastpath_enabled()
524 and _use_m1_transposed_fastpath_shape(N, K)
525 and _is_rhs_transposed_layout(b_kernel)
526 ):
527 packed_rhs = _maybe_get_prepacked_rhs(b_kernel)
528 if packed_rhs is not None and _launch_mm_m1_kernel(
529 a_kernel, packed_rhs, c_kernel, N, K
530 ):
531 return c_kernel.to(c_dtype) if m1_out_fp32 else c_kernel
532 if _launch_mm_m1_transposed_rhs_kernel(a_kernel, b_kernel, c_kernel, N, K):
533 return c_kernel.to(c_dtype) if m1_out_fp32 else c_kernel
534 if _m1_fastpath_enabled() and _launch_mm_m1_kernel(
535 a_kernel, b_kernel, c_kernel, N, K
536 ):
537 return c_kernel.to(c_dtype) if m1_out_fp32 else c_kernel
539 # M>1 BF16: fallback to ATen native mm (ARM BFMMLA, 3-5x faster than Triton).
540 # Cannot call torch.mm() here (infinite recursion via torch.library override).
541 # torch.addmm(beta=0) bypasses aten::mm dispatch and uses ATen BFMMLA directly.
542 if M > 1 and use_fp32_kernel:
543 return torch.addmm(
544 torch.empty(N, device=device, dtype=c_dtype), a, b, beta=0, alpha=1
545 )
547 # Generic path: for M>1 bf16, pass bf16 inputs directly to the Triton kernel
548 # instead of casting to fp32 first. The kernel uses tl.dot(out_dtype=tl.float32)
549 # for fp32 accumulation, so bf16 inputs are handled natively. This avoids the
550 # expensive full-tensor bf16->fp32 conversion that was 2-4x slower than native.
551 if use_fp32_kernel and M > 1:
552 a_kernel = a
553 b_kernel = b
554 else:
555 a_kernel = a.to(torch.float32) if use_fp32_kernel else a
556 b_kernel = _maybe_get_cached_fp32(b) if use_fp32_kernel else b
557 c_kernel = torch.empty(
558 (M, N),
559 device=device,
560 dtype=(torch.float32 if use_fp32_kernel else c_dtype),
561 )
563 BLOCK_M, BLOCK_N, BLOCK_K = _select_mm_config(M, N, K)
564 EVEN_K = K % BLOCK_K == 0
565 # launch kernel
566 grid = lambda META: (
567 triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
568 1,
569 )
570 mm_kernel[grid](
571 a_kernel,
572 b_kernel,
573 c_kernel,
574 M,
575 N,
576 K,
577 a_kernel.stride(0),
578 a_kernel.stride(1),
579 b_kernel.stride(0),
580 b_kernel.stride(1),
581 c_kernel.stride(0),
582 c_kernel.stride(1),
583 dot_out_dtype=tl.float32,
584 BLOCK_M=BLOCK_M,
585 BLOCK_N=BLOCK_N,
586 BLOCK_K=BLOCK_K,
587 GROUP_M=8,
588 SPLIT_K=1,
589 EVEN_K=EVEN_K,
590 )
591 return c_kernel.to(c_dtype) if use_fp32_kernel else c_kernel
594def mm_out(a, b, *, out):
595 logging.debug("GEMS MM_OUT")
596 if a.stride(0) > 1 and a.stride(1) > 1:
597 a = a.contiguous()
598 if b.stride(0) > 1 and b.stride(1) > 1:
599 b = b.contiguous()
601 assert a.shape[1] == b.shape[0], "incompatible dimensions"
602 M, K = a.shape
603 _, N = b.shape
604 assert out is not None, "out tensor is required"
605 assert out.shape == (M, N), "incompatible out shape"
606 use_fp32_kernel = a.dtype is torch.bfloat16 or b.dtype is torch.bfloat16
607 if M == 1:
608 a_kernel = a
609 b_kernel = b
610 m1_out_fp32 = use_fp32_kernel
611 out_kernel = (
612 torch.empty((M, N), device=out.device, dtype=torch.float32)
613 if m1_out_fp32
614 else out
615 )
616 if (
617 _m1_transposed_fastpath_enabled()
618 and _use_m1_transposed_fastpath_shape(N, K)
619 and _is_rhs_transposed_layout(b_kernel)
620 ):
621 packed_rhs = _maybe_get_prepacked_rhs(b_kernel)
622 if packed_rhs is not None and _launch_mm_m1_kernel(
623 a_kernel, packed_rhs, out_kernel, N, K
624 ):
625 if m1_out_fp32:
626 out.copy_(out_kernel.to(out.dtype))
627 return out
628 if _launch_mm_m1_transposed_rhs_kernel(
629 a_kernel, b_kernel, out_kernel, N, K
630 ):
631 if m1_out_fp32:
632 out.copy_(out_kernel.to(out.dtype))
633 return out
634 if _m1_fastpath_enabled() and _launch_mm_m1_kernel(
635 a_kernel, b_kernel, out_kernel, N, K
636 ):
637 if m1_out_fp32:
638 out.copy_(out_kernel.to(out.dtype))
639 return out
641 # M>1 BF16: fallback to ATen native mm (see mm() for rationale).
642 if M > 1 and use_fp32_kernel:
643 torch.addmm(
644 torch.empty(N, device=out.device, dtype=out.dtype),
645 a,
646 b,
647 beta=0,
648 alpha=1,
649 out=out,
650 )
651 return out
653 # For M>1 bf16, pass bf16 inputs directly to Triton kernel (see mm() comment).
654 if use_fp32_kernel and M > 1:
655 a_kernel = a
656 b_kernel = b
657 else:
658 a_kernel = a.to(torch.float32) if use_fp32_kernel else a
659 b_kernel = _maybe_get_cached_fp32(b) if use_fp32_kernel else b
660 out_kernel = (
661 torch.empty((M, N), device=out.device, dtype=torch.float32)
662 if use_fp32_kernel
663 else out
664 )
666 BLOCK_M, BLOCK_N, BLOCK_K = _select_mm_config(M, N, K)
667 EVEN_K = K % BLOCK_K == 0
669 grid = lambda META: (
670 triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
671 1,
672 )
673 mm_kernel[grid](
674 a_kernel,
675 b_kernel,
676 out_kernel,
677 M,
678 N,
679 K,
680 a_kernel.stride(0),
681 a_kernel.stride(1),
682 b_kernel.stride(0),
683 b_kernel.stride(1),
684 out_kernel.stride(0),
685 out_kernel.stride(1),
686 dot_out_dtype=tl.float32,
687 BLOCK_M=BLOCK_M,
688 BLOCK_N=BLOCK_N,
689 BLOCK_K=BLOCK_K,
690 GROUP_M=8,
691 SPLIT_K=1,
692 EVEN_K=EVEN_K,
693 )
694 if use_fp32_kernel:
695 out.copy_(out_kernel.to(out.dtype))
696 return out