Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
432 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.mm_streamk import streamk_mm
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
14from flag_gems.utils.device_info import get_device_capability, get_sm_count
15from flag_gems.utils.triton_version_utils import HAS_TLE, HAS_TLE_DEVICE_MESH
17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm")
18CACHE_USAGE_THRESHOLD = 0.8
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml")
21)
22_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024
25def _get_shared_memory_limit_bytes():
26 """Return per-block opt-in shared-memory limit for current CUDA device."""
27 try:
28 if not torch.cuda.is_available():
29 return None
30 return torch.cuda.get_device_properties(
31 torch.cuda.current_device()
32 ).shared_memory_per_block_optin
33 except Exception:
34 return None
37def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages):
38 bytes_per_element = 4
39 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element
40 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES
43if HAS_TLE_DEVICE_MESH:
44 import triton.experimental.tle.language as tle_exp
46 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]})
47 TLE_CLUSTER_SIZE = 2
48 TLE_REMOTE_BM = 64
49 TLE_REMOTE_BN = 256
50 TLE_REMOTE_BK = 64
51 TLE_REMOTE_NUM_WARPS = 8
52 TLE_REMOTE_NUM_STAGES = 2
53 TLE_REMOTE_A_SLOTS = 2
54else:
55 tle_exp = None
56 BLOCK_CLUSTER_MESH = None
57 TLE_CLUSTER_SIZE = 2
58 TLE_REMOTE_BM = 64
59 TLE_REMOTE_BN = 256
60 TLE_REMOTE_BK = 64
61 TLE_REMOTE_NUM_WARPS = 8
62 TLE_REMOTE_NUM_STAGES = 2
63 TLE_REMOTE_A_SLOTS = 2
66def is_tma_compatible(a, b, N, K):
67 """
68 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
70 TMA requires 128-bit (16-byte) alignment for memory access:
71 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
72 (8 elements × 2 bytes = 16 bytes)
73 - For FP32 (4 bytes/element): N and K must be multiples of 4
74 (4 elements × 4 bytes = 16 bytes)
76 Args:
77 a, b: Input tensors
78 N, K: Matrix dimensions
80 Returns:
81 bool: True if compatible with TMA's alignment requirements
82 """
83 return (
84 a.dtype in (torch.float16, torch.bfloat16)
85 and b.dtype in (torch.float16, torch.bfloat16)
86 and N % 8 == 0
87 and K % 8 == 0
88 ) or (
89 a.dtype in (torch.float32,)
90 and b.dtype in (torch.float32,)
91 and N % 4 == 0
92 and K % 4 == 0
93 )
96@triton.jit
97def prev_multiple_of(a, b):
98 # the largest x<a that x%b ==0
99 return tl.cdiv(a, b) * b - b
102def matmul_tma_set_block_size_hook(nargs):
103 BLOCK_M = nargs["BLOCK_M"]
104 BLOCK_N = nargs["BLOCK_N"]
105 BLOCK_K = nargs["BLOCK_K"]
106 if nargs["A_ROW_MAJOR"]:
107 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
108 else:
109 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
111 if nargs["B_ROW_MAJOR"]:
112 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
113 else:
114 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
116 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
119@libentry()
120@libtuner(
121 configs=runtime.get_tuned_config("mm"),
122 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
123 key=["M", "N", "K", "stride_am", "stride_bk"],
124 strategy=["default", "default", "default", "default", "default"],
125 warmup=5,
126 rep=10,
127)
128@triton.jit
129def mm_kernel_general(
130 A,
131 B,
132 C,
133 M,
134 N,
135 K,
136 stride_am,
137 stride_ak,
138 stride_bk,
139 stride_bn,
140 stride_cm,
141 stride_cn,
142 BLOCK_M: tl.constexpr,
143 BLOCK_N: tl.constexpr,
144 BLOCK_K: tl.constexpr,
145 GROUP_M: tl.constexpr,
146 IS_FP64: tl.constexpr = False,
147):
148 # matrix multiplication
149 pid = ext.program_id(0)
150 grid_m = tl.cdiv(M, BLOCK_M)
151 grid_n = tl.cdiv(N, BLOCK_N)
152 # re-order program ID for better L2 performance
153 width = GROUP_M * grid_n
154 group_id = pid // width
155 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
156 pid_m = group_id * GROUP_M + (pid % group_size)
157 pid_n = (pid % width) // (group_size)
159 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
160 # offset
161 offset_am = pid_m * BLOCK_M
162 offset_bn = pid_n * BLOCK_N
163 offset_k = 0
165 a_desc = tl.make_tensor_descriptor(
166 base=A,
167 shape=[M, K],
168 strides=[K, 1],
169 block_shape=[BLOCK_M, BLOCK_K],
170 )
172 # row-major
173 b_desc = tl.make_tensor_descriptor(
174 base=B,
175 shape=[K, N],
176 strides=[N, 1],
177 block_shape=[BLOCK_K, BLOCK_N],
178 )
180 # column-major
181 # b_desc = tl.make_tensor_descriptor(
182 # B,
183 # shape = [N, K],
184 # strides = [K, 1],
185 # block_shape = [BLOCK_N, BLOCK_K],
186 # )
188 c_desc = tl.make_tensor_descriptor(
189 base=C,
190 shape=[M, N],
191 strides=[N, 1],
192 block_shape=[BLOCK_M, BLOCK_N],
193 )
195 if IS_FP64:
196 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
197 else:
198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
199 for k in range(0, tl.cdiv(K, BLOCK_K)):
200 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
201 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
202 if IS_FP64:
203 acc += tl.dot(a, b, allow_tf32=False)
204 else:
205 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
206 offset_k += BLOCK_K
208 acc = acc.to(a_desc.dtype)
209 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
211 else:
212 # do matrix multiplication
213 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
214 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
215 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
216 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
217 rm = rm.to(tl.int64)
218 rn = rn.to(tl.int64)
219 prev_multiple = prev_multiple_of(K, BLOCK_K)
221 if IS_FP64:
222 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
223 else:
224 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
225 for start_k in range(0, prev_multiple, BLOCK_K):
226 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
227 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
228 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
229 if a.dtype != b.dtype:
230 a = a.to(C.dtype.element_ty)
231 b = b.to(C.dtype.element_ty)
232 if IS_FP64:
233 acc += tl.dot(a, b, allow_tf32=False)
234 else:
235 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
237 # loop peeling
238 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
239 mask_k = rk < K
240 a = tl.load(
241 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
242 mask=mask_k[None, :],
243 other=0.0,
244 )
245 b = tl.load(
246 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
247 mask=mask_k[:, None],
248 other=0.0,
249 )
250 if a.dtype != b.dtype:
251 a = a.to(C.dtype.element_ty)
252 b = b.to(C.dtype.element_ty)
253 if IS_FP64:
254 acc += tl.dot(a, b, allow_tf32=False)
255 else:
256 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
258 acc = acc.to(C.dtype.element_ty)
259 # rematerialize rm and rn to save registers
260 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
261 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
262 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
263 mask = (rm < M)[:, None] & (rn < N)[None, :]
264 # handles write-back with reduction-splitting
265 tl.store(offsets, acc, mask=mask)
268def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
269 configs = [
270 triton.Config(
271 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
272 num_stages=s,
273 num_warps=w,
274 pre_hook=pre_hook,
275 )
276 for BM in [32, 64, 128, 256]
277 for BN in [32, 64, 128]
278 for BK in [32, 64, 128]
279 for s in [2, 3, 4]
280 for w in [4, 8]
281 ]
282 shared_mem_limit = _get_shared_memory_limit_bytes()
283 if shared_mem_limit is None:
284 return configs
286 filtered_configs = [
287 cfg
288 for cfg in configs
289 if _estimate_tma_shared_memory_bytes(
290 cfg.kwargs["BLOCK_M"],
291 cfg.kwargs["BLOCK_N"],
292 cfg.kwargs["BLOCK_K"],
293 cfg.num_stages,
294 )
295 <= shared_mem_limit
296 ]
297 if not filtered_configs:
298 logger.warning(
299 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.",
300 shared_mem_limit,
301 )
302 return configs
303 return filtered_configs
306@libentry()
307@libtuner(
308 configs=runtime.ops_get_configs(
309 "mm_general_tma",
310 pre_hook=matmul_tma_set_block_size_hook,
311 yaml_path=EXPAND_CONFIG_FILENAME,
312 )
313 if os.environ.get("USE_FLAGTUNE") == "1"
314 else matmul_get_configs(),
315 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
316 strategy=runtime.get_expand_config(
317 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME
318 )["strategy"]
319 if os.environ.get("USE_FLAGTUNE") == "1"
320 else ["align32", "align32", "align32", "align32", "align32", "default"],
321 warmup=5,
322 rep=5,
323)
324@triton.jit
325def mm_kernel_general_host_tma(
326 a_desc,
327 b_desc,
328 c_desc,
329 M,
330 N,
331 K,
332 stride_am,
333 stride_ak,
334 stride_bk,
335 stride_bn,
336 stride_cm,
337 stride_cn,
338 BLOCK_M: tl.constexpr,
339 BLOCK_N: tl.constexpr,
340 BLOCK_K: tl.constexpr,
341 GROUP_M: tl.constexpr,
342 A_ROW_MAJOR: tl.constexpr,
343 B_ROW_MAJOR: tl.constexpr,
344 dtype: tl.constexpr,
345 enable_warp_specialization=True,
346):
347 pid = tl.program_id(0)
348 grid_m = tl.cdiv(M, BLOCK_M)
349 grid_n = tl.cdiv(N, BLOCK_N)
351 width = GROUP_M * grid_n
352 group_id = pid // width
353 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
354 pid_m = group_id * GROUP_M + (pid % group_size)
355 pid_n = (pid % width) // (group_size)
357 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
358 offset_am = (pid_m * BLOCK_M).to(tl.int32)
359 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
360 iters = tl.cdiv(K, BLOCK_K)
361 for k in range(iters):
362 offset_ak = (k * BLOCK_K).to(tl.int32)
364 if A_ROW_MAJOR:
365 a = a_desc.load([offset_am, offset_ak])
366 else:
367 a_t = a_desc.load([offset_ak, offset_am])
368 a = tl.trans(a_t)
370 if B_ROW_MAJOR:
371 b = b_desc.load([offset_ak, offset_bn])
372 else:
373 b_t = b_desc.load([offset_bn, offset_ak])
374 b = tl.trans(b_t)
376 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
377 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
378 else:
379 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
381 c = accumulator.to(c_desc.dtype)
382 c_desc.store([offset_am, offset_bn], c)
385def get_higher_dtype(a, b):
386 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
388 if a is b:
389 return a
391 assert a in _ordered_datatypes
392 assert b in _ordered_datatypes
394 for d in _ordered_datatypes:
395 if a is d:
396 return b
397 if b is d:
398 return a
401def general_mm(a, b, c, M, N, K, op_name="mm"):
402 # TODO: Remove this debug message
403 logger.debug(
404 "GEMS MM-hopper, [op]: %s, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
405 "[A column-major]: %s, [B column-major]: %s",
406 op_name,
407 M,
408 N,
409 K,
410 a.stride(0) == 1,
411 b.stride(0) == 1,
412 )
413 # Broadcast tensors from expand() have stride=0, incompatible with TMA
414 if 0 in a.stride():
415 a = a.contiguous()
416 if 0 in b.stride():
417 b = b.contiguous()
418 grid = lambda META: (
419 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
420 )
421 if hasattr(
422 triton.tools.tensor_descriptor, "TensorDescriptor"
423 ) and is_tma_compatible(a, b, N, K):
424 a_row_major = a.stride(1) == 1
425 b_row_major = b.stride(1) == 1
426 dummy_block = [1, 1]
427 # triton 3.5.0
428 from triton.tools.tensor_descriptor import TensorDescriptor
430 if a_row_major:
431 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
432 else:
433 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
434 if b_row_major:
435 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
436 else:
437 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
438 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
440 input_dtype = a.dtype
441 dtype_str = str(input_dtype).split(".")[-1]
443 with torch_device_fn.device(a.device):
444 mm_kernel_general_host_tma[grid](
445 a_desc,
446 b_desc,
447 c_desc,
448 M,
449 N,
450 K,
451 a.stride(0),
452 a.stride(1),
453 b.stride(0),
454 b.stride(1),
455 c.stride(0),
456 c.stride(1),
457 GROUP_M=8,
458 A_ROW_MAJOR=a_row_major,
459 B_ROW_MAJOR=b_row_major,
460 dtype=dtype_str,
461 )
462 else:
464 def alloc_fn(size: int, align: int, stream: Optional[int]):
465 return torch.empty(size, dtype=torch.int8, device=a.device)
467 triton.set_allocator(alloc_fn)
469 with torch_device_fn.device(a.device):
470 mm_kernel_general[grid](
471 a,
472 b,
473 c,
474 M,
475 N,
476 K,
477 a.stride(0),
478 a.stride(1),
479 b.stride(0),
480 b.stride(1),
481 c.stride(0),
482 c.stride(1),
483 GROUP_M=8,
484 IS_FP64=a.dtype == torch.float64,
485 )
486 return c
489@libentry()
490@libtuner(
491 configs=runtime.ops_get_configs(
492 "gemv", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME
493 )
494 if os.environ.get("USE_FLAGTUNE") == "1"
495 else [
496 triton.Config(
497 {"BLOCK_M": 32, "BLOCK_K": 256},
498 )
499 ],
500 key=["M", "K", "stride_am", "stride_bk"],
501 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[
502 "strategy"
503 ]
504 if os.environ.get("USE_FLAGTUNE") == "1"
505 else ["align32", "align32", "align32", "default"],
506 warmup=5,
507 rep=10,
508)
509@triton.jit
510def gemv_kernel(
511 A,
512 B,
513 C,
514 M,
515 K,
516 stride_am,
517 stride_ak,
518 stride_bk,
519 BLOCK_M: tl.constexpr,
520 BLOCK_K: tl.constexpr,
521 IS_FP64: tl.constexpr = False,
522):
523 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
524 pid = tl.program_id(0)
526 # Each program handles BLOCK_M rows
527 row_start = pid * BLOCK_M
528 row_offset = row_start + tl.arange(0, BLOCK_M)
529 row_mask = row_offset < M
531 # Accumulator for this block of rows
532 if IS_FP64:
533 acc = tl.zeros((BLOCK_M,), dtype=tl.float64)
534 else:
535 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
537 # Iterate over K dimension
538 for k_start in range(0, K, BLOCK_K):
539 k_offset = k_start + tl.arange(0, BLOCK_K)
540 k_mask = k_offset < K
542 # Load block from matrix A: [BLOCK_M, BLOCK_K]
543 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
544 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
546 # Load block from vector B: [BLOCK_K]
547 b_ptrs = B + k_offset * stride_bk
548 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
550 # Accumulate: sum over K dimension
551 if IS_FP64:
552 acc += tl.sum(a * b[None, :], axis=1)
553 else:
554 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1)
556 # Store result
557 c_ptrs = C + row_offset
558 acc = acc.to(C.dtype.element_ty)
559 tl.store(c_ptrs, acc, mask=row_mask)
562def gemv_mm(a, b, c, M, K):
563 """Optimized matrix-vector multiplication for N=1 case"""
564 logger.debug(
565 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
566 M,
567 K,
568 )
570 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
572 with torch_device_fn.device(a.device):
573 gemv_kernel[grid](
574 a,
575 b,
576 c,
577 M,
578 K,
579 a.stride(0),
580 a.stride(1),
581 b.stride(0),
582 IS_FP64=a.dtype == torch.float64,
583 )
584 return c
587@libentry()
588@libtuner(
589 configs=runtime.ops_get_configs(
590 "mm_splitk",
591 pre_hook=None,
592 yaml_path=EXPAND_CONFIG_FILENAME,
593 )
594 if os.environ.get("USE_FLAGTUNE") == "1"
595 else runtime.get_tuned_config("mm_splitk"),
596 key=["M", "N", "K", "stride_am", "stride_bk"],
597 reset_to_zero=["C"],
598 strategy=runtime.get_expand_config("mm_splitk", yaml_path=EXPAND_CONFIG_FILENAME)[
599 "strategy"
600 ]
601 if os.environ.get("USE_FLAGTUNE") == "1"
602 else ["align32", "align32", "align32", "align32", "align32"],
603 warmup=5,
604 rep=10,
605)
606@triton.jit
607def mm_kernel_splitk(
608 A,
609 B,
610 C,
611 M,
612 N,
613 K,
614 stride_am,
615 stride_ak,
616 stride_bk,
617 stride_bn,
618 stride_cm,
619 stride_cn,
620 BLOCK_M: tl.constexpr,
621 BLOCK_N: tl.constexpr,
622 BLOCK_K: tl.constexpr,
623 SPLIT_K: tl.constexpr,
624):
625 pid = tl.program_id(0)
626 pid_k = tl.program_id(1)
628 grid_n = tl.cdiv(N, BLOCK_N)
629 pid_m = pid // grid_n
630 pid_n = pid % grid_n
632 offset_am = pid_m * BLOCK_M
633 offset_bn = pid_n * BLOCK_N
634 offs_am = offset_am + tl.arange(0, BLOCK_M)
635 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
637 total_k_iters = tl.cdiv(K, BLOCK_K)
638 k_per_split = tl.cdiv(total_k_iters, SPLIT_K)
639 k_start = pid_k * k_per_split
640 k_end = min((pid_k + 1) * k_per_split, total_k_iters)
642 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
643 for k in range(k_start, k_end):
644 offset_k = k * BLOCK_K
645 offs_k = offset_k + tl.arange(0, BLOCK_K)
647 a = tl.load(
648 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak,
649 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K),
650 other=0.0,
651 )
652 b = tl.load(
653 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn,
654 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N),
655 other=0.0,
656 )
657 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
659 offs_cm = offset_am + tl.arange(0, BLOCK_M)
660 offs_cn = offset_bn + tl.arange(0, BLOCK_N)
661 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
662 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
663 tl.atomic_add(c_ptrs, acc, mask=mask)
666def splitk_mm(a, b, c, M, N, K, op_name="mm"):
667 logger.debug(
668 "GEMS MM-hopper, [op]: %s, [mm scenario]: splitk, [shape info]: [-, %s, %s, %s](batch, M, N, K)",
669 op_name,
670 M,
671 N,
672 K,
673 )
674 grid = lambda META: (
675 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
676 META["SPLIT_K"],
677 )
678 with torch_device_fn.device(a.device):
679 mm_kernel_splitk[grid](
680 a,
681 b,
682 c,
683 M,
684 N,
685 K,
686 a.stride(0),
687 a.stride(1),
688 b.stride(0),
689 b.stride(1),
690 c.stride(0),
691 c.stride(1),
692 )
693 return c
696def streamk_scenario(a, b, M, N, K):
697 # TODO: this my change sometime according to the realbenchmark result
698 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
699 # The optimal settings for other devices need to be determined through real testing.
700 capability = get_device_capability()
701 return (
702 capability[0] == 8
703 and a.dtype in [torch.float16, torch.bfloat16]
704 and b.dtype in [torch.float16, torch.bfloat16]
705 and a.is_contiguous()
706 and b.is_contiguous()
707 and K > M * 5
708 and K > N * 5
709 )
712if HAS_TLE:
714 @triton.jit
715 def _cluster_remote_gemm_kernel(
716 a_ptr,
717 b_ptr,
718 c_ptr,
719 M,
720 N,
721 K,
722 stride_am,
723 stride_ak,
724 stride_bk,
725 stride_bn,
726 stride_cm,
727 stride_cn,
728 mesh: tl.constexpr,
729 BM: tl.constexpr,
730 BN: tl.constexpr,
731 BK: tl.constexpr,
732 DOT_K: tl.constexpr,
733 CLUSTER_SIZE: tl.constexpr,
734 USE_MASK: tl.constexpr,
735 A_SLOTS: tl.constexpr,
736 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr,
737 ):
738 pid = tl.program_id(0)
739 cluster_rank = tle_exp.shard_id(mesh, "cluster_x")
740 cluster_id = pid // CLUSTER_SIZE
742 num_pid_n = tl.cdiv(N, BN)
743 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE)
744 pid_m = cluster_id // num_pid_n_group
745 pid_ng = cluster_id % num_pid_n_group
746 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank
748 offs_m = pid_m * BM + tl.arange(0, BM)
749 offs_n = pid_n * BN + tl.arange(0, BN)
750 offs_k = tl.arange(0, BK)
751 a_row_base = offs_m - pid_m * BM
752 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK))
753 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
754 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM))
755 a_buf = tle_exp.gpu.alloc(
756 [A_SLOTS, BM, BK],
757 dtype=tl.float16,
758 layout=None,
759 scope=tle_exp.gpu.smem,
760 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT,
761 )
762 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh)
764 acc = tl.zeros((BM, BN), dtype=tl.float32)
765 slot0 = 0
766 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0
767 if cluster_rank == 0:
768 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
769 if USE_MASK:
770 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K)
771 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
772 else:
773 a_tile = tl.load(a_ptrs)
774 a_local_ptr_tile = tle_exp.gpu.local_ptr(
775 a_buf, (slot0_full, a_rows_full, a_cols_full)
776 )
777 if USE_MASK:
778 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
779 else:
780 tl.store(a_local_ptr_tile, a_tile)
782 tle_exp.distributed_barrier(mesh)
784 for k0 in range(0, K, BK):
785 iter_idx = k0 // BK
786 slot = iter_idx % A_SLOTS
788 for ks in range(0, BK, DOT_K):
789 k_local = ks + tl.arange(0, DOT_K)
790 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM))
791 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot
792 a_ptr_remote = tle_exp.gpu.local_ptr(
793 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t)
794 )
795 if USE_MASK:
796 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M)
797 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0))
798 else:
799 a = tl.trans(tl.load(a_ptr_remote))
801 b_ptrs = (
802 b_ptr
803 + (k0 + k_local)[:, None] * stride_bk
804 + offs_n[None, :] * stride_bn
805 )
806 if USE_MASK:
807 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N)
808 b = tl.load(b_ptrs, mask=b_mask, other=0.0)
809 else:
810 b = tl.load(b_ptrs)
811 acc = tl.dot(a, b, acc)
813 if A_SLOTS == 1:
814 tle_exp.distributed_barrier(mesh)
816 next_k0 = k0 + BK
817 has_next = next_k0 < K
818 next_iter = iter_idx + 1
819 next_slot = next_iter % A_SLOTS
820 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot
821 if has_next and cluster_rank == 0:
822 a_ptrs = (
823 a_ptr
824 + offs_m[:, None] * stride_am
825 + (next_k0 + offs_k)[None, :] * stride_ak
826 )
827 if USE_MASK:
828 a_mask_tile = (offs_m[:, None] < M) & (
829 (next_k0 + offs_k)[None, :] < K
830 )
831 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
832 else:
833 a_tile = tl.load(a_ptrs)
834 a_local_ptr_tile = tle_exp.gpu.local_ptr(
835 a_buf, (next_slot_full, a_rows_full, a_cols_full)
836 )
837 if USE_MASK:
838 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
839 else:
840 tl.store(a_local_ptr_tile, a_tile)
842 tle_exp.distributed_barrier(mesh)
844 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
845 if USE_MASK:
846 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
847 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
848 else:
849 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty))
852def _select_remote_dot_k(bk: int) -> int:
853 if bk % 16 == 0:
854 return 16
855 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}")
858def _grid_cluster_remote(
859 M: int,
860 N: int,
861 BM: int,
862 BN: int,
863 cluster_size: int = TLE_CLUSTER_SIZE,
864) -> tuple:
865 num_pid_n = triton.cdiv(N, BN)
866 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size)
867 return (triton.cdiv(M, BM) * num_pid_n_group,)
870def _run_cluster_remote(
871 a: torch.Tensor,
872 b: torch.Tensor,
873 c: torch.Tensor,
874 bm: int,
875 bn: int,
876 bk: int,
877 num_warps: int,
878 num_stages: int,
879) -> None:
880 M, K = a.shape
881 N = b.shape[1]
882 dot_k = _select_remote_dot_k(bk)
883 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0)
884 a_slots = TLE_REMOTE_A_SLOTS
885 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2)
886 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)](
887 a,
888 b,
889 c,
890 M,
891 N,
892 K,
893 a.stride(0),
894 a.stride(1),
895 b.stride(0),
896 b.stride(1),
897 c.stride(0),
898 c.stride(1),
899 mesh=BLOCK_CLUSTER_MESH,
900 BM=bm,
901 BN=bn,
902 BK=bk,
903 DOT_K=dot_k,
904 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
905 USE_MASK=use_mask,
906 A_SLOTS=a_slots,
907 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout,
908 num_ctas=1,
909 num_warps=num_warps,
910 num_stages=num_stages,
911 )
914def cluster_remote_mm_scenario(a, b, c, M, N, K):
915 capability = get_device_capability()
916 return (
917 HAS_TLE
918 and BLOCK_CLUSTER_MESH is not None
919 and capability[0] >= 9
920 and a.is_cuda
921 and b.is_cuda
922 and c.is_cuda
923 and a.dtype == torch.float16
924 and b.dtype == torch.float16
925 and c.dtype == torch.float16
926 and a.is_contiguous()
927 and b.is_contiguous()
928 and M >= TLE_REMOTE_BM
929 and N >= TLE_REMOTE_BN
930 and K >= TLE_REMOTE_BK
931 )
934def cluster_remote_mm(a, b, c, M, N, K):
935 logger.debug(
936 M,
937 N,
938 K,
939 a.stride(0) == 1,
940 b.stride(0) == 1,
941 )
942 with torch_device_fn.device(a.device):
943 _run_cluster_remote(
944 a,
945 b,
946 c,
947 TLE_REMOTE_BM,
948 TLE_REMOTE_BN,
949 TLE_REMOTE_BK,
950 TLE_REMOTE_NUM_WARPS,
951 TLE_REMOTE_NUM_STAGES,
952 )
953 return c
956def mm(a, b):
957 device = a.device
958 # handle non-contiguous inputs if necessary
959 if a.stride(0) > 1 and a.stride(1) > 1:
960 a = a.contiguous()
961 if b.stride(0) > 1 and b.stride(1) > 1:
962 b = b.contiguous()
963 # checks constraints
964 assert a.shape[1] == b.shape[0], "incompatible dimensions"
965 M, K = a.shape
966 _, N = b.shape
967 # allocates output
968 c_dtype = get_higher_dtype(a.dtype, b.dtype)
969 c = torch.empty((M, N), device=device, dtype=c_dtype)
971 # Optimize for N=1 case (matrix-vector multiplication)
972 if N == 1:
973 return gemv_mm(a, b, c, M, K)
974 # l2_cache_size = get_l2_cache_size()
975 sm_count = get_sm_count()
976 if streamk_scenario(a, b, M, N, K):
977 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
978 if HAS_TLE and BLOCK_CLUSTER_MESH is not None:
979 if cluster_remote_mm_scenario(a, b, c, M, N, K):
980 return cluster_remote_mm(a, b, c, M, N, K)
981 # Use splitk for small M
982 if M < 2048 and N < 2048 and K >= 4096:
983 c.zero_()
984 return splitk_mm(a, b, c, M, N, K)
985 return general_mm(a, b, c, M, N, K)
988def mm_out(a, b, *, out):
989 # handle non-contiguous inputs if necessary
990 if a.stride(0) > 1 and a.stride(1) > 1:
991 a = a.contiguous()
992 if b.stride(0) > 1 and b.stride(1) > 1:
993 b = b.contiguous()
994 # checks constraints
995 assert a.shape[1] == b.shape[0], "incompatible dimensions"
996 M, K = a.shape
997 _, N = b.shape
999 # Optimize for N=1 case (matrix-vector multiplication)
1000 if N == 1:
1001 return gemv_mm(a, b, out, M, K)
1002 # l2_cache_size = get_l2_cache_size()
1003 sm_count = get_sm_count()
1004 if streamk_scenario(a, b, M, N, K):
1005 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
1006 if HAS_TLE and BLOCK_CLUSTER_MESH is not None:
1007 if cluster_remote_mm_scenario(a, b, out, M, N, K):
1008 return cluster_remote_mm(a, b, out, M, N, K)
1009 # Use splitk for small M
1010 if M < 2048 and N < 2048 and K >= 4096:
1011 out.zero_()
1012 return splitk_mm(a, b, out, M, N, K)
1013 return general_mm(a, b, out, M, N, K)
1016def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1017 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K)."""
1018 if x.stride(0) > 1 and x.stride(1) > 1:
1019 x = x.contiguous()
1020 M, K = x.shape
1021 N = weight.shape[0]
1022 c = torch.empty((M, N), device=x.device, dtype=torch.float32)
1023 b = weight.t()
1024 if M < 2048 and N < 2048 and K >= 4096:
1025 c.zero_()
1026 return splitk_mm(x, b, c, M, N, K, op_name="router_gemm")
1027 return general_mm(x, b, c, M, N, K, op_name="router_gemm")