Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
432 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.mm_streamk import streamk_mm
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
14from flag_gems.utils.device_info import get_device_capability, get_sm_count
15from flag_gems.utils.triton_version_utils import HAS_TLE, HAS_TLE_DEVICE_MESH
17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm")
18CACHE_USAGE_THRESHOLD = 0.8
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml")
21)
22_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024
25def _get_shared_memory_limit_bytes():
26 """Return per-block opt-in shared-memory limit for current CUDA device."""
27 try:
28 if not torch.cuda.is_available():
29 return None
30 return torch.cuda.get_device_properties(
31 torch.cuda.current_device()
32 ).shared_memory_per_block_optin
33 except Exception:
34 return None
37def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages):
38 bytes_per_element = 4
39 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element
40 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES
43if HAS_TLE_DEVICE_MESH:
44 import triton.experimental.tle.language as tle_exp
46 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]})
47 TLE_CLUSTER_SIZE = 2
48 TLE_REMOTE_BM = 64
49 TLE_REMOTE_BN = 256
50 TLE_REMOTE_BK = 64
51 TLE_REMOTE_NUM_WARPS = 8
52 TLE_REMOTE_NUM_STAGES = 2
53 TLE_REMOTE_A_SLOTS = 2
54else:
55 tle_exp = None
56 BLOCK_CLUSTER_MESH = None
57 TLE_CLUSTER_SIZE = 2
58 TLE_REMOTE_BM = 64
59 TLE_REMOTE_BN = 256
60 TLE_REMOTE_BK = 64
61 TLE_REMOTE_NUM_WARPS = 8
62 TLE_REMOTE_NUM_STAGES = 2
63 TLE_REMOTE_A_SLOTS = 2
66def is_tma_compatible(a, b, N, K):
67 """
68 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
70 TMA requires 128-bit (16-byte) alignment for memory access:
71 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
72 (8 elements × 2 bytes = 16 bytes)
73 - For FP32 (4 bytes/element): N and K must be multiples of 4
74 (4 elements × 4 bytes = 16 bytes)
76 Args:
77 a, b: Input tensors
78 N, K: Matrix dimensions
80 Returns:
81 bool: True if compatible with TMA's alignment requirements
82 """
83 return (
84 a.dtype in (torch.float16, torch.bfloat16)
85 and b.dtype in (torch.float16, torch.bfloat16)
86 and N % 8 == 0
87 and K % 8 == 0
88 ) or (
89 a.dtype in (torch.float32,)
90 and b.dtype in (torch.float32,)
91 and N % 4 == 0
92 and K % 4 == 0
93 )
96@triton.jit
97def prev_multiple_of(a, b):
98 # the largest x<a that x%b ==0
99 return tl.cdiv(a, b) * b - b
102def matmul_tma_set_block_size_hook(nargs):
103 BLOCK_M = nargs["BLOCK_M"]
104 BLOCK_N = nargs["BLOCK_N"]
105 BLOCK_K = nargs["BLOCK_K"]
106 if nargs["A_ROW_MAJOR"]:
107 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
108 else:
109 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
111 if nargs["B_ROW_MAJOR"]:
112 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
113 else:
114 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
116 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
119@libentry()
120@libtuner(
121 configs=runtime.get_tuned_config("mm"),
122 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
123 key=["M", "N", "K", "stride_am", "stride_bk"],
124 strategy=["default", "default", "default", "default", "default"],
125 warmup=5,
126 rep=10,
127)
128@triton.jit
129def mm_kernel_general(
130 A,
131 B,
132 C,
133 M,
134 N,
135 K,
136 stride_am,
137 stride_ak,
138 stride_bk,
139 stride_bn,
140 stride_cm,
141 stride_cn,
142 BLOCK_M: tl.constexpr,
143 BLOCK_N: tl.constexpr,
144 BLOCK_K: tl.constexpr,
145 GROUP_M: tl.constexpr,
146 IS_FP64: tl.constexpr = False,
147):
148 # matrix multiplication
149 pid = ext.program_id(0)
150 grid_m = tl.cdiv(M, BLOCK_M)
151 grid_n = tl.cdiv(N, BLOCK_N)
152 # re-order program ID for better L2 performance
153 width = GROUP_M * grid_n
154 group_id = pid // width
155 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
156 pid_m = group_id * GROUP_M + (pid % group_size)
157 pid_n = (pid % width) // (group_size)
159 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
160 # offset
161 offset_am = pid_m * BLOCK_M
162 offset_bn = pid_n * BLOCK_N
163 offset_k = 0
165 a_desc = tl.make_tensor_descriptor(
166 base=A,
167 shape=[M, K],
168 strides=[K, 1],
169 block_shape=[BLOCK_M, BLOCK_K],
170 )
172 # row-major
173 b_desc = tl.make_tensor_descriptor(
174 base=B,
175 shape=[K, N],
176 strides=[N, 1],
177 block_shape=[BLOCK_K, BLOCK_N],
178 )
180 # column-major
181 # b_desc = tl.make_tensor_descriptor(
182 # B,
183 # shape = [N, K],
184 # strides = [K, 1],
185 # block_shape = [BLOCK_N, BLOCK_K],
186 # )
188 c_desc = tl.make_tensor_descriptor(
189 base=C,
190 shape=[M, N],
191 strides=[N, 1],
192 block_shape=[BLOCK_M, BLOCK_N],
193 )
195 if IS_FP64:
196 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
197 else:
198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
199 for k in range(0, tl.cdiv(K, BLOCK_K)):
200 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
201 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
202 if IS_FP64:
203 acc += tl.dot(a, b, allow_tf32=False)
204 else:
205 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
206 offset_k += BLOCK_K
208 acc = acc.to(a_desc.dtype)
209 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
211 else:
212 # do matrix multiplication
213 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
214 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
215 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
216 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
217 rm = rm.to(tl.int64)
218 rn = rn.to(tl.int64)
219 prev_multiple = prev_multiple_of(K, BLOCK_K)
221 if IS_FP64:
222 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
223 else:
224 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
225 for start_k in range(0, prev_multiple, BLOCK_K):
226 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
227 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
228 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
229 if a.dtype != b.dtype:
230 a = a.to(C.dtype.element_ty)
231 b = b.to(C.dtype.element_ty)
232 if IS_FP64:
233 acc += tl.dot(a, b, allow_tf32=False)
234 else:
235 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
237 # loop peeling
238 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
239 mask_k = rk < K
240 a = tl.load(
241 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
242 mask=mask_k[None, :],
243 other=0.0,
244 )
245 b = tl.load(
246 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
247 mask=mask_k[:, None],
248 other=0.0,
249 )
250 if a.dtype != b.dtype:
251 a = a.to(C.dtype.element_ty)
252 b = b.to(C.dtype.element_ty)
253 if IS_FP64:
254 acc += tl.dot(a, b, allow_tf32=False)
255 else:
256 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
258 acc = acc.to(C.dtype.element_ty)
259 # rematerialize rm and rn to save registers
260 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
261 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
262 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
263 mask = (rm < M)[:, None] & (rn < N)[None, :]
264 # handles write-back with reduction-splitting
265 tl.store(offsets, acc, mask=mask)
268def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
269 configs = [
270 triton.Config(
271 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
272 num_stages=s,
273 num_warps=w,
274 pre_hook=pre_hook,
275 )
276 for BM in [32, 64, 128, 256]
277 for BN in [32, 64, 128]
278 for BK in [32, 64, 128]
279 for s in [2, 3, 4]
280 for w in [4, 8]
281 ]
282 shared_mem_limit = _get_shared_memory_limit_bytes()
283 if shared_mem_limit is None:
284 return configs
286 filtered_configs = [
287 cfg
288 for cfg in configs
289 if _estimate_tma_shared_memory_bytes(
290 cfg.kwargs["BLOCK_M"],
291 cfg.kwargs["BLOCK_N"],
292 cfg.kwargs["BLOCK_K"],
293 cfg.num_stages,
294 )
295 <= shared_mem_limit
296 ]
297 if not filtered_configs:
298 logger.warning(
299 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.",
300 shared_mem_limit,
301 )
302 return configs
303 return filtered_configs
306@libentry()
307@libtuner(
308 configs=matmul_get_configs(),
309 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
310 strategy=["align32", "align32", "align32", "align32", "align32", "default"],
311 warmup=5,
312 rep=5,
313 flagtune_op_name="mm",
314 flagtune_expand_op_name="mm_general_tma",
315 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
316 flagtune_pre_hook=matmul_tma_set_block_size_hook,
317)
318@triton.jit
319def mm_kernel_general_host_tma(
320 a_desc,
321 b_desc,
322 c_desc,
323 M,
324 N,
325 K,
326 stride_am,
327 stride_ak,
328 stride_bk,
329 stride_bn,
330 stride_cm,
331 stride_cn,
332 BLOCK_M: tl.constexpr,
333 BLOCK_N: tl.constexpr,
334 BLOCK_K: tl.constexpr,
335 GROUP_M: tl.constexpr,
336 A_ROW_MAJOR: tl.constexpr,
337 B_ROW_MAJOR: tl.constexpr,
338 dtype: tl.constexpr,
339 enable_warp_specialization=True,
340):
341 pid = tl.program_id(0)
342 grid_m = tl.cdiv(M, BLOCK_M)
343 grid_n = tl.cdiv(N, BLOCK_N)
345 width = GROUP_M * grid_n
346 group_id = pid // width
347 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
348 pid_m = group_id * GROUP_M + (pid % group_size)
349 pid_n = (pid % width) // (group_size)
351 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
352 offset_am = (pid_m * BLOCK_M).to(tl.int32)
353 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
354 iters = tl.cdiv(K, BLOCK_K)
355 for k in range(iters):
356 offset_ak = (k * BLOCK_K).to(tl.int32)
358 if A_ROW_MAJOR:
359 a = a_desc.load([offset_am, offset_ak])
360 else:
361 a_t = a_desc.load([offset_ak, offset_am])
362 a = tl.trans(a_t)
364 if B_ROW_MAJOR:
365 b = b_desc.load([offset_ak, offset_bn])
366 else:
367 b_t = b_desc.load([offset_bn, offset_ak])
368 b = tl.trans(b_t)
370 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
371 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
372 else:
373 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
375 c = accumulator.to(c_desc.dtype)
376 c_desc.store([offset_am, offset_bn], c)
379def get_higher_dtype(a, b):
380 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
382 if a is b:
383 return a
385 assert a in _ordered_datatypes
386 assert b in _ordered_datatypes
388 for d in _ordered_datatypes:
389 if a is d:
390 return b
391 if b is d:
392 return a
395def general_mm(a, b, c, M, N, K, op_name="mm"):
396 # TODO: Remove this debug message
397 logger.debug(
398 "GEMS MM-hopper, [op]: %s, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
399 "[A column-major]: %s, [B column-major]: %s",
400 op_name,
401 M,
402 N,
403 K,
404 a.stride(0) == 1,
405 b.stride(0) == 1,
406 )
407 # Broadcast tensors from expand() have stride=0, incompatible with TMA
408 if 0 in a.stride():
409 a = a.contiguous()
410 if 0 in b.stride():
411 b = b.contiguous()
412 grid = lambda META: (
413 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
414 )
415 if hasattr(
416 triton.tools.tensor_descriptor, "TensorDescriptor"
417 ) and is_tma_compatible(a, b, N, K):
418 a_row_major = a.stride(1) == 1
419 b_row_major = b.stride(1) == 1
420 dummy_block = [1, 1]
421 # triton 3.5.0
422 from triton.tools.tensor_descriptor import TensorDescriptor
424 if a_row_major:
425 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
426 else:
427 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
428 if b_row_major:
429 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
430 else:
431 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
432 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
434 input_dtype = a.dtype
435 dtype_str = str(input_dtype).split(".")[-1]
437 with torch_device_fn.device(a.device):
438 mm_kernel_general_host_tma[grid](
439 a_desc,
440 b_desc,
441 c_desc,
442 M,
443 N,
444 K,
445 a.stride(0),
446 a.stride(1),
447 b.stride(0),
448 b.stride(1),
449 c.stride(0),
450 c.stride(1),
451 GROUP_M=8,
452 A_ROW_MAJOR=a_row_major,
453 B_ROW_MAJOR=b_row_major,
454 dtype=dtype_str,
455 )
456 else:
458 def alloc_fn(size: int, align: int, stream: Optional[int]):
459 return torch.empty(size, dtype=torch.int8, device=a.device)
461 triton.set_allocator(alloc_fn)
463 with torch_device_fn.device(a.device):
464 mm_kernel_general[grid](
465 a,
466 b,
467 c,
468 M,
469 N,
470 K,
471 a.stride(0),
472 a.stride(1),
473 b.stride(0),
474 b.stride(1),
475 c.stride(0),
476 c.stride(1),
477 GROUP_M=8,
478 IS_FP64=a.dtype == torch.float64,
479 )
480 return c
483@libentry()
484@libtuner(
485 configs=[
486 triton.Config(
487 {"BLOCK_M": 32, "BLOCK_K": 256},
488 )
489 ],
490 key=["M", "K", "stride_am", "stride_bk"],
491 strategy=["align32", "align32", "align32", "default"],
492 warmup=5,
493 rep=10,
494 flagtune_op_name="mm",
495 flagtune_expand_op_name="gemv",
496 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
497 flagtune_pre_hook=None,
498)
499@triton.jit
500def gemv_kernel(
501 A,
502 B,
503 C,
504 M,
505 K,
506 stride_am,
507 stride_ak,
508 stride_bk,
509 BLOCK_M: tl.constexpr,
510 BLOCK_K: tl.constexpr,
511 IS_FP64: tl.constexpr = False,
512):
513 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
514 pid = tl.program_id(0)
516 # Each program handles BLOCK_M rows
517 row_start = pid * BLOCK_M
518 row_offset = row_start + tl.arange(0, BLOCK_M)
519 row_mask = row_offset < M
521 # Accumulator for this block of rows
522 if IS_FP64:
523 acc = tl.zeros((BLOCK_M,), dtype=tl.float64)
524 else:
525 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
527 # Iterate over K dimension
528 for k_start in range(0, K, BLOCK_K):
529 k_offset = k_start + tl.arange(0, BLOCK_K)
530 k_mask = k_offset < K
532 # Load block from matrix A: [BLOCK_M, BLOCK_K]
533 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
534 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
536 # Load block from vector B: [BLOCK_K]
537 b_ptrs = B + k_offset * stride_bk
538 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
540 # Accumulate: sum over K dimension
541 if IS_FP64:
542 acc += tl.sum(a * b[None, :], axis=1)
543 else:
544 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1)
546 # Store result
547 c_ptrs = C + row_offset
548 acc = acc.to(C.dtype.element_ty)
549 tl.store(c_ptrs, acc, mask=row_mask)
552def gemv_mm(a, b, c, M, K):
553 """Optimized matrix-vector multiplication for N=1 case"""
554 logger.debug(
555 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
556 M,
557 K,
558 )
560 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
562 with torch_device_fn.device(a.device):
563 gemv_kernel[grid](
564 a,
565 b,
566 c,
567 M,
568 K,
569 a.stride(0),
570 a.stride(1),
571 b.stride(0),
572 IS_FP64=a.dtype == torch.float64,
573 )
574 return c
577@libentry()
578@libtuner(
579 configs=runtime.get_tuned_config("mm_splitk"),
580 key=["M", "N", "K", "stride_am", "stride_bk"],
581 reset_to_zero=["C"],
582 strategy=["align32", "align32", "align32", "align32", "align32"],
583 warmup=5,
584 rep=10,
585 flagtune_op_name="mm",
586 flagtune_expand_op_name="mm_splitk",
587 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
588 flagtune_pre_hook=None,
589)
590@triton.jit
591def mm_kernel_splitk(
592 A,
593 B,
594 C,
595 M,
596 N,
597 K,
598 stride_am,
599 stride_ak,
600 stride_bk,
601 stride_bn,
602 stride_cm,
603 stride_cn,
604 BLOCK_M: tl.constexpr,
605 BLOCK_N: tl.constexpr,
606 BLOCK_K: tl.constexpr,
607 SPLIT_K: tl.constexpr,
608):
609 pid = tl.program_id(0)
610 pid_k = tl.program_id(1)
612 grid_n = tl.cdiv(N, BLOCK_N)
613 pid_m = pid // grid_n
614 pid_n = pid % grid_n
616 offset_am = pid_m * BLOCK_M
617 offset_bn = pid_n * BLOCK_N
618 offs_am = offset_am + tl.arange(0, BLOCK_M)
619 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
621 total_k_iters = tl.cdiv(K, BLOCK_K)
622 k_per_split = tl.cdiv(total_k_iters, SPLIT_K)
623 k_start = pid_k * k_per_split
624 k_end = min((pid_k + 1) * k_per_split, total_k_iters)
626 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
627 for k in range(k_start, k_end):
628 offset_k = k * BLOCK_K
629 offs_k = offset_k + tl.arange(0, BLOCK_K)
631 a = tl.load(
632 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak,
633 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K),
634 other=0.0,
635 )
636 b = tl.load(
637 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn,
638 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N),
639 other=0.0,
640 )
641 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
643 offs_cm = offset_am + tl.arange(0, BLOCK_M)
644 offs_cn = offset_bn + tl.arange(0, BLOCK_N)
645 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
646 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
647 tl.atomic_add(c_ptrs, acc, mask=mask)
650def splitk_mm(a, b, c, M, N, K, op_name="mm"):
651 logger.debug(
652 "GEMS MM-hopper, [op]: %s, [mm scenario]: splitk, [shape info]: [-, %s, %s, %s](batch, M, N, K)",
653 op_name,
654 M,
655 N,
656 K,
657 )
658 grid = lambda META: (
659 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
660 META["SPLIT_K"],
661 )
662 with torch_device_fn.device(a.device):
663 mm_kernel_splitk[grid](
664 a,
665 b,
666 c,
667 M,
668 N,
669 K,
670 a.stride(0),
671 a.stride(1),
672 b.stride(0),
673 b.stride(1),
674 c.stride(0),
675 c.stride(1),
676 )
677 return c
680def streamk_scenario(a, b, M, N, K):
681 # TODO: this my change sometime according to the realbenchmark result
682 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
683 # The optimal settings for other devices need to be determined through real testing.
684 capability = get_device_capability()
685 return (
686 capability[0] == 8
687 and a.dtype in [torch.float16, torch.bfloat16]
688 and b.dtype in [torch.float16, torch.bfloat16]
689 and a.is_contiguous()
690 and b.is_contiguous()
691 and K > M * 5
692 and K > N * 5
693 )
696if HAS_TLE:
698 @triton.jit
699 def _cluster_remote_gemm_kernel(
700 a_ptr,
701 b_ptr,
702 c_ptr,
703 M,
704 N,
705 K,
706 stride_am,
707 stride_ak,
708 stride_bk,
709 stride_bn,
710 stride_cm,
711 stride_cn,
712 mesh: tl.constexpr,
713 BM: tl.constexpr,
714 BN: tl.constexpr,
715 BK: tl.constexpr,
716 DOT_K: tl.constexpr,
717 CLUSTER_SIZE: tl.constexpr,
718 USE_MASK: tl.constexpr,
719 A_SLOTS: tl.constexpr,
720 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr,
721 ):
722 pid = tl.program_id(0)
723 cluster_rank = tle_exp.shard_id(mesh, "cluster_x")
724 cluster_id = pid // CLUSTER_SIZE
726 num_pid_n = tl.cdiv(N, BN)
727 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE)
728 pid_m = cluster_id // num_pid_n_group
729 pid_ng = cluster_id % num_pid_n_group
730 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank
732 offs_m = pid_m * BM + tl.arange(0, BM)
733 offs_n = pid_n * BN + tl.arange(0, BN)
734 offs_k = tl.arange(0, BK)
735 a_row_base = offs_m - pid_m * BM
736 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK))
737 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
738 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM))
739 a_buf = tle_exp.gpu.alloc(
740 [A_SLOTS, BM, BK],
741 dtype=tl.float16,
742 layout=None,
743 scope=tle_exp.gpu.smem,
744 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT,
745 )
746 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh)
748 acc = tl.zeros((BM, BN), dtype=tl.float32)
749 slot0 = 0
750 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0
751 if cluster_rank == 0:
752 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
753 if USE_MASK:
754 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K)
755 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
756 else:
757 a_tile = tl.load(a_ptrs)
758 a_local_ptr_tile = tle_exp.gpu.local_ptr(
759 a_buf, (slot0_full, a_rows_full, a_cols_full)
760 )
761 if USE_MASK:
762 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
763 else:
764 tl.store(a_local_ptr_tile, a_tile)
766 tle_exp.distributed_barrier(mesh)
768 for k0 in range(0, K, BK):
769 iter_idx = k0 // BK
770 slot = iter_idx % A_SLOTS
772 for ks in range(0, BK, DOT_K):
773 k_local = ks + tl.arange(0, DOT_K)
774 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM))
775 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot
776 a_ptr_remote = tle_exp.gpu.local_ptr(
777 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t)
778 )
779 if USE_MASK:
780 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M)
781 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0))
782 else:
783 a = tl.trans(tl.load(a_ptr_remote))
785 b_ptrs = (
786 b_ptr
787 + (k0 + k_local)[:, None] * stride_bk
788 + offs_n[None, :] * stride_bn
789 )
790 if USE_MASK:
791 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N)
792 b = tl.load(b_ptrs, mask=b_mask, other=0.0)
793 else:
794 b = tl.load(b_ptrs)
795 acc = tl.dot(a, b, acc)
797 if A_SLOTS == 1:
798 tle_exp.distributed_barrier(mesh)
800 next_k0 = k0 + BK
801 has_next = next_k0 < K
802 next_iter = iter_idx + 1
803 next_slot = next_iter % A_SLOTS
804 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot
805 if has_next and cluster_rank == 0:
806 a_ptrs = (
807 a_ptr
808 + offs_m[:, None] * stride_am
809 + (next_k0 + offs_k)[None, :] * stride_ak
810 )
811 if USE_MASK:
812 a_mask_tile = (offs_m[:, None] < M) & (
813 (next_k0 + offs_k)[None, :] < K
814 )
815 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
816 else:
817 a_tile = tl.load(a_ptrs)
818 a_local_ptr_tile = tle_exp.gpu.local_ptr(
819 a_buf, (next_slot_full, a_rows_full, a_cols_full)
820 )
821 if USE_MASK:
822 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
823 else:
824 tl.store(a_local_ptr_tile, a_tile)
826 tle_exp.distributed_barrier(mesh)
828 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
829 if USE_MASK:
830 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
831 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
832 else:
833 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty))
836def _select_remote_dot_k(bk: int) -> int:
837 if bk % 16 == 0:
838 return 16
839 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}")
842def _grid_cluster_remote(
843 M: int,
844 N: int,
845 BM: int,
846 BN: int,
847 cluster_size: int = TLE_CLUSTER_SIZE,
848) -> tuple:
849 num_pid_n = triton.cdiv(N, BN)
850 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size)
851 return (triton.cdiv(M, BM) * num_pid_n_group,)
854def _run_cluster_remote(
855 a: torch.Tensor,
856 b: torch.Tensor,
857 c: torch.Tensor,
858 bm: int,
859 bn: int,
860 bk: int,
861 num_warps: int,
862 num_stages: int,
863) -> None:
864 M, K = a.shape
865 N = b.shape[1]
866 dot_k = _select_remote_dot_k(bk)
867 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0)
868 a_slots = TLE_REMOTE_A_SLOTS
869 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2)
870 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)](
871 a,
872 b,
873 c,
874 M,
875 N,
876 K,
877 a.stride(0),
878 a.stride(1),
879 b.stride(0),
880 b.stride(1),
881 c.stride(0),
882 c.stride(1),
883 mesh=BLOCK_CLUSTER_MESH,
884 BM=bm,
885 BN=bn,
886 BK=bk,
887 DOT_K=dot_k,
888 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
889 USE_MASK=use_mask,
890 A_SLOTS=a_slots,
891 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout,
892 num_ctas=1,
893 num_warps=num_warps,
894 num_stages=num_stages,
895 )
898def cluster_remote_mm_scenario(a, b, c, M, N, K):
899 capability = get_device_capability()
900 return (
901 HAS_TLE
902 and BLOCK_CLUSTER_MESH is not None
903 and capability[0] >= 9
904 and a.is_cuda
905 and b.is_cuda
906 and c.is_cuda
907 and a.dtype == torch.float16
908 and b.dtype == torch.float16
909 and c.dtype == torch.float16
910 and a.is_contiguous()
911 and b.is_contiguous()
912 and M >= TLE_REMOTE_BM
913 and N >= TLE_REMOTE_BN
914 and K >= TLE_REMOTE_BK
915 )
918def cluster_remote_mm(a, b, c, M, N, K):
919 logger.debug(
920 M,
921 N,
922 K,
923 a.stride(0) == 1,
924 b.stride(0) == 1,
925 )
926 with torch_device_fn.device(a.device):
927 _run_cluster_remote(
928 a,
929 b,
930 c,
931 TLE_REMOTE_BM,
932 TLE_REMOTE_BN,
933 TLE_REMOTE_BK,
934 TLE_REMOTE_NUM_WARPS,
935 TLE_REMOTE_NUM_STAGES,
936 )
937 return c
940def mm(a, b):
941 device = a.device
942 # handle non-contiguous inputs if necessary
943 if a.stride(0) > 1 and a.stride(1) > 1:
944 a = a.contiguous()
945 if b.stride(0) > 1 and b.stride(1) > 1:
946 b = b.contiguous()
947 # checks constraints
948 assert a.shape[1] == b.shape[0], "incompatible dimensions"
949 M, K = a.shape
950 _, N = b.shape
951 # allocates output
952 c_dtype = get_higher_dtype(a.dtype, b.dtype)
953 c = torch.empty((M, N), device=device, dtype=c_dtype)
955 # Optimize for N=1 case (matrix-vector multiplication)
956 if N == 1:
957 return gemv_mm(a, b, c, M, K)
958 # l2_cache_size = get_l2_cache_size()
959 sm_count = get_sm_count()
960 if streamk_scenario(a, b, M, N, K):
961 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
962 if HAS_TLE and BLOCK_CLUSTER_MESH is not None:
963 if cluster_remote_mm_scenario(a, b, c, M, N, K):
964 return cluster_remote_mm(a, b, c, M, N, K)
965 # Use splitk for small M
966 if M < 2048 and N < 2048 and K >= 4096:
967 c.zero_()
968 return splitk_mm(a, b, c, M, N, K)
969 return general_mm(a, b, c, M, N, K)
972def mm_out(a, b, *, out):
973 # handle non-contiguous inputs if necessary
974 if a.stride(0) > 1 and a.stride(1) > 1:
975 a = a.contiguous()
976 if b.stride(0) > 1 and b.stride(1) > 1:
977 b = b.contiguous()
978 # checks constraints
979 assert a.shape[1] == b.shape[0], "incompatible dimensions"
980 M, K = a.shape
981 _, N = b.shape
983 # Optimize for N=1 case (matrix-vector multiplication)
984 if N == 1:
985 return gemv_mm(a, b, out, M, K)
986 # l2_cache_size = get_l2_cache_size()
987 sm_count = get_sm_count()
988 if streamk_scenario(a, b, M, N, K):
989 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
990 if HAS_TLE and BLOCK_CLUSTER_MESH is not None:
991 if cluster_remote_mm_scenario(a, b, out, M, N, K):
992 return cluster_remote_mm(a, b, out, M, N, K)
993 # Use splitk for small M
994 if M < 2048 and N < 2048 and K >= 4096:
995 out.zero_()
996 return splitk_mm(a, b, out, M, N, K)
997 return general_mm(a, b, out, M, N, K)
1000def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1001 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K)."""
1002 if x.stride(0) > 1 and x.stride(1) > 1:
1003 x = x.contiguous()
1004 M, K = x.shape
1005 N = weight.shape[0]
1006 c = torch.empty((M, N), device=x.device, dtype=torch.float32)
1007 b = weight.t()
1008 if M < 2048 and N < 2048 and K >= 4096:
1009 c.zero_()
1010 return splitk_mm(x, b, c, M, N, K, op_name="router_gemm")
1011 return general_mm(x, b, c, M, N, K, op_name="router_gemm")