Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

432 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.mm_streamk import streamk_mm 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as ext 

14from flag_gems.utils.device_info import get_device_capability, get_sm_count 

15from flag_gems.utils.triton_version_utils import HAS_TLE, HAS_TLE_DEVICE_MESH 

16 

17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm") 

18CACHE_USAGE_THRESHOLD = 0.8 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml") 

21) 

22_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024 

23 

24 

25def _get_shared_memory_limit_bytes(): 

26 """Return per-block opt-in shared-memory limit for current CUDA device.""" 

27 try: 

28 if not torch.cuda.is_available(): 

29 return None 

30 return torch.cuda.get_device_properties( 

31 torch.cuda.current_device() 

32 ).shared_memory_per_block_optin 

33 except Exception: 

34 return None 

35 

36 

37def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages): 

38 bytes_per_element = 4 

39 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element 

40 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES 

41 

42 

43if HAS_TLE_DEVICE_MESH: 

44 import triton.experimental.tle.language as tle_exp 

45 

46 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]}) 

47 TLE_CLUSTER_SIZE = 2 

48 TLE_REMOTE_BM = 64 

49 TLE_REMOTE_BN = 256 

50 TLE_REMOTE_BK = 64 

51 TLE_REMOTE_NUM_WARPS = 8 

52 TLE_REMOTE_NUM_STAGES = 2 

53 TLE_REMOTE_A_SLOTS = 2 

54else: 

55 tle_exp = None 

56 BLOCK_CLUSTER_MESH = None 

57 TLE_CLUSTER_SIZE = 2 

58 TLE_REMOTE_BM = 64 

59 TLE_REMOTE_BN = 256 

60 TLE_REMOTE_BK = 64 

61 TLE_REMOTE_NUM_WARPS = 8 

62 TLE_REMOTE_NUM_STAGES = 2 

63 TLE_REMOTE_A_SLOTS = 2 

64 

65 

66def is_tma_compatible(a, b, N, K): 

67 """ 

68 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

69 

70 TMA requires 128-bit (16-byte) alignment for memory access: 

71 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

72 (8 elements × 2 bytes = 16 bytes) 

73 - For FP32 (4 bytes/element): N and K must be multiples of 4 

74 (4 elements × 4 bytes = 16 bytes) 

75 

76 Args: 

77 a, b: Input tensors 

78 N, K: Matrix dimensions 

79 

80 Returns: 

81 bool: True if compatible with TMA's alignment requirements 

82 """ 

83 return ( 

84 a.dtype in (torch.float16, torch.bfloat16) 

85 and b.dtype in (torch.float16, torch.bfloat16) 

86 and N % 8 == 0 

87 and K % 8 == 0 

88 ) or ( 

89 a.dtype in (torch.float32,) 

90 and b.dtype in (torch.float32,) 

91 and N % 4 == 0 

92 and K % 4 == 0 

93 ) 

94 

95 

96@triton.jit 

97def prev_multiple_of(a, b): 

98 # the largest x<a that x%b ==0 

99 return tl.cdiv(a, b) * b - b 

100 

101 

102def matmul_tma_set_block_size_hook(nargs): 

103 BLOCK_M = nargs["BLOCK_M"] 

104 BLOCK_N = nargs["BLOCK_N"] 

105 BLOCK_K = nargs["BLOCK_K"] 

106 if nargs["A_ROW_MAJOR"]: 

107 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

108 else: 

109 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

110 

111 if nargs["B_ROW_MAJOR"]: 

112 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

113 else: 

114 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

115 

116 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

117 

118 

119@libentry() 

120@libtuner( 

121 configs=runtime.get_tuned_config("mm"), 

122 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

123 key=["M", "N", "K", "stride_am", "stride_bk"], 

124 strategy=["default", "default", "default", "default", "default"], 

125 warmup=5, 

126 rep=10, 

127) 

128@triton.jit 

129def mm_kernel_general( 

130 A, 

131 B, 

132 C, 

133 M, 

134 N, 

135 K, 

136 stride_am, 

137 stride_ak, 

138 stride_bk, 

139 stride_bn, 

140 stride_cm, 

141 stride_cn, 

142 BLOCK_M: tl.constexpr, 

143 BLOCK_N: tl.constexpr, 

144 BLOCK_K: tl.constexpr, 

145 GROUP_M: tl.constexpr, 

146 IS_FP64: tl.constexpr = False, 

147): 

148 # matrix multiplication 

149 pid = ext.program_id(0) 

150 grid_m = tl.cdiv(M, BLOCK_M) 

151 grid_n = tl.cdiv(N, BLOCK_N) 

152 # re-order program ID for better L2 performance 

153 width = GROUP_M * grid_n 

154 group_id = pid // width 

155 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

156 pid_m = group_id * GROUP_M + (pid % group_size) 

157 pid_n = (pid % width) // (group_size) 

158 

159 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

160 # offset 

161 offset_am = pid_m * BLOCK_M 

162 offset_bn = pid_n * BLOCK_N 

163 offset_k = 0 

164 

165 a_desc = tl.make_tensor_descriptor( 

166 base=A, 

167 shape=[M, K], 

168 strides=[K, 1], 

169 block_shape=[BLOCK_M, BLOCK_K], 

170 ) 

171 

172 # row-major 

173 b_desc = tl.make_tensor_descriptor( 

174 base=B, 

175 shape=[K, N], 

176 strides=[N, 1], 

177 block_shape=[BLOCK_K, BLOCK_N], 

178 ) 

179 

180 # column-major 

181 # b_desc = tl.make_tensor_descriptor( 

182 # B, 

183 # shape = [N, K], 

184 # strides = [K, 1], 

185 # block_shape = [BLOCK_N, BLOCK_K], 

186 # ) 

187 

188 c_desc = tl.make_tensor_descriptor( 

189 base=C, 

190 shape=[M, N], 

191 strides=[N, 1], 

192 block_shape=[BLOCK_M, BLOCK_N], 

193 ) 

194 

195 if IS_FP64: 

196 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

197 else: 

198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

199 for k in range(0, tl.cdiv(K, BLOCK_K)): 

200 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

201 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

202 if IS_FP64: 

203 acc += tl.dot(a, b, allow_tf32=False) 

204 else: 

205 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

206 offset_k += BLOCK_K 

207 

208 acc = acc.to(a_desc.dtype) 

209 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

210 

211 else: 

212 # do matrix multiplication 

213 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

214 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

215 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

216 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

217 rm = rm.to(tl.int64) 

218 rn = rn.to(tl.int64) 

219 prev_multiple = prev_multiple_of(K, BLOCK_K) 

220 

221 if IS_FP64: 

222 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

223 else: 

224 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

225 for start_k in range(0, prev_multiple, BLOCK_K): 

226 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

227 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

228 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

229 if a.dtype != b.dtype: 

230 a = a.to(C.dtype.element_ty) 

231 b = b.to(C.dtype.element_ty) 

232 if IS_FP64: 

233 acc += tl.dot(a, b, allow_tf32=False) 

234 else: 

235 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

236 

237 # loop peeling 

238 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

239 mask_k = rk < K 

240 a = tl.load( 

241 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

242 mask=mask_k[None, :], 

243 other=0.0, 

244 ) 

245 b = tl.load( 

246 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

247 mask=mask_k[:, None], 

248 other=0.0, 

249 ) 

250 if a.dtype != b.dtype: 

251 a = a.to(C.dtype.element_ty) 

252 b = b.to(C.dtype.element_ty) 

253 if IS_FP64: 

254 acc += tl.dot(a, b, allow_tf32=False) 

255 else: 

256 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

257 

258 acc = acc.to(C.dtype.element_ty) 

259 # rematerialize rm and rn to save registers 

260 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

261 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

262 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

263 mask = (rm < M)[:, None] & (rn < N)[None, :] 

264 # handles write-back with reduction-splitting 

265 tl.store(offsets, acc, mask=mask) 

266 

267 

268def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

269 configs = [ 

270 triton.Config( 

271 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

272 num_stages=s, 

273 num_warps=w, 

274 pre_hook=pre_hook, 

275 ) 

276 for BM in [32, 64, 128, 256] 

277 for BN in [32, 64, 128] 

278 for BK in [32, 64, 128] 

279 for s in [2, 3, 4] 

280 for w in [4, 8] 

281 ] 

282 shared_mem_limit = _get_shared_memory_limit_bytes() 

283 if shared_mem_limit is None: 

284 return configs 

285 

286 filtered_configs = [ 

287 cfg 

288 for cfg in configs 

289 if _estimate_tma_shared_memory_bytes( 

290 cfg.kwargs["BLOCK_M"], 

291 cfg.kwargs["BLOCK_N"], 

292 cfg.kwargs["BLOCK_K"], 

293 cfg.num_stages, 

294 ) 

295 <= shared_mem_limit 

296 ] 

297 if not filtered_configs: 

298 logger.warning( 

299 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.", 

300 shared_mem_limit, 

301 ) 

302 return configs 

303 return filtered_configs 

304 

305 

306@libentry() 

307@libtuner( 

308 configs=matmul_get_configs(), 

309 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

310 strategy=["align32", "align32", "align32", "align32", "align32", "default"], 

311 warmup=5, 

312 rep=5, 

313 flagtune_op_name="mm", 

314 flagtune_expand_op_name="mm_general_tma", 

315 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

316 flagtune_pre_hook=matmul_tma_set_block_size_hook, 

317) 

318@triton.jit 

319def mm_kernel_general_host_tma( 

320 a_desc, 

321 b_desc, 

322 c_desc, 

323 M, 

324 N, 

325 K, 

326 stride_am, 

327 stride_ak, 

328 stride_bk, 

329 stride_bn, 

330 stride_cm, 

331 stride_cn, 

332 BLOCK_M: tl.constexpr, 

333 BLOCK_N: tl.constexpr, 

334 BLOCK_K: tl.constexpr, 

335 GROUP_M: tl.constexpr, 

336 A_ROW_MAJOR: tl.constexpr, 

337 B_ROW_MAJOR: tl.constexpr, 

338 dtype: tl.constexpr, 

339 enable_warp_specialization=True, 

340): 

341 pid = tl.program_id(0) 

342 grid_m = tl.cdiv(M, BLOCK_M) 

343 grid_n = tl.cdiv(N, BLOCK_N) 

344 

345 width = GROUP_M * grid_n 

346 group_id = pid // width 

347 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

348 pid_m = group_id * GROUP_M + (pid % group_size) 

349 pid_n = (pid % width) // (group_size) 

350 

351 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

352 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

353 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

354 iters = tl.cdiv(K, BLOCK_K) 

355 for k in range(iters): 

356 offset_ak = (k * BLOCK_K).to(tl.int32) 

357 

358 if A_ROW_MAJOR: 

359 a = a_desc.load([offset_am, offset_ak]) 

360 else: 

361 a_t = a_desc.load([offset_ak, offset_am]) 

362 a = tl.trans(a_t) 

363 

364 if B_ROW_MAJOR: 

365 b = b_desc.load([offset_ak, offset_bn]) 

366 else: 

367 b_t = b_desc.load([offset_bn, offset_ak]) 

368 b = tl.trans(b_t) 

369 

370 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

371 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

372 else: 

373 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

374 

375 c = accumulator.to(c_desc.dtype) 

376 c_desc.store([offset_am, offset_bn], c) 

377 

378 

379def get_higher_dtype(a, b): 

380 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

381 

382 if a is b: 

383 return a 

384 

385 assert a in _ordered_datatypes 

386 assert b in _ordered_datatypes 

387 

388 for d in _ordered_datatypes: 

389 if a is d: 

390 return b 

391 if b is d: 

392 return a 

393 

394 

395def general_mm(a, b, c, M, N, K, op_name="mm"): 

396 # TODO: Remove this debug message 

397 logger.debug( 

398 "GEMS MM-hopper, [op]: %s, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

399 "[A column-major]: %s, [B column-major]: %s", 

400 op_name, 

401 M, 

402 N, 

403 K, 

404 a.stride(0) == 1, 

405 b.stride(0) == 1, 

406 ) 

407 # Broadcast tensors from expand() have stride=0, incompatible with TMA 

408 if 0 in a.stride(): 

409 a = a.contiguous() 

410 if 0 in b.stride(): 

411 b = b.contiguous() 

412 grid = lambda META: ( 

413 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

414 ) 

415 if hasattr( 

416 triton.tools.tensor_descriptor, "TensorDescriptor" 

417 ) and is_tma_compatible(a, b, N, K): 

418 a_row_major = a.stride(1) == 1 

419 b_row_major = b.stride(1) == 1 

420 dummy_block = [1, 1] 

421 # triton 3.5.0 

422 from triton.tools.tensor_descriptor import TensorDescriptor 

423 

424 if a_row_major: 

425 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

426 else: 

427 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

428 if b_row_major: 

429 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

430 else: 

431 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

432 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

433 

434 input_dtype = a.dtype 

435 dtype_str = str(input_dtype).split(".")[-1] 

436 

437 with torch_device_fn.device(a.device): 

438 mm_kernel_general_host_tma[grid]( 

439 a_desc, 

440 b_desc, 

441 c_desc, 

442 M, 

443 N, 

444 K, 

445 a.stride(0), 

446 a.stride(1), 

447 b.stride(0), 

448 b.stride(1), 

449 c.stride(0), 

450 c.stride(1), 

451 GROUP_M=8, 

452 A_ROW_MAJOR=a_row_major, 

453 B_ROW_MAJOR=b_row_major, 

454 dtype=dtype_str, 

455 ) 

456 else: 

457 

458 def alloc_fn(size: int, align: int, stream: Optional[int]): 

459 return torch.empty(size, dtype=torch.int8, device=a.device) 

460 

461 triton.set_allocator(alloc_fn) 

462 

463 with torch_device_fn.device(a.device): 

464 mm_kernel_general[grid]( 

465 a, 

466 b, 

467 c, 

468 M, 

469 N, 

470 K, 

471 a.stride(0), 

472 a.stride(1), 

473 b.stride(0), 

474 b.stride(1), 

475 c.stride(0), 

476 c.stride(1), 

477 GROUP_M=8, 

478 IS_FP64=a.dtype == torch.float64, 

479 ) 

480 return c 

481 

482 

483@libentry() 

484@libtuner( 

485 configs=[ 

486 triton.Config( 

487 {"BLOCK_M": 32, "BLOCK_K": 256}, 

488 ) 

489 ], 

490 key=["M", "K", "stride_am", "stride_bk"], 

491 strategy=["align32", "align32", "align32", "default"], 

492 warmup=5, 

493 rep=10, 

494 flagtune_op_name="mm", 

495 flagtune_expand_op_name="gemv", 

496 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

497 flagtune_pre_hook=None, 

498) 

499@triton.jit 

500def gemv_kernel( 

501 A, 

502 B, 

503 C, 

504 M, 

505 K, 

506 stride_am, 

507 stride_ak, 

508 stride_bk, 

509 BLOCK_M: tl.constexpr, 

510 BLOCK_K: tl.constexpr, 

511 IS_FP64: tl.constexpr = False, 

512): 

513 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

514 pid = tl.program_id(0) 

515 

516 # Each program handles BLOCK_M rows 

517 row_start = pid * BLOCK_M 

518 row_offset = row_start + tl.arange(0, BLOCK_M) 

519 row_mask = row_offset < M 

520 

521 # Accumulator for this block of rows 

522 if IS_FP64: 

523 acc = tl.zeros((BLOCK_M,), dtype=tl.float64) 

524 else: 

525 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

526 

527 # Iterate over K dimension 

528 for k_start in range(0, K, BLOCK_K): 

529 k_offset = k_start + tl.arange(0, BLOCK_K) 

530 k_mask = k_offset < K 

531 

532 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

533 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

534 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

535 

536 # Load block from vector B: [BLOCK_K] 

537 b_ptrs = B + k_offset * stride_bk 

538 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

539 

540 # Accumulate: sum over K dimension 

541 if IS_FP64: 

542 acc += tl.sum(a * b[None, :], axis=1) 

543 else: 

544 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1) 

545 

546 # Store result 

547 c_ptrs = C + row_offset 

548 acc = acc.to(C.dtype.element_ty) 

549 tl.store(c_ptrs, acc, mask=row_mask) 

550 

551 

552def gemv_mm(a, b, c, M, K): 

553 """Optimized matrix-vector multiplication for N=1 case""" 

554 logger.debug( 

555 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

556 M, 

557 K, 

558 ) 

559 

560 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

561 

562 with torch_device_fn.device(a.device): 

563 gemv_kernel[grid]( 

564 a, 

565 b, 

566 c, 

567 M, 

568 K, 

569 a.stride(0), 

570 a.stride(1), 

571 b.stride(0), 

572 IS_FP64=a.dtype == torch.float64, 

573 ) 

574 return c 

575 

576 

577@libentry() 

578@libtuner( 

579 configs=runtime.get_tuned_config("mm_splitk"), 

580 key=["M", "N", "K", "stride_am", "stride_bk"], 

581 reset_to_zero=["C"], 

582 strategy=["align32", "align32", "align32", "align32", "align32"], 

583 warmup=5, 

584 rep=10, 

585 flagtune_op_name="mm", 

586 flagtune_expand_op_name="mm_splitk", 

587 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

588 flagtune_pre_hook=None, 

589) 

590@triton.jit 

591def mm_kernel_splitk( 

592 A, 

593 B, 

594 C, 

595 M, 

596 N, 

597 K, 

598 stride_am, 

599 stride_ak, 

600 stride_bk, 

601 stride_bn, 

602 stride_cm, 

603 stride_cn, 

604 BLOCK_M: tl.constexpr, 

605 BLOCK_N: tl.constexpr, 

606 BLOCK_K: tl.constexpr, 

607 SPLIT_K: tl.constexpr, 

608): 

609 pid = tl.program_id(0) 

610 pid_k = tl.program_id(1) 

611 

612 grid_n = tl.cdiv(N, BLOCK_N) 

613 pid_m = pid // grid_n 

614 pid_n = pid % grid_n 

615 

616 offset_am = pid_m * BLOCK_M 

617 offset_bn = pid_n * BLOCK_N 

618 offs_am = offset_am + tl.arange(0, BLOCK_M) 

619 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

620 

621 total_k_iters = tl.cdiv(K, BLOCK_K) 

622 k_per_split = tl.cdiv(total_k_iters, SPLIT_K) 

623 k_start = pid_k * k_per_split 

624 k_end = min((pid_k + 1) * k_per_split, total_k_iters) 

625 

626 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

627 for k in range(k_start, k_end): 

628 offset_k = k * BLOCK_K 

629 offs_k = offset_k + tl.arange(0, BLOCK_K) 

630 

631 a = tl.load( 

632 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak, 

633 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K), 

634 other=0.0, 

635 ) 

636 b = tl.load( 

637 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn, 

638 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), 

639 other=0.0, 

640 ) 

641 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

642 

643 offs_cm = offset_am + tl.arange(0, BLOCK_M) 

644 offs_cn = offset_bn + tl.arange(0, BLOCK_N) 

645 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn 

646 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

647 tl.atomic_add(c_ptrs, acc, mask=mask) 

648 

649 

650def splitk_mm(a, b, c, M, N, K, op_name="mm"): 

651 logger.debug( 

652 "GEMS MM-hopper, [op]: %s, [mm scenario]: splitk, [shape info]: [-, %s, %s, %s](batch, M, N, K)", 

653 op_name, 

654 M, 

655 N, 

656 K, 

657 ) 

658 grid = lambda META: ( 

659 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

660 META["SPLIT_K"], 

661 ) 

662 with torch_device_fn.device(a.device): 

663 mm_kernel_splitk[grid]( 

664 a, 

665 b, 

666 c, 

667 M, 

668 N, 

669 K, 

670 a.stride(0), 

671 a.stride(1), 

672 b.stride(0), 

673 b.stride(1), 

674 c.stride(0), 

675 c.stride(1), 

676 ) 

677 return c 

678 

679 

680def streamk_scenario(a, b, M, N, K): 

681 # TODO: this my change sometime according to the realbenchmark result 

682 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

683 # The optimal settings for other devices need to be determined through real testing. 

684 capability = get_device_capability() 

685 return ( 

686 capability[0] == 8 

687 and a.dtype in [torch.float16, torch.bfloat16] 

688 and b.dtype in [torch.float16, torch.bfloat16] 

689 and a.is_contiguous() 

690 and b.is_contiguous() 

691 and K > M * 5 

692 and K > N * 5 

693 ) 

694 

695 

696if HAS_TLE: 

697 

698 @triton.jit 

699 def _cluster_remote_gemm_kernel( 

700 a_ptr, 

701 b_ptr, 

702 c_ptr, 

703 M, 

704 N, 

705 K, 

706 stride_am, 

707 stride_ak, 

708 stride_bk, 

709 stride_bn, 

710 stride_cm, 

711 stride_cn, 

712 mesh: tl.constexpr, 

713 BM: tl.constexpr, 

714 BN: tl.constexpr, 

715 BK: tl.constexpr, 

716 DOT_K: tl.constexpr, 

717 CLUSTER_SIZE: tl.constexpr, 

718 USE_MASK: tl.constexpr, 

719 A_SLOTS: tl.constexpr, 

720 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr, 

721 ): 

722 pid = tl.program_id(0) 

723 cluster_rank = tle_exp.shard_id(mesh, "cluster_x") 

724 cluster_id = pid // CLUSTER_SIZE 

725 

726 num_pid_n = tl.cdiv(N, BN) 

727 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE) 

728 pid_m = cluster_id // num_pid_n_group 

729 pid_ng = cluster_id % num_pid_n_group 

730 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank 

731 

732 offs_m = pid_m * BM + tl.arange(0, BM) 

733 offs_n = pid_n * BN + tl.arange(0, BN) 

734 offs_k = tl.arange(0, BK) 

735 a_row_base = offs_m - pid_m * BM 

736 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK)) 

737 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) 

738 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM)) 

739 a_buf = tle_exp.gpu.alloc( 

740 [A_SLOTS, BM, BK], 

741 dtype=tl.float16, 

742 layout=None, 

743 scope=tle_exp.gpu.smem, 

744 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT, 

745 ) 

746 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh) 

747 

748 acc = tl.zeros((BM, BN), dtype=tl.float32) 

749 slot0 = 0 

750 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0 

751 if cluster_rank == 0: 

752 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

753 if USE_MASK: 

754 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K) 

755 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

756 else: 

757 a_tile = tl.load(a_ptrs) 

758 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

759 a_buf, (slot0_full, a_rows_full, a_cols_full) 

760 ) 

761 if USE_MASK: 

762 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

763 else: 

764 tl.store(a_local_ptr_tile, a_tile) 

765 

766 tle_exp.distributed_barrier(mesh) 

767 

768 for k0 in range(0, K, BK): 

769 iter_idx = k0 // BK 

770 slot = iter_idx % A_SLOTS 

771 

772 for ks in range(0, BK, DOT_K): 

773 k_local = ks + tl.arange(0, DOT_K) 

774 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM)) 

775 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot 

776 a_ptr_remote = tle_exp.gpu.local_ptr( 

777 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t) 

778 ) 

779 if USE_MASK: 

780 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M) 

781 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0)) 

782 else: 

783 a = tl.trans(tl.load(a_ptr_remote)) 

784 

785 b_ptrs = ( 

786 b_ptr 

787 + (k0 + k_local)[:, None] * stride_bk 

788 + offs_n[None, :] * stride_bn 

789 ) 

790 if USE_MASK: 

791 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N) 

792 b = tl.load(b_ptrs, mask=b_mask, other=0.0) 

793 else: 

794 b = tl.load(b_ptrs) 

795 acc = tl.dot(a, b, acc) 

796 

797 if A_SLOTS == 1: 

798 tle_exp.distributed_barrier(mesh) 

799 

800 next_k0 = k0 + BK 

801 has_next = next_k0 < K 

802 next_iter = iter_idx + 1 

803 next_slot = next_iter % A_SLOTS 

804 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot 

805 if has_next and cluster_rank == 0: 

806 a_ptrs = ( 

807 a_ptr 

808 + offs_m[:, None] * stride_am 

809 + (next_k0 + offs_k)[None, :] * stride_ak 

810 ) 

811 if USE_MASK: 

812 a_mask_tile = (offs_m[:, None] < M) & ( 

813 (next_k0 + offs_k)[None, :] < K 

814 ) 

815 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

816 else: 

817 a_tile = tl.load(a_ptrs) 

818 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

819 a_buf, (next_slot_full, a_rows_full, a_cols_full) 

820 ) 

821 if USE_MASK: 

822 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

823 else: 

824 tl.store(a_local_ptr_tile, a_tile) 

825 

826 tle_exp.distributed_barrier(mesh) 

827 

828 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

829 if USE_MASK: 

830 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

831 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) 

832 else: 

833 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty)) 

834 

835 

836def _select_remote_dot_k(bk: int) -> int: 

837 if bk % 16 == 0: 

838 return 16 

839 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}") 

840 

841 

842def _grid_cluster_remote( 

843 M: int, 

844 N: int, 

845 BM: int, 

846 BN: int, 

847 cluster_size: int = TLE_CLUSTER_SIZE, 

848) -> tuple: 

849 num_pid_n = triton.cdiv(N, BN) 

850 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size) 

851 return (triton.cdiv(M, BM) * num_pid_n_group,) 

852 

853 

854def _run_cluster_remote( 

855 a: torch.Tensor, 

856 b: torch.Tensor, 

857 c: torch.Tensor, 

858 bm: int, 

859 bn: int, 

860 bk: int, 

861 num_warps: int, 

862 num_stages: int, 

863) -> None: 

864 M, K = a.shape 

865 N = b.shape[1] 

866 dot_k = _select_remote_dot_k(bk) 

867 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0) 

868 a_slots = TLE_REMOTE_A_SLOTS 

869 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2) 

870 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)]( 

871 a, 

872 b, 

873 c, 

874 M, 

875 N, 

876 K, 

877 a.stride(0), 

878 a.stride(1), 

879 b.stride(0), 

880 b.stride(1), 

881 c.stride(0), 

882 c.stride(1), 

883 mesh=BLOCK_CLUSTER_MESH, 

884 BM=bm, 

885 BN=bn, 

886 BK=bk, 

887 DOT_K=dot_k, 

888 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

889 USE_MASK=use_mask, 

890 A_SLOTS=a_slots, 

891 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout, 

892 num_ctas=1, 

893 num_warps=num_warps, 

894 num_stages=num_stages, 

895 ) 

896 

897 

898def cluster_remote_mm_scenario(a, b, c, M, N, K): 

899 capability = get_device_capability() 

900 return ( 

901 HAS_TLE 

902 and BLOCK_CLUSTER_MESH is not None 

903 and capability[0] >= 9 

904 and a.is_cuda 

905 and b.is_cuda 

906 and c.is_cuda 

907 and a.dtype == torch.float16 

908 and b.dtype == torch.float16 

909 and c.dtype == torch.float16 

910 and a.is_contiguous() 

911 and b.is_contiguous() 

912 and M >= TLE_REMOTE_BM 

913 and N >= TLE_REMOTE_BN 

914 and K >= TLE_REMOTE_BK 

915 ) 

916 

917 

918def cluster_remote_mm(a, b, c, M, N, K): 

919 logger.debug( 

920 M, 

921 N, 

922 K, 

923 a.stride(0) == 1, 

924 b.stride(0) == 1, 

925 ) 

926 with torch_device_fn.device(a.device): 

927 _run_cluster_remote( 

928 a, 

929 b, 

930 c, 

931 TLE_REMOTE_BM, 

932 TLE_REMOTE_BN, 

933 TLE_REMOTE_BK, 

934 TLE_REMOTE_NUM_WARPS, 

935 TLE_REMOTE_NUM_STAGES, 

936 ) 

937 return c 

938 

939 

940def mm(a, b): 

941 device = a.device 

942 # handle non-contiguous inputs if necessary 

943 if a.stride(0) > 1 and a.stride(1) > 1: 

944 a = a.contiguous() 

945 if b.stride(0) > 1 and b.stride(1) > 1: 

946 b = b.contiguous() 

947 # checks constraints 

948 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

949 M, K = a.shape 

950 _, N = b.shape 

951 # allocates output 

952 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

953 c = torch.empty((M, N), device=device, dtype=c_dtype) 

954 

955 # Optimize for N=1 case (matrix-vector multiplication) 

956 if N == 1: 

957 return gemv_mm(a, b, c, M, K) 

958 # l2_cache_size = get_l2_cache_size() 

959 sm_count = get_sm_count() 

960 if streamk_scenario(a, b, M, N, K): 

961 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

962 if HAS_TLE and BLOCK_CLUSTER_MESH is not None: 

963 if cluster_remote_mm_scenario(a, b, c, M, N, K): 

964 return cluster_remote_mm(a, b, c, M, N, K) 

965 # Use splitk for small M 

966 if M < 2048 and N < 2048 and K >= 4096: 

967 c.zero_() 

968 return splitk_mm(a, b, c, M, N, K) 

969 return general_mm(a, b, c, M, N, K) 

970 

971 

972def mm_out(a, b, *, out): 

973 # handle non-contiguous inputs if necessary 

974 if a.stride(0) > 1 and a.stride(1) > 1: 

975 a = a.contiguous() 

976 if b.stride(0) > 1 and b.stride(1) > 1: 

977 b = b.contiguous() 

978 # checks constraints 

979 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

980 M, K = a.shape 

981 _, N = b.shape 

982 

983 # Optimize for N=1 case (matrix-vector multiplication) 

984 if N == 1: 

985 return gemv_mm(a, b, out, M, K) 

986 # l2_cache_size = get_l2_cache_size() 

987 sm_count = get_sm_count() 

988 if streamk_scenario(a, b, M, N, K): 

989 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

990 if HAS_TLE and BLOCK_CLUSTER_MESH is not None: 

991 if cluster_remote_mm_scenario(a, b, out, M, N, K): 

992 return cluster_remote_mm(a, b, out, M, N, K) 

993 # Use splitk for small M 

994 if M < 2048 and N < 2048 and K >= 4096: 

995 out.zero_() 

996 return splitk_mm(a, b, out, M, N, K) 

997 return general_mm(a, b, out, M, N, K) 

998 

999 

1000def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 

1001 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K).""" 

1002 if x.stride(0) > 1 and x.stride(1) > 1: 

1003 x = x.contiguous() 

1004 M, K = x.shape 

1005 N = weight.shape[0] 

1006 c = torch.empty((M, N), device=x.device, dtype=torch.float32) 

1007 b = weight.t() 

1008 if M < 2048 and N < 2048 and K >= 4096: 

1009 c.zero_() 

1010 return splitk_mm(x, b, c, M, N, K, op_name="router_gemm") 

1011 return general_mm(x, b, c, M, N, K, op_name="router_gemm")