Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%
432 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.mm_streamk import streamk_mm
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry, libtuner
13from flag_gems.utils import triton_lang_extension as ext
14from flag_gems.utils.device_info import get_device_capability, get_sm_count
15from flag_gems.utils.triton_version_utils import HAS_TLE, HAS_TLE_DEVICE_MESH
17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm")
18CACHE_USAGE_THRESHOLD = 0.8
19EXPAND_CONFIG_FILENAME = os.path.normpath(
20 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml")
21)
22_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024
25def _get_shared_memory_limit_bytes():
26 """Return per-block opt-in shared-memory limit for current CUDA device."""
27 try:
28 if not torch.cuda.is_available():
29 return None
30 return torch.cuda.get_device_properties(
31 torch.cuda.current_device()
32 ).shared_memory_per_block_optin
33 except Exception:
34 return None
37def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages):
38 bytes_per_element = 4
39 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element
40 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES
43if HAS_TLE_DEVICE_MESH:
44 import triton.experimental.tle.language as tle_exp
46 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]})
47 TLE_CLUSTER_SIZE = 2
48 TLE_REMOTE_BM = 64
49 TLE_REMOTE_BN = 256
50 TLE_REMOTE_BK = 64
51 TLE_REMOTE_NUM_WARPS = 8
52 TLE_REMOTE_NUM_STAGES = 2
53 TLE_REMOTE_A_SLOTS = 2
54else:
55 tle_exp = None
56 BLOCK_CLUSTER_MESH = None
57 TLE_CLUSTER_SIZE = 2
58 TLE_REMOTE_BM = 64
59 TLE_REMOTE_BN = 256
60 TLE_REMOTE_BK = 64
61 TLE_REMOTE_NUM_WARPS = 8
62 TLE_REMOTE_NUM_STAGES = 2
63 TLE_REMOTE_A_SLOTS = 2
66def is_tma_compatible(a, b, N, K):
67 """
68 Check if tensors are compatible with TMA (Tensor Memory Accelerator).
70 TMA requires 128-bit (16-byte) alignment for memory access:
71 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8
72 (8 elements × 2 bytes = 16 bytes)
73 - For FP32 (4 bytes/element): N and K must be multiples of 4
74 (4 elements × 4 bytes = 16 bytes)
76 Args:
77 a, b: Input tensors
78 N, K: Matrix dimensions
80 Returns:
81 bool: True if compatible with TMA's alignment requirements
82 """
83 return (
84 a.dtype in (torch.float16, torch.bfloat16)
85 and b.dtype in (torch.float16, torch.bfloat16)
86 and N % 8 == 0
87 and K % 8 == 0
88 ) or (
89 a.dtype in (torch.float32,)
90 and b.dtype in (torch.float32,)
91 and N % 4 == 0
92 and K % 4 == 0
93 )
96@triton.jit
97def prev_multiple_of(a, b):
98 # the largest x<a that x%b ==0
99 return tl.cdiv(a, b) * b - b
102def matmul_tma_set_block_size_hook(nargs):
103 BLOCK_M = nargs["BLOCK_M"]
104 BLOCK_N = nargs["BLOCK_N"]
105 BLOCK_K = nargs["BLOCK_K"]
106 if nargs["A_ROW_MAJOR"]:
107 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
108 else:
109 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M]
111 if nargs["B_ROW_MAJOR"]:
112 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
113 else:
114 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K]
116 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
119@libentry()
120@libtuner(
121 configs=runtime.get_tuned_config("mm"),
122 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
123 key=["M", "N", "K", "stride_am", "stride_bk"],
124 strategy=["default", "default", "default", "default", "default"],
125 warmup=5,
126 rep=10,
127)
128@triton.jit
129def mm_kernel_general(
130 A,
131 B,
132 C,
133 M,
134 N,
135 K,
136 stride_am,
137 stride_ak,
138 stride_bk,
139 stride_bn,
140 stride_cm,
141 stride_cn,
142 BLOCK_M: tl.constexpr,
143 BLOCK_N: tl.constexpr,
144 BLOCK_K: tl.constexpr,
145 GROUP_M: tl.constexpr,
146 IS_FP64: tl.constexpr = False,
147):
148 # matrix multiplication
149 pid = ext.program_id(0)
150 grid_m = tl.cdiv(M, BLOCK_M)
151 grid_n = tl.cdiv(N, BLOCK_N)
152 # re-order program ID for better L2 performance
153 width = GROUP_M * grid_n
154 group_id = pid // width
155 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
156 pid_m = group_id * GROUP_M + (pid % group_size)
157 pid_n = (pid % width) // (group_size)
159 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
160 # offset
161 offset_am = pid_m * BLOCK_M
162 offset_bn = pid_n * BLOCK_N
163 offset_k = 0
165 a_desc = tl.make_tensor_descriptor(
166 base=A,
167 shape=[M, K],
168 strides=[K, 1],
169 block_shape=[BLOCK_M, BLOCK_K],
170 )
172 # row-major
173 b_desc = tl.make_tensor_descriptor(
174 base=B,
175 shape=[K, N],
176 strides=[N, 1],
177 block_shape=[BLOCK_K, BLOCK_N],
178 )
180 # column-major
181 # b_desc = tl.make_tensor_descriptor(
182 # B,
183 # shape = [N, K],
184 # strides = [K, 1],
185 # block_shape = [BLOCK_N, BLOCK_K],
186 # )
188 c_desc = tl.make_tensor_descriptor(
189 base=C,
190 shape=[M, N],
191 strides=[N, 1],
192 block_shape=[BLOCK_M, BLOCK_N],
193 )
195 if IS_FP64:
196 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
197 else:
198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
199 for k in range(0, tl.cdiv(K, BLOCK_K)):
200 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
201 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
202 if IS_FP64:
203 acc += tl.dot(a, b, allow_tf32=False)
204 else:
205 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
206 offset_k += BLOCK_K
208 acc = acc.to(a_desc.dtype)
209 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)
211 else:
212 # do matrix multiplication
213 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
214 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
215 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
216 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
217 rm = rm.to(tl.int64)
218 rn = rn.to(tl.int64)
219 prev_multiple = prev_multiple_of(K, BLOCK_K)
221 if IS_FP64:
222 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
223 else:
224 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
225 for start_k in range(0, prev_multiple, BLOCK_K):
226 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
227 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
228 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
229 if a.dtype != b.dtype:
230 a = a.to(C.dtype.element_ty)
231 b = b.to(C.dtype.element_ty)
232 if IS_FP64:
233 acc += tl.dot(a, b, allow_tf32=False)
234 else:
235 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
237 # loop peeling
238 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
239 mask_k = rk < K
240 a = tl.load(
241 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
242 mask=mask_k[None, :],
243 other=0.0,
244 )
245 b = tl.load(
246 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
247 mask=mask_k[:, None],
248 other=0.0,
249 )
250 if a.dtype != b.dtype:
251 a = a.to(C.dtype.element_ty)
252 b = b.to(C.dtype.element_ty)
253 if IS_FP64:
254 acc += tl.dot(a, b, allow_tf32=False)
255 else:
256 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
258 acc = acc.to(C.dtype.element_ty)
259 # rematerialize rm and rn to save registers
260 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
261 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
262 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
263 mask = (rm < M)[:, None] & (rn < N)[None, :]
264 # handles write-back with reduction-splitting
265 tl.store(offsets, acc, mask=mask)
268def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook):
269 configs = [
270 triton.Config(
271 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK, "GROUP_M": 8},
272 num_stages=s,
273 num_warps=w,
274 pre_hook=pre_hook,
275 )
276 for BM in [32, 64, 128, 256]
277 for BN in [32, 64, 128]
278 for BK in [32, 64, 128]
279 for s in [2, 3, 4]
280 for w in [4, 8]
281 ]
282 shared_mem_limit = _get_shared_memory_limit_bytes()
283 if shared_mem_limit is None:
284 return configs
286 filtered_configs = [
287 cfg
288 for cfg in configs
289 if _estimate_tma_shared_memory_bytes(
290 cfg.kwargs["BLOCK_M"],
291 cfg.kwargs["BLOCK_N"],
292 cfg.kwargs["BLOCK_K"],
293 cfg.num_stages,
294 )
295 <= shared_mem_limit
296 ]
297 if not filtered_configs:
298 logger.warning(
299 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.",
300 shared_mem_limit,
301 )
302 return configs
303 return filtered_configs
306@libentry()
307@libtuner(
308 configs=matmul_get_configs(),
309 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"],
310 strategy=["align32", "align32", "align32", "align32", "align32", "default"],
311 warmup=5,
312 rep=5,
313 flagtune_op_name="mm",
314 flagtune_expand_op_name="mm_general_tma",
315 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
316 flagtune_pre_hook=matmul_tma_set_block_size_hook,
317)
318@triton.jit
319def mm_kernel_general_host_tma(
320 a_desc,
321 b_desc,
322 c_desc,
323 M,
324 N,
325 K,
326 stride_am,
327 stride_ak,
328 stride_bk,
329 stride_bn,
330 stride_cm,
331 stride_cn,
332 BLOCK_M: tl.constexpr,
333 BLOCK_N: tl.constexpr,
334 BLOCK_K: tl.constexpr,
335 GROUP_M: tl.constexpr,
336 A_ROW_MAJOR: tl.constexpr,
337 B_ROW_MAJOR: tl.constexpr,
338 dtype: tl.constexpr,
339 enable_warp_specialization=True,
340):
341 pid = tl.program_id(0)
342 grid_m = tl.cdiv(M, BLOCK_M)
343 grid_n = tl.cdiv(N, BLOCK_N)
345 width = GROUP_M * grid_n
346 group_id = pid // width
347 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
348 pid_m = group_id * GROUP_M + (pid % group_size)
349 pid_n = (pid % width) // (group_size)
351 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
352 offset_am = (pid_m * BLOCK_M).to(tl.int32)
353 offset_bn = (pid_n * BLOCK_N).to(tl.int32)
354 iters = tl.cdiv(K, BLOCK_K)
355 for k in range(iters):
356 offset_ak = (k * BLOCK_K).to(tl.int32)
358 if A_ROW_MAJOR:
359 a = a_desc.load([offset_am, offset_ak])
360 else:
361 a_t = a_desc.load([offset_ak, offset_am])
362 a = tl.trans(a_t)
364 if B_ROW_MAJOR:
365 b = b_desc.load([offset_ak, offset_bn])
366 else:
367 b_t = b_desc.load([offset_bn, offset_ak])
368 b = tl.trans(b_t)
370 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16:
371 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
372 else:
373 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3")
375 c = accumulator.to(c_desc.dtype)
376 c_desc.store([offset_am, offset_bn], c)
379def get_higher_dtype(a, b):
380 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
382 if a is b:
383 return a
385 assert a in _ordered_datatypes
386 assert b in _ordered_datatypes
388 for d in _ordered_datatypes:
389 if a is d:
390 return b
391 if b is d:
392 return a
395def general_mm(a, b, c, M, N, K, op_name="mm"):
396 # TODO: Remove this debug message
397 logger.debug(
398 "GEMS MM-hopper, [op]: %s, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
399 "[A column-major]: %s, [B column-major]: %s",
400 op_name,
401 M,
402 N,
403 K,
404 a.stride(0) == 1,
405 b.stride(0) == 1,
406 )
407 # Broadcast tensors from expand() have stride=0, incompatible with TMA
408 if 0 in a.stride():
409 a = a.contiguous()
410 if 0 in b.stride():
411 b = b.contiguous()
412 grid = lambda META: (
413 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
414 )
415 if hasattr(
416 triton.tools.tensor_descriptor, "TensorDescriptor"
417 ) and is_tma_compatible(a, b, N, K):
418 a_row_major = a.stride(1) == 1
419 b_row_major = b.stride(1) == 1
420 dummy_block = [1, 1]
421 # triton 3.5.0
422 from triton.tools.tensor_descriptor import TensorDescriptor
424 if a_row_major:
425 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block)
426 else:
427 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block)
428 if b_row_major:
429 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block)
430 else:
431 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block)
432 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block)
434 input_dtype = a.dtype
435 dtype_str = str(input_dtype).split(".")[-1]
437 with torch_device_fn.device(a.device):
438 mm_kernel_general_host_tma[grid](
439 a_desc,
440 b_desc,
441 c_desc,
442 M,
443 N,
444 K,
445 a.stride(0),
446 a.stride(1),
447 b.stride(0),
448 b.stride(1),
449 c.stride(0),
450 c.stride(1),
451 A_ROW_MAJOR=a_row_major,
452 B_ROW_MAJOR=b_row_major,
453 dtype=dtype_str,
454 )
455 else:
457 def alloc_fn(size: int, align: int, stream: Optional[int]):
458 return torch.empty(size, dtype=torch.int8, device=a.device)
460 triton.set_allocator(alloc_fn)
462 with torch_device_fn.device(a.device):
463 mm_kernel_general[grid](
464 a,
465 b,
466 c,
467 M,
468 N,
469 K,
470 a.stride(0),
471 a.stride(1),
472 b.stride(0),
473 b.stride(1),
474 c.stride(0),
475 c.stride(1),
476 GROUP_M=8,
477 IS_FP64=a.dtype == torch.float64,
478 )
479 return c
482@libentry()
483@libtuner(
484 configs=[
485 triton.Config(
486 {"BLOCK_M": 32, "BLOCK_K": 256},
487 )
488 ],
489 key=["M", "K", "stride_am", "stride_bk"],
490 strategy=["align32", "align32", "align32", "default"],
491 warmup=5,
492 rep=10,
493 flagtune_op_name="mm",
494 flagtune_expand_op_name="gemv",
495 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
496 flagtune_pre_hook=None,
497)
498@triton.jit
499def gemv_kernel(
500 A,
501 B,
502 C,
503 M,
504 K,
505 stride_am,
506 stride_ak,
507 stride_bk,
508 BLOCK_M: tl.constexpr,
509 BLOCK_K: tl.constexpr,
510 IS_FP64: tl.constexpr = False,
511):
512 """Optimized kernel for matrix-vector multiplication (N=1 case)"""
513 pid = tl.program_id(0)
515 # Each program handles BLOCK_M rows
516 row_start = pid * BLOCK_M
517 row_offset = row_start + tl.arange(0, BLOCK_M)
518 row_mask = row_offset < M
520 # Accumulator for this block of rows
521 if IS_FP64:
522 acc = tl.zeros((BLOCK_M,), dtype=tl.float64)
523 else:
524 acc = tl.zeros((BLOCK_M,), dtype=tl.float32)
526 # Iterate over K dimension
527 for k_start in range(0, K, BLOCK_K):
528 k_offset = k_start + tl.arange(0, BLOCK_K)
529 k_mask = k_offset < K
531 # Load block from matrix A: [BLOCK_M, BLOCK_K]
532 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak
533 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
535 # Load block from vector B: [BLOCK_K]
536 b_ptrs = B + k_offset * stride_bk
537 b = tl.load(b_ptrs, mask=k_mask, other=0.0)
539 # Accumulate: sum over K dimension
540 if IS_FP64:
541 acc += tl.sum(a * b[None, :], axis=1)
542 else:
543 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1)
545 # Store result
546 c_ptrs = C + row_offset
547 acc = acc.to(C.dtype.element_ty)
548 tl.store(c_ptrs, acc, mask=row_mask)
551def gemv_mm(a, b, c, M, K):
552 """Optimized matrix-vector multiplication for N=1 case"""
553 logger.debug(
554 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)",
555 M,
556 K,
557 )
559 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),)
561 with torch_device_fn.device(a.device):
562 gemv_kernel[grid](
563 a,
564 b,
565 c,
566 M,
567 K,
568 a.stride(0),
569 a.stride(1),
570 b.stride(0),
571 IS_FP64=a.dtype == torch.float64,
572 )
573 return c
576@libentry()
577@libtuner(
578 configs=runtime.get_tuned_config("mm_splitk"),
579 key=["M", "N", "K", "stride_am", "stride_bk"],
580 reset_to_zero=["C"],
581 strategy=["align32", "align32", "align32", "align32", "align32"],
582 warmup=5,
583 rep=10,
584 flagtune_op_name="mm",
585 flagtune_expand_op_name="mm_splitk",
586 flagtune_yaml_path=EXPAND_CONFIG_FILENAME,
587 flagtune_pre_hook=None,
588)
589@triton.jit
590def mm_kernel_splitk(
591 A,
592 B,
593 C,
594 M,
595 N,
596 K,
597 stride_am,
598 stride_ak,
599 stride_bk,
600 stride_bn,
601 stride_cm,
602 stride_cn,
603 BLOCK_M: tl.constexpr,
604 BLOCK_N: tl.constexpr,
605 BLOCK_K: tl.constexpr,
606 SPLIT_K: tl.constexpr,
607):
608 pid = tl.program_id(0)
609 pid_k = tl.program_id(1)
611 grid_n = tl.cdiv(N, BLOCK_N)
612 pid_m = pid // grid_n
613 pid_n = pid % grid_n
615 offset_am = pid_m * BLOCK_M
616 offset_bn = pid_n * BLOCK_N
617 offs_am = offset_am + tl.arange(0, BLOCK_M)
618 offs_bn = offset_bn + tl.arange(0, BLOCK_N)
620 total_k_iters = tl.cdiv(K, BLOCK_K)
621 k_per_split = tl.cdiv(total_k_iters, SPLIT_K)
622 k_start = pid_k * k_per_split
623 k_end = min((pid_k + 1) * k_per_split, total_k_iters)
625 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
626 for k in range(k_start, k_end):
627 offset_k = k * BLOCK_K
628 offs_k = offset_k + tl.arange(0, BLOCK_K)
630 a = tl.load(
631 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak,
632 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K),
633 other=0.0,
634 )
635 b = tl.load(
636 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn,
637 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N),
638 other=0.0,
639 )
640 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
642 offs_cm = offset_am + tl.arange(0, BLOCK_M)
643 offs_cn = offset_bn + tl.arange(0, BLOCK_N)
644 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
645 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :]
646 tl.atomic_add(c_ptrs, acc, mask=mask)
649def splitk_mm(a, b, c, M, N, K, op_name="mm"):
650 logger.debug(
651 "GEMS MM-hopper, [op]: %s, [mm scenario]: splitk, [shape info]: [-, %s, %s, %s](batch, M, N, K)",
652 op_name,
653 M,
654 N,
655 K,
656 )
657 grid = lambda META: (
658 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
659 META["SPLIT_K"],
660 )
661 with torch_device_fn.device(a.device):
662 mm_kernel_splitk[grid](
663 a,
664 b,
665 c,
666 M,
667 N,
668 K,
669 a.stride(0),
670 a.stride(1),
671 b.stride(0),
672 b.stride(1),
673 c.stride(0),
674 c.stride(1),
675 )
676 return c
679def streamk_scenario(a, b, M, N, K):
680 # TODO: this my change sometime according to the realbenchmark result
681 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
682 # The optimal settings for other devices need to be determined through real testing.
683 capability = get_device_capability()
684 return (
685 capability[0] == 8
686 and a.dtype in [torch.float16, torch.bfloat16]
687 and b.dtype in [torch.float16, torch.bfloat16]
688 and a.is_contiguous()
689 and b.is_contiguous()
690 and K > M * 5
691 and K > N * 5
692 )
695if HAS_TLE:
697 @triton.jit
698 def _cluster_remote_gemm_kernel(
699 a_ptr,
700 b_ptr,
701 c_ptr,
702 M,
703 N,
704 K,
705 stride_am,
706 stride_ak,
707 stride_bk,
708 stride_bn,
709 stride_cm,
710 stride_cn,
711 mesh: tl.constexpr,
712 BM: tl.constexpr,
713 BN: tl.constexpr,
714 BK: tl.constexpr,
715 DOT_K: tl.constexpr,
716 CLUSTER_SIZE: tl.constexpr,
717 USE_MASK: tl.constexpr,
718 A_SLOTS: tl.constexpr,
719 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr,
720 ):
721 pid = tl.program_id(0)
722 cluster_rank = tle_exp.shard_id(mesh, "cluster_x")
723 cluster_id = pid // CLUSTER_SIZE
725 num_pid_n = tl.cdiv(N, BN)
726 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE)
727 pid_m = cluster_id // num_pid_n_group
728 pid_ng = cluster_id % num_pid_n_group
729 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank
731 offs_m = pid_m * BM + tl.arange(0, BM)
732 offs_n = pid_n * BN + tl.arange(0, BN)
733 offs_k = tl.arange(0, BK)
734 a_row_base = offs_m - pid_m * BM
735 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK))
736 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
737 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM))
738 a_buf = tle_exp.gpu.alloc(
739 [A_SLOTS, BM, BK],
740 dtype=tl.float16,
741 layout=None,
742 scope=tle_exp.gpu.smem,
743 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT,
744 )
745 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh)
747 acc = tl.zeros((BM, BN), dtype=tl.float32)
748 slot0 = 0
749 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0
750 if cluster_rank == 0:
751 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
752 if USE_MASK:
753 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K)
754 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
755 else:
756 a_tile = tl.load(a_ptrs)
757 a_local_ptr_tile = tle_exp.gpu.local_ptr(
758 a_buf, (slot0_full, a_rows_full, a_cols_full)
759 )
760 if USE_MASK:
761 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
762 else:
763 tl.store(a_local_ptr_tile, a_tile)
765 tle_exp.distributed_barrier(mesh)
767 for k0 in range(0, K, BK):
768 iter_idx = k0 // BK
769 slot = iter_idx % A_SLOTS
771 for ks in range(0, BK, DOT_K):
772 k_local = ks + tl.arange(0, DOT_K)
773 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM))
774 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot
775 a_ptr_remote = tle_exp.gpu.local_ptr(
776 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t)
777 )
778 if USE_MASK:
779 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M)
780 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0))
781 else:
782 a = tl.trans(tl.load(a_ptr_remote))
784 b_ptrs = (
785 b_ptr
786 + (k0 + k_local)[:, None] * stride_bk
787 + offs_n[None, :] * stride_bn
788 )
789 if USE_MASK:
790 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N)
791 b = tl.load(b_ptrs, mask=b_mask, other=0.0)
792 else:
793 b = tl.load(b_ptrs)
794 acc = tl.dot(a, b, acc)
796 if A_SLOTS == 1:
797 tle_exp.distributed_barrier(mesh)
799 next_k0 = k0 + BK
800 has_next = next_k0 < K
801 next_iter = iter_idx + 1
802 next_slot = next_iter % A_SLOTS
803 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot
804 if has_next and cluster_rank == 0:
805 a_ptrs = (
806 a_ptr
807 + offs_m[:, None] * stride_am
808 + (next_k0 + offs_k)[None, :] * stride_ak
809 )
810 if USE_MASK:
811 a_mask_tile = (offs_m[:, None] < M) & (
812 (next_k0 + offs_k)[None, :] < K
813 )
814 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0)
815 else:
816 a_tile = tl.load(a_ptrs)
817 a_local_ptr_tile = tle_exp.gpu.local_ptr(
818 a_buf, (next_slot_full, a_rows_full, a_cols_full)
819 )
820 if USE_MASK:
821 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile)
822 else:
823 tl.store(a_local_ptr_tile, a_tile)
825 tle_exp.distributed_barrier(mesh)
827 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
828 if USE_MASK:
829 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
830 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
831 else:
832 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty))
835def _select_remote_dot_k(bk: int) -> int:
836 if bk % 16 == 0:
837 return 16
838 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}")
841def _grid_cluster_remote(
842 M: int,
843 N: int,
844 BM: int,
845 BN: int,
846 cluster_size: int = TLE_CLUSTER_SIZE,
847) -> tuple:
848 num_pid_n = triton.cdiv(N, BN)
849 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size)
850 return (triton.cdiv(M, BM) * num_pid_n_group,)
853def _run_cluster_remote(
854 a: torch.Tensor,
855 b: torch.Tensor,
856 c: torch.Tensor,
857 bm: int,
858 bn: int,
859 bk: int,
860 num_warps: int,
861 num_stages: int,
862) -> None:
863 M, K = a.shape
864 N = b.shape[1]
865 dot_k = _select_remote_dot_k(bk)
866 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0)
867 a_slots = TLE_REMOTE_A_SLOTS
868 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2)
869 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)](
870 a,
871 b,
872 c,
873 M,
874 N,
875 K,
876 a.stride(0),
877 a.stride(1),
878 b.stride(0),
879 b.stride(1),
880 c.stride(0),
881 c.stride(1),
882 mesh=BLOCK_CLUSTER_MESH,
883 BM=bm,
884 BN=bn,
885 BK=bk,
886 DOT_K=dot_k,
887 CLUSTER_SIZE=TLE_CLUSTER_SIZE,
888 USE_MASK=use_mask,
889 A_SLOTS=a_slots,
890 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout,
891 num_ctas=1,
892 num_warps=num_warps,
893 num_stages=num_stages,
894 )
897def cluster_remote_mm_scenario(a, b, c, M, N, K):
898 capability = get_device_capability()
899 return (
900 HAS_TLE
901 and BLOCK_CLUSTER_MESH is not None
902 and capability[0] >= 9
903 and a.is_cuda
904 and b.is_cuda
905 and c.is_cuda
906 and a.dtype == torch.float16
907 and b.dtype == torch.float16
908 and c.dtype == torch.float16
909 and a.is_contiguous()
910 and b.is_contiguous()
911 and M >= TLE_REMOTE_BM
912 and N >= TLE_REMOTE_BN
913 and K >= TLE_REMOTE_BK
914 )
917def cluster_remote_mm(a, b, c, M, N, K):
918 logger.debug(
919 M,
920 N,
921 K,
922 a.stride(0) == 1,
923 b.stride(0) == 1,
924 )
925 with torch_device_fn.device(a.device):
926 _run_cluster_remote(
927 a,
928 b,
929 c,
930 TLE_REMOTE_BM,
931 TLE_REMOTE_BN,
932 TLE_REMOTE_BK,
933 TLE_REMOTE_NUM_WARPS,
934 TLE_REMOTE_NUM_STAGES,
935 )
936 return c
939def mm(a, b):
940 device = a.device
941 # handle non-contiguous inputs if necessary
942 if a.stride(0) > 1 and a.stride(1) > 1:
943 a = a.contiguous()
944 if b.stride(0) > 1 and b.stride(1) > 1:
945 b = b.contiguous()
946 # checks constraints
947 assert a.shape[1] == b.shape[0], "incompatible dimensions"
948 M, K = a.shape
949 _, N = b.shape
950 # allocates output
951 c_dtype = get_higher_dtype(a.dtype, b.dtype)
952 c = torch.empty((M, N), device=device, dtype=c_dtype)
954 # Optimize for N=1 case (matrix-vector multiplication)
955 if N == 1:
956 return gemv_mm(a, b, c, M, K)
957 # l2_cache_size = get_l2_cache_size()
958 sm_count = get_sm_count()
959 if streamk_scenario(a, b, M, N, K):
960 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
961 if HAS_TLE and BLOCK_CLUSTER_MESH is not None:
962 if cluster_remote_mm_scenario(a, b, c, M, N, K):
963 return cluster_remote_mm(a, b, c, M, N, K)
964 # Use splitk for small M
965 if M < 2048 and N < 2048 and K >= 4096:
966 c.zero_()
967 return splitk_mm(a, b, c, M, N, K)
968 return general_mm(a, b, c, M, N, K)
971def mm_out(a, b, *, out):
972 # handle non-contiguous inputs if necessary
973 if a.stride(0) > 1 and a.stride(1) > 1:
974 a = a.contiguous()
975 if b.stride(0) > 1 and b.stride(1) > 1:
976 b = b.contiguous()
977 # checks constraints
978 assert a.shape[1] == b.shape[0], "incompatible dimensions"
979 M, K = a.shape
980 _, N = b.shape
982 # Optimize for N=1 case (matrix-vector multiplication)
983 if N == 1:
984 return gemv_mm(a, b, out, M, K)
985 # l2_cache_size = get_l2_cache_size()
986 sm_count = get_sm_count()
987 if streamk_scenario(a, b, M, N, K):
988 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
989 if HAS_TLE and BLOCK_CLUSTER_MESH is not None:
990 if cluster_remote_mm_scenario(a, b, out, M, N, K):
991 return cluster_remote_mm(a, b, out, M, N, K)
992 # Use splitk for small M
993 if M < 2048 and N < 2048 and K >= 4096:
994 out.zero_()
995 return splitk_mm(a, b, out, M, N, K)
996 return general_mm(a, b, out, M, N, K)
999def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1000 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K)."""
1001 if x.stride(0) > 1 and x.stride(1) > 1:
1002 x = x.contiguous()
1003 M, K = x.shape
1004 N = weight.shape[0]
1005 c = torch.empty((M, N), device=x.device, dtype=torch.float32)
1006 b = weight.t()
1007 if M < 2048 and N < 2048 and K >= 4096:
1008 c.zero_()
1009 return splitk_mm(x, b, c, M, N, K, op_name="router_gemm")
1010 return general_mm(x, b, c, M, N, K, op_name="router_gemm")