Coverage for src/flag_gems/ops/mm.py: 32%
283 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.mm_streamk import streamk_mm
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as ext
12from flag_gems.utils.device_info import get_device_capability, get_sm_count
13from flag_gems.utils.triton_version_utils import ( # noqa: F401
14 HAS_TLE,
15 HAS_TLE_DEVICE_MESH,
16 _triton_version_at_least,
17)
19if HAS_TLE_DEVICE_MESH:
20 import triton.experimental.tle.language as tle_exp
22 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]})
23else:
24 tle_exp = None
25 BLOCK_CLUSTER_MESH = None
27CACHE_USAGE_THRESHOLD = 0.8
28TLE_CLUSTER_SIZE = 2
29TLE_REMOTE_BM = 64
30TLE_REMOTE_BN = 256
31TLE_REMOTE_BK = 64
32TLE_REMOTE_NUM_WARPS = 8
33TLE_REMOTE_NUM_STAGES = 2
34TLE_REMOTE_A_SLOTS = 2
36logger = logging.getLogger(__name__)
39@triton.jit
40def prev_multiple_of(a, b):
41 # the largest x<a that x%b ==0
42 return tl.cdiv(a, b) * b - b
45@libentry()
46@libtuner(
47 configs=runtime.get_tuned_config("mm"),
48 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
49 key=["M", "N", "K", "stride_am", "stride_bk"],
50 strategy=["align32", "align32", "align32", "align32", "align32"],
51 warmup=5,
52 rep=10,
53)
54@triton.jit
55def mm_kernel_general(
56 A,
57 B,
58 C,
59 M,
60 N,
61 K,
62 stride_am,
63 stride_ak,
64 stride_bk,
65 stride_bn,
66 stride_cm,
67 stride_cn,
68 BLOCK_M: tl.constexpr,
69 BLOCK_N: tl.constexpr,
70 BLOCK_K: tl.constexpr,
71 GROUP_M: tl.constexpr,
72 IS_FP64: tl.constexpr = False,
73):
74 # matrix multiplication
75 pid = ext.program_id(0)
76 grid_m = tl.cdiv(M, BLOCK_M)
77 grid_n = tl.cdiv(N, BLOCK_N)
78 # re-order program ID for better L2 performance
79 width = GROUP_M * grid_n
80 group_id = pid // width
81 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
82 pid_m = group_id * GROUP_M + (pid % group_size)
83 pid_n = (pid % width) // (group_size)
84 # do matrix multiplication
85 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
86 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
87 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
88 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
89 rm = rm.to(tl.int64)
90 rn = rn.to(tl.int64)
91 prev_multiple = prev_multiple_of(K, BLOCK_K)
93 if IS_FP64:
94 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
95 else:
96 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
97 for start_k in range(0, prev_multiple, BLOCK_K):
98 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
99 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
100 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
101 if a.dtype != b.dtype:
102 a = a.to(C.dtype.element_ty)
103 b = b.to(C.dtype.element_ty)
104 if IS_FP64:
105 acc += tl.dot(a, b, allow_tf32=False)
106 else:
107 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
109 # loop peeling
110 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
111 mask_k = rk < K
112 a = tl.load(
113 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
114 mask=mask_k[None, :],
115 other=0.0,
116 )
117 b = tl.load(
118 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
119 mask=mask_k[:, None],
120 other=0.0,
121 )
122 if a.dtype != b.dtype:
123 a = a.to(C.dtype.element_ty)
124 b = b.to(C.dtype.element_ty)
125 if IS_FP64:
126 acc += tl.dot(a, b, allow_tf32=False)
127 else:
128 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
130 acc = acc.to(C.dtype.element_ty)
131 # rematerialize rm and rn to save registers
132 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
133 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
134 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
135 mask = (rm < M)[:, None] & (rn < N)[None, :]
136 # handles write-back with reduction-splitting
137 tl.store(C, acc, mask=mask)
140if HAS_TLE:
142 @triton.jit
143 def _cluster_remote_gemm_kernel(
144 a_ptr,
145 b_ptr,
146 c_ptr,
147 M,
148 N,
149 K,
150 stride_am,
151 stride_ak,
152 stride_bk,
153 stride_bn,
154 stride_cm,
155 stride_cn,
156 mesh: tl.constexpr,
157 BM: tl.constexpr,
158 BN: tl.constexpr,
159 BK: tl.constexpr,
160 DOT_K: tl.constexpr,
161 CLUSTER_SIZE: tl.constexpr,
162 USE_MASK: tl.constexpr,
163 A_SLOTS: tl.constexpr,
164 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr,
165 ):
166 pid = tl.program_id(0)
167 cluster_rank = tle_exp.shard_id(mesh, "cluster_x")
168 cluster_id = pid // CLUSTER_SIZE
170 num_pid_n = tl.cdiv(N, BN)
171 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE)
172 pid_m = cluster_id // num_pid_n_group
173 pid_ng = cluster_id % num_pid_n_group
174 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank
176 offs_m = pid_m * BM + tl.arange(0, BM)
177 offs_n = pid_n * BN + tl.arange(0, BN)
178 offs_k = tl.arange(0, BK)
179 a_row_base = offs_m - pid_m * BM
180 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK))
181 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
182 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM))
183 a_buf = tle_exp.gpu.alloc(
184 [A_SLOTS, BM, BK],
185 dtype=tl.float16,
186 layout=None,
187 scope=tle_exp.gpu.smem,
188 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT,
189 )
190 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh)
192 acc = tl.zeros((BM, BN), dtype=tl.float32)
193 slot0 = 0
194 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0
195 if cluster_rank == 0:
196 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
197 if USE_MASK:
198 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K)
199 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
200 else:
201 a_tile = tl.load(a_ptrs)
202 a_local_ptr_tile = tle_exp.gpu.local_ptr(
203 a_buf, (slot0_full, a_rows_full, a_cols_full)
204 )
205 if USE_MASK:
206 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
207 else:
208 tl.store(a_local_ptr_tile, a_tile)
210 tle_exp.distributed_barrier(mesh)
212 for k0 in range(0, K, BK):
213 iter_idx = k0 // BK
214 slot = iter_idx % A_SLOTS
216 for ks in range(0, BK, DOT_K):
217 k_local = ks + tl.arange(0, DOT_K)
218 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM))
219 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot
220 a_ptr_remote = tle_exp.gpu.local_ptr(
221 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t)
222 )
223 if USE_MASK:
224 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M)
225 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0))
226 else:
227 a = tl.trans(tl.load(a_ptr_remote))
229 b_ptrs = (
230 b_ptr
231 + (k0 + k_local)[:, None] * stride_bk
232 + offs_n[None, :] * stride_bn
233 )
234 if USE_MASK:
235 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N)
236 b = tl.load(b_ptrs, mask=b_mask, other=0.0)
237 else:
238 b = tl.load(b_ptrs)
239 acc = tl.dot(a, b, acc)
241 if A_SLOTS == 1:
242 tle_exp.distributed_barrier(mesh)
244 next_k0 = k0 + BK
245 has_next = next_k0 < K
246 next_iter = iter_idx + 1
247 next_slot = next_iter % A_SLOTS
248 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot
249 if has_next and cluster_rank == 0:
250 a_ptrs = (
251 a_ptr
252 + offs_m[:, None] * stride_am
253 + (next_k0 + offs_k)[None, :] * stride_ak
254 )
255 if USE_MASK:
256 a_mask_tile = (offs_m[:, None] < M) & (
257 (next_k0 + offs_k)[None, :] < K
258 )
259 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
260 else:
261 a_tile = tl.load(a_ptrs)
262 a_local_ptr_tile = tle_exp.gpu.local_ptr(
263 a_buf, (next_slot_full, a_rows_full, a_cols_full)
264 )
265 if USE_MASK:
266 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
267 else:
268 tl.store(a_local_ptr_tile, a_tile)
270 tle_exp.distributed_barrier(mesh)
272 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
273 if USE_MASK:
274 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
275 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
276 else:
277 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty))
280def _select_remote_dot_k(bk: int) -> int:
281 if bk % 16 == 0:
282 return 16
283 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}")
286def _grid_cluster_remote(
287 M: int,
288 N: int,
289 BM: int,
290 BN: int,
291 cluster_size: int = TLE_CLUSTER_SIZE,
292) -> tuple[int]:
293 num_pid_n = triton.cdiv(N, BN)
294 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size)
295 return (triton.cdiv(M, BM) * num_pid_n_group,)
298def _run_cluster_remote(
299 a: torch.Tensor,
300 b: torch.Tensor,
301 c: torch.Tensor,
302 bm: int,
303 bn: int,
304 bk: int,
305 num_warps: int,
306 num_stages: int,
307) -> None:
308 M, K = a.shape
309 N = b.shape[1]
310 dot_k = _select_remote_dot_k(bk)
311 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0)
312 a_slots = TLE_REMOTE_A_SLOTS
313 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2)
314 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)](
315 a,
316 b,
317 c,
318 M,
319 N,
320 K,
321 a.stride(0),
322 a.stride(1),
323 b.stride(0),
324 b.stride(1),
325 c.stride(0),
326 c.stride(1),
327 mesh=BLOCK_CLUSTER_MESH,
328 BM=bm,
329 BN=bn,
330 BK=bk,
331 DOT_K=dot_k,
332 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
333 USE_MASK=use_mask,
334 A_SLOTS=a_slots,
335 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout,
336 num_ctas=1,
337 num_warps=num_warps,
338 num_stages=num_stages,
339 )
342def cluster_remote_mm_scenario(a, b, c, M, N, K):
343 capability = get_device_capability()
344 return (
345 HAS_TLE
346 and BLOCK_CLUSTER_MESH is not None
347 and capability[0] >= 9
348 and a.is_cuda
349 and b.is_cuda
350 and c.is_cuda
351 and a.dtype == torch.float16
352 and b.dtype == torch.float16
353 and c.dtype == torch.float16
354 and a.is_contiguous()
355 and b.is_contiguous()
356 and M >= TLE_REMOTE_BM
357 and N >= TLE_REMOTE_BN
358 and K >= TLE_REMOTE_BK
359 )
362def cluster_remote_mm(a, b, c, M, N, K):
363 logger.debug(
364 "GEMS MM [cluster_remote]: M=%s N=%s K=%s, A_col_major=%s, B_col_major=%s",
365 M,
366 N,
367 K,
368 a.stride(0) == 1,
369 b.stride(0) == 1,
370 )
371 with torch_device_fn.device(a.device):
372 _run_cluster_remote(
373 a,
374 b,
375 c,
376 TLE_REMOTE_BM,
377 TLE_REMOTE_BN,
378 TLE_REMOTE_BK,
379 TLE_REMOTE_NUM_WARPS,
380 TLE_REMOTE_NUM_STAGES,
381 )
382 return c
385_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
388def get_higher_dtype(a, b):
389 if a is b:
390 return a
392 assert a in _ordered_datatypes
393 assert b in _ordered_datatypes
395 for d in _ordered_datatypes:
396 if a is d:
397 return b
398 if b is d:
399 return a
402def general_mm(a, b, c, M, N, K):
403 grid = lambda META: (
404 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
405 )
406 with torch_device_fn.device(a.device):
407 mm_kernel_general[grid](
408 a,
409 b,
410 c,
411 M,
412 N,
413 K,
414 a.stride(0),
415 a.stride(1),
416 b.stride(0),
417 b.stride(1),
418 c.stride(0),
419 c.stride(1),
420 GROUP_M=8,
421 IS_FP64=a.dtype == torch.float64,
422 )
423 return c
426@libentry()
427@libtuner(
428 configs=runtime.get_tuned_config("mm_self_transpose"),
429 key=["M", "K", "stride_am", "stride_ak"],
430 strategy=["align32", "align32", "align32", "align32"],
431 warmup=2,
432 rep=4,
433)
434@triton.jit
435def mm_kernel_syrk(
436 A,
437 C,
438 M,
439 K,
440 stride_am,
441 stride_ak,
442 stride_cm,
443 stride_cn,
444 BLOCK_M: tl.constexpr,
445 BLOCK_K: tl.constexpr,
446):
447 pid = tl.program_id(0)
449 # Packed lower-triangular launch domain:
450 # pid = row * (row + 1) / 2 + col, where 0 <= col <= row.
451 #
452 # Invert the triangular-number indexing by solving:
453 # row^2 + row - 2 * pid = 0
454 # => row = (-1 + sqrt(1 + 8 * pid)) / 2
455 #
456 # We take floor(...) as the candidate row, then apply an integer +/-1 correction
457 # because fp32 sqrt can be off near triangular-number boundaries.
458 pid_f = pid.to(tl.float32)
459 pid_m = tl.floor((tl.sqrt(8.0 * pid_f + 1.0) - 1.0) / 2.0).to(tl.int32)
460 tri_start = pid_m * (pid_m + 1) // 2
461 pid_m = tl.where(tri_start > pid, pid_m - 1, pid_m)
462 next_tri_start = (pid_m + 1) * (pid_m + 2) // 2
463 pid_m = tl.where(next_tri_start <= pid, pid_m + 1, pid_m)
464 tri_start = pid_m * (pid_m + 1) // 2
465 pid_n = pid - tri_start
467 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
468 rn = pid_n * BLOCK_M + tl.arange(0, BLOCK_M)
469 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
470 ran = tl.max_contiguous(tl.multiple_of(rn % M, BLOCK_M), BLOCK_M).to(tl.int64)
471 rm = rm.to(tl.int64)
472 rn = rn.to(tl.int64)
473 acc = tl.zeros((BLOCK_M, BLOCK_M), dtype=tl.float32)
475 for start_k in range(0, K, BLOCK_K):
476 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
477 mask_k = rk < K
478 a = tl.load(
479 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
480 mask=mask_k[None, :],
481 other=0.0,
482 )
483 b = tl.load(
484 A + (rk[:, None] * stride_ak + ran[None, :] * stride_am),
485 mask=mask_k[:, None],
486 other=0.0,
487 )
488 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
490 out = acc.to(C.dtype.element_ty)
491 c_ptr = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
492 mask = (rm < M)[:, None] & (rn < M)[None, :]
493 tl.store(c_ptr, out, mask=mask)
495 if pid_m > pid_n:
496 c_t_ptr = C + (rn[:, None] * stride_cm + rm[None, :] * stride_cn)
497 mask_t = (rn < M)[:, None] & (rm < M)[None, :]
498 tl.store(c_t_ptr, tl.trans(out), mask=mask_t)
501def is_syrk_transpose_pair(a, b):
502 return (
503 a.ndim == 2
504 and b.ndim == 2
505 and a.shape[0] == b.shape[1]
506 and a.shape[1] == b.shape[0]
507 and a.stride(0) == b.stride(1)
508 and a.stride(1) == b.stride(0)
509 and a.storage_offset() == b.storage_offset()
510 and a.data_ptr() == b.data_ptr()
511 )
514def syrk_mm(a, c, M, K):
515 grid = lambda META: (
516 # Number of tile rows is tiles = ceil(M / BLOCK_M).
517 # Packed lower triangle contains:
518 # 1 + 2 + ... + tiles = tiles * (tiles + 1) / 2
519 triton.cdiv(M, META["BLOCK_M"])
520 * (triton.cdiv(M, META["BLOCK_M"]) + 1)
521 // 2,
522 )
523 with torch_device_fn.device(a.device):
524 mm_kernel_syrk[grid](
525 a,
526 c,
527 M,
528 K,
529 a.stride(0),
530 a.stride(1),
531 c.stride(0),
532 c.stride(1),
533 )
534 return c
537def streamk_scenario(a, b, M, N, K):
538 # TODO: this my change sometime according to the realbenchmark result
539 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
540 # The optimal settings for other devices need to be determined through real testing.
541 capability = get_device_capability()
542 return (
543 capability[0] == 8
544 and a.dtype in [torch.float16, torch.bfloat16]
545 and b.dtype in [torch.float16, torch.bfloat16]
546 and a.is_contiguous()
547 and b.is_contiguous()
548 and K > M * 5
549 and K > N * 5
550 )
553def mm(a, b):
554 logger.debug("GEMS MM")
556 device = a.device
557 if is_syrk_transpose_pair(a, b):
558 M, K = a.shape
559 c = torch.empty((M, M), device=device, dtype=a.dtype)
560 return syrk_mm(a, c, M, K)
561 # handle non-contiguous inputs if necessary
562 if a.stride(0) > 1 and a.stride(1) > 1:
563 a = a.contiguous()
564 if b.stride(0) > 1 and b.stride(1) > 1:
565 b = b.contiguous()
566 # checks constraints
567 assert a.shape[1] == b.shape[0], "incompatible dimensions"
568 M, K = a.shape
569 _, N = b.shape
570 # allocates output
571 c_dtype = get_higher_dtype(a.dtype, b.dtype)
572 c = torch.empty((M, N), device=device, dtype=c_dtype)
573 # l2_cache_size = get_l2_cache_size()
574 sm_count = get_sm_count()
575 if streamk_scenario(a, b, M, N, K):
576 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
577 if cluster_remote_mm_scenario(a, b, c, M, N, K):
578 return cluster_remote_mm(a, b, c, M, N, K)
579 return general_mm(a, b, c, M, N, K)
582def mm_out(a, b, *, out):
583 logger.debug("GEMS MM_OUT")
585 if is_syrk_transpose_pair(a, b):
586 M, K = a.shape
587 return syrk_mm(a, out, M, K)
588 # handle non-contiguous inputs if necessary
589 if a.stride(0) > 1 and a.stride(1) > 1:
590 a = a.contiguous()
591 if b.stride(0) > 1 and b.stride(1) > 1:
592 b = b.contiguous()
593 # checks constraints
594 assert a.shape[1] == b.shape[0], "incompatible dimensions"
595 M, K = a.shape
596 _, N = b.shape
597 # l2_cache_size = get_l2_cache_size()
598 sm_count = get_sm_count()
599 if streamk_scenario(a, b, M, N, K):
600 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
601 if cluster_remote_mm_scenario(a, b, out, M, N, K):
602 return cluster_remote_mm(a, b, out, M, N, K)
603 return general_mm(a, b, out, M, N, K)
606def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
607 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K)."""
608 if x.stride(0) > 1 and x.stride(1) > 1:
609 x = x.contiguous()
610 M, K = x.shape
611 N = weight.shape[0]
612 c = torch.empty((M, N), device=x.device, dtype=torch.float32)
613 b = weight.t().contiguous()
614 return general_mm(x, b, c, M, N, K)