Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

432 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.mm_streamk import streamk_mm 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as ext 

14from flag_gems.utils.device_info import get_device_capability, get_sm_count 

15from flag_gems.utils.triton_version_utils import HAS_TLE, HAS_TLE_DEVICE_MESH 

16 

17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm") 

18CACHE_USAGE_THRESHOLD = 0.8 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml") 

21) 

22_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024 

23 

24 

25def _get_shared_memory_limit_bytes(): 

26 """Return per-block opt-in shared-memory limit for current CUDA device.""" 

27 try: 

28 if not torch.cuda.is_available(): 

29 return None 

30 return torch.cuda.get_device_properties( 

31 torch.cuda.current_device() 

32 ).shared_memory_per_block_optin 

33 except Exception: 

34 return None 

35 

36 

37def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages): 

38 bytes_per_element = 4 

39 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element 

40 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES 

41 

42 

43if HAS_TLE_DEVICE_MESH: 

44 import triton.experimental.tle.language as tle_exp 

45 

46 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]}) 

47 TLE_CLUSTER_SIZE = 2 

48 TLE_REMOTE_BM = 64 

49 TLE_REMOTE_BN = 256 

50 TLE_REMOTE_BK = 64 

51 TLE_REMOTE_NUM_WARPS = 8 

52 TLE_REMOTE_NUM_STAGES = 2 

53 TLE_REMOTE_A_SLOTS = 2 

54else: 

55 tle_exp = None 

56 BLOCK_CLUSTER_MESH = None 

57 TLE_CLUSTER_SIZE = 2 

58 TLE_REMOTE_BM = 64 

59 TLE_REMOTE_BN = 256 

60 TLE_REMOTE_BK = 64 

61 TLE_REMOTE_NUM_WARPS = 8 

62 TLE_REMOTE_NUM_STAGES = 2 

63 TLE_REMOTE_A_SLOTS = 2 

64 

65 

66def is_tma_compatible(a, b, N, K): 

67 """ 

68 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

69 

70 TMA requires 128-bit (16-byte) alignment for memory access: 

71 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

72 (8 elements × 2 bytes = 16 bytes) 

73 - For FP32 (4 bytes/element): N and K must be multiples of 4 

74 (4 elements × 4 bytes = 16 bytes) 

75 

76 Args: 

77 a, b: Input tensors 

78 N, K: Matrix dimensions 

79 

80 Returns: 

81 bool: True if compatible with TMA's alignment requirements 

82 """ 

83 return ( 

84 a.dtype in (torch.float16, torch.bfloat16) 

85 and b.dtype in (torch.float16, torch.bfloat16) 

86 and N % 8 == 0 

87 and K % 8 == 0 

88 ) or ( 

89 a.dtype in (torch.float32,) 

90 and b.dtype in (torch.float32,) 

91 and N % 4 == 0 

92 and K % 4 == 0 

93 ) 

94 

95 

96@triton.jit 

97def prev_multiple_of(a, b): 

98 # the largest x<a that x%b ==0 

99 return tl.cdiv(a, b) * b - b 

100 

101 

102def matmul_tma_set_block_size_hook(nargs): 

103 BLOCK_M = nargs["BLOCK_M"] 

104 BLOCK_N = nargs["BLOCK_N"] 

105 BLOCK_K = nargs["BLOCK_K"] 

106 if nargs["A_ROW_MAJOR"]: 

107 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

108 else: 

109 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

110 

111 if nargs["B_ROW_MAJOR"]: 

112 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

113 else: 

114 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

115 

116 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

117 

118 

119@libentry() 

120@libtuner( 

121 configs=runtime.get_tuned_config("mm"), 

122 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

123 key=["M", "N", "K", "stride_am", "stride_bk"], 

124 strategy=["default", "default", "default", "default", "default"], 

125 warmup=5, 

126 rep=10, 

127) 

128@triton.jit 

129def mm_kernel_general( 

130 A, 

131 B, 

132 C, 

133 M, 

134 N, 

135 K, 

136 stride_am, 

137 stride_ak, 

138 stride_bk, 

139 stride_bn, 

140 stride_cm, 

141 stride_cn, 

142 BLOCK_M: tl.constexpr, 

143 BLOCK_N: tl.constexpr, 

144 BLOCK_K: tl.constexpr, 

145 GROUP_M: tl.constexpr, 

146 IS_FP64: tl.constexpr = False, 

147): 

148 # matrix multiplication 

149 pid = ext.program_id(0) 

150 grid_m = tl.cdiv(M, BLOCK_M) 

151 grid_n = tl.cdiv(N, BLOCK_N) 

152 # re-order program ID for better L2 performance 

153 width = GROUP_M * grid_n 

154 group_id = pid // width 

155 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

156 pid_m = group_id * GROUP_M + (pid % group_size) 

157 pid_n = (pid % width) // (group_size) 

158 

159 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

160 # offset 

161 offset_am = pid_m * BLOCK_M 

162 offset_bn = pid_n * BLOCK_N 

163 offset_k = 0 

164 

165 a_desc = tl.make_tensor_descriptor( 

166 base=A, 

167 shape=[M, K], 

168 strides=[K, 1], 

169 block_shape=[BLOCK_M, BLOCK_K], 

170 ) 

171 

172 # row-major 

173 b_desc = tl.make_tensor_descriptor( 

174 base=B, 

175 shape=[K, N], 

176 strides=[N, 1], 

177 block_shape=[BLOCK_K, BLOCK_N], 

178 ) 

179 

180 # column-major 

181 # b_desc = tl.make_tensor_descriptor( 

182 # B, 

183 # shape = [N, K], 

184 # strides = [K, 1], 

185 # block_shape = [BLOCK_N, BLOCK_K], 

186 # ) 

187 

188 c_desc = tl.make_tensor_descriptor( 

189 base=C, 

190 shape=[M, N], 

191 strides=[N, 1], 

192 block_shape=[BLOCK_M, BLOCK_N], 

193 ) 

194 

195 if IS_FP64: 

196 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

197 else: 

198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

199 for k in range(0, tl.cdiv(K, BLOCK_K)): 

200 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

201 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

202 if IS_FP64: 

203 acc += tl.dot(a, b, allow_tf32=False) 

204 else: 

205 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

206 offset_k += BLOCK_K 

207 

208 acc = acc.to(a_desc.dtype) 

209 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

210 

211 else: 

212 # do matrix multiplication 

213 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

214 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

215 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

216 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

217 rm = rm.to(tl.int64) 

218 rn = rn.to(tl.int64) 

219 prev_multiple = prev_multiple_of(K, BLOCK_K) 

220 

221 if IS_FP64: 

222 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

223 else: 

224 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

225 for start_k in range(0, prev_multiple, BLOCK_K): 

226 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

227 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

228 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

229 if a.dtype != b.dtype: 

230 a = a.to(C.dtype.element_ty) 

231 b = b.to(C.dtype.element_ty) 

232 if IS_FP64: 

233 acc += tl.dot(a, b, allow_tf32=False) 

234 else: 

235 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

236 

237 # loop peeling 

238 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

239 mask_k = rk < K 

240 a = tl.load( 

241 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

242 mask=mask_k[None, :], 

243 other=0.0, 

244 ) 

245 b = tl.load( 

246 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

247 mask=mask_k[:, None], 

248 other=0.0, 

249 ) 

250 if a.dtype != b.dtype: 

251 a = a.to(C.dtype.element_ty) 

252 b = b.to(C.dtype.element_ty) 

253 if IS_FP64: 

254 acc += tl.dot(a, b, allow_tf32=False) 

255 else: 

256 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

257 

258 acc = acc.to(C.dtype.element_ty) 

259 # rematerialize rm and rn to save registers 

260 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

261 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

262 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

263 mask = (rm < M)[:, None] & (rn < N)[None, :] 

264 # handles write-back with reduction-splitting 

265 tl.store(offsets, acc, mask=mask) 

266 

267 

268def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

269 configs = [ 

270 triton.Config( 

271 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK, "GROUP_M": 8}, 

272 num_stages=s, 

273 num_warps=w, 

274 pre_hook=pre_hook, 

275 ) 

276 for BM in [32, 64, 128, 256] 

277 for BN in [32, 64, 128] 

278 for BK in [32, 64, 128] 

279 for s in [2, 3, 4] 

280 for w in [4, 8] 

281 ] 

282 shared_mem_limit = _get_shared_memory_limit_bytes() 

283 if shared_mem_limit is None: 

284 return configs 

285 

286 filtered_configs = [ 

287 cfg 

288 for cfg in configs 

289 if _estimate_tma_shared_memory_bytes( 

290 cfg.kwargs["BLOCK_M"], 

291 cfg.kwargs["BLOCK_N"], 

292 cfg.kwargs["BLOCK_K"], 

293 cfg.num_stages, 

294 ) 

295 <= shared_mem_limit 

296 ] 

297 if not filtered_configs: 

298 logger.warning( 

299 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.", 

300 shared_mem_limit, 

301 ) 

302 return configs 

303 return filtered_configs 

304 

305 

306@libentry() 

307@libtuner( 

308 configs=matmul_get_configs(), 

309 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

310 strategy=["align32", "align32", "align32", "align32", "align32", "default"], 

311 warmup=5, 

312 rep=5, 

313 flagtune_op_name="mm", 

314 flagtune_expand_op_name="mm_general_tma", 

315 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

316 flagtune_pre_hook=matmul_tma_set_block_size_hook, 

317) 

318@triton.jit 

319def mm_kernel_general_host_tma( 

320 a_desc, 

321 b_desc, 

322 c_desc, 

323 M, 

324 N, 

325 K, 

326 stride_am, 

327 stride_ak, 

328 stride_bk, 

329 stride_bn, 

330 stride_cm, 

331 stride_cn, 

332 BLOCK_M: tl.constexpr, 

333 BLOCK_N: tl.constexpr, 

334 BLOCK_K: tl.constexpr, 

335 GROUP_M: tl.constexpr, 

336 A_ROW_MAJOR: tl.constexpr, 

337 B_ROW_MAJOR: tl.constexpr, 

338 dtype: tl.constexpr, 

339 enable_warp_specialization=True, 

340): 

341 pid = tl.program_id(0) 

342 grid_m = tl.cdiv(M, BLOCK_M) 

343 grid_n = tl.cdiv(N, BLOCK_N) 

344 

345 width = GROUP_M * grid_n 

346 group_id = pid // width 

347 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

348 pid_m = group_id * GROUP_M + (pid % group_size) 

349 pid_n = (pid % width) // (group_size) 

350 

351 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

352 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

353 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

354 iters = tl.cdiv(K, BLOCK_K) 

355 for k in range(iters): 

356 offset_ak = (k * BLOCK_K).to(tl.int32) 

357 

358 if A_ROW_MAJOR: 

359 a = a_desc.load([offset_am, offset_ak]) 

360 else: 

361 a_t = a_desc.load([offset_ak, offset_am]) 

362 a = tl.trans(a_t) 

363 

364 if B_ROW_MAJOR: 

365 b = b_desc.load([offset_ak, offset_bn]) 

366 else: 

367 b_t = b_desc.load([offset_bn, offset_ak]) 

368 b = tl.trans(b_t) 

369 

370 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

371 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

372 else: 

373 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

374 

375 c = accumulator.to(c_desc.dtype) 

376 c_desc.store([offset_am, offset_bn], c) 

377 

378 

379def get_higher_dtype(a, b): 

380 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

381 

382 if a is b: 

383 return a 

384 

385 assert a in _ordered_datatypes 

386 assert b in _ordered_datatypes 

387 

388 for d in _ordered_datatypes: 

389 if a is d: 

390 return b 

391 if b is d: 

392 return a 

393 

394 

395def general_mm(a, b, c, M, N, K, op_name="mm"): 

396 # TODO: Remove this debug message 

397 logger.debug( 

398 "GEMS MM-hopper, [op]: %s, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

399 "[A column-major]: %s, [B column-major]: %s", 

400 op_name, 

401 M, 

402 N, 

403 K, 

404 a.stride(0) == 1, 

405 b.stride(0) == 1, 

406 ) 

407 # Broadcast tensors from expand() have stride=0, incompatible with TMA 

408 if 0 in a.stride(): 

409 a = a.contiguous() 

410 if 0 in b.stride(): 

411 b = b.contiguous() 

412 grid = lambda META: ( 

413 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

414 ) 

415 if hasattr( 

416 triton.tools.tensor_descriptor, "TensorDescriptor" 

417 ) and is_tma_compatible(a, b, N, K): 

418 a_row_major = a.stride(1) == 1 

419 b_row_major = b.stride(1) == 1 

420 dummy_block = [1, 1] 

421 # triton 3.5.0 

422 from triton.tools.tensor_descriptor import TensorDescriptor 

423 

424 if a_row_major: 

425 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

426 else: 

427 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

428 if b_row_major: 

429 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

430 else: 

431 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

432 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

433 

434 input_dtype = a.dtype 

435 dtype_str = str(input_dtype).split(".")[-1] 

436 

437 with torch_device_fn.device(a.device): 

438 mm_kernel_general_host_tma[grid]( 

439 a_desc, 

440 b_desc, 

441 c_desc, 

442 M, 

443 N, 

444 K, 

445 a.stride(0), 

446 a.stride(1), 

447 b.stride(0), 

448 b.stride(1), 

449 c.stride(0), 

450 c.stride(1), 

451 A_ROW_MAJOR=a_row_major, 

452 B_ROW_MAJOR=b_row_major, 

453 dtype=dtype_str, 

454 ) 

455 else: 

456 

457 def alloc_fn(size: int, align: int, stream: Optional[int]): 

458 return torch.empty(size, dtype=torch.int8, device=a.device) 

459 

460 triton.set_allocator(alloc_fn) 

461 

462 with torch_device_fn.device(a.device): 

463 mm_kernel_general[grid]( 

464 a, 

465 b, 

466 c, 

467 M, 

468 N, 

469 K, 

470 a.stride(0), 

471 a.stride(1), 

472 b.stride(0), 

473 b.stride(1), 

474 c.stride(0), 

475 c.stride(1), 

476 GROUP_M=8, 

477 IS_FP64=a.dtype == torch.float64, 

478 ) 

479 return c 

480 

481 

482@libentry() 

483@libtuner( 

484 configs=[ 

485 triton.Config( 

486 {"BLOCK_M": 32, "BLOCK_K": 256}, 

487 ) 

488 ], 

489 key=["M", "K", "stride_am", "stride_bk"], 

490 strategy=["align32", "align32", "align32", "default"], 

491 warmup=5, 

492 rep=10, 

493 flagtune_op_name="mm", 

494 flagtune_expand_op_name="gemv", 

495 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

496 flagtune_pre_hook=None, 

497) 

498@triton.jit 

499def gemv_kernel( 

500 A, 

501 B, 

502 C, 

503 M, 

504 K, 

505 stride_am, 

506 stride_ak, 

507 stride_bk, 

508 BLOCK_M: tl.constexpr, 

509 BLOCK_K: tl.constexpr, 

510 IS_FP64: tl.constexpr = False, 

511): 

512 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

513 pid = tl.program_id(0) 

514 

515 # Each program handles BLOCK_M rows 

516 row_start = pid * BLOCK_M 

517 row_offset = row_start + tl.arange(0, BLOCK_M) 

518 row_mask = row_offset < M 

519 

520 # Accumulator for this block of rows 

521 if IS_FP64: 

522 acc = tl.zeros((BLOCK_M,), dtype=tl.float64) 

523 else: 

524 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

525 

526 # Iterate over K dimension 

527 for k_start in range(0, K, BLOCK_K): 

528 k_offset = k_start + tl.arange(0, BLOCK_K) 

529 k_mask = k_offset < K 

530 

531 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

532 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

533 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

534 

535 # Load block from vector B: [BLOCK_K] 

536 b_ptrs = B + k_offset * stride_bk 

537 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

538 

539 # Accumulate: sum over K dimension 

540 if IS_FP64: 

541 acc += tl.sum(a * b[None, :], axis=1) 

542 else: 

543 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1) 

544 

545 # Store result 

546 c_ptrs = C + row_offset 

547 acc = acc.to(C.dtype.element_ty) 

548 tl.store(c_ptrs, acc, mask=row_mask) 

549 

550 

551def gemv_mm(a, b, c, M, K): 

552 """Optimized matrix-vector multiplication for N=1 case""" 

553 logger.debug( 

554 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

555 M, 

556 K, 

557 ) 

558 

559 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

560 

561 with torch_device_fn.device(a.device): 

562 gemv_kernel[grid]( 

563 a, 

564 b, 

565 c, 

566 M, 

567 K, 

568 a.stride(0), 

569 a.stride(1), 

570 b.stride(0), 

571 IS_FP64=a.dtype == torch.float64, 

572 ) 

573 return c 

574 

575 

576@libentry() 

577@libtuner( 

578 configs=runtime.get_tuned_config("mm_splitk"), 

579 key=["M", "N", "K", "stride_am", "stride_bk"], 

580 reset_to_zero=["C"], 

581 strategy=["align32", "align32", "align32", "align32", "align32"], 

582 warmup=5, 

583 rep=10, 

584 flagtune_op_name="mm", 

585 flagtune_expand_op_name="mm_splitk", 

586 flagtune_yaml_path=EXPAND_CONFIG_FILENAME, 

587 flagtune_pre_hook=None, 

588) 

589@triton.jit 

590def mm_kernel_splitk( 

591 A, 

592 B, 

593 C, 

594 M, 

595 N, 

596 K, 

597 stride_am, 

598 stride_ak, 

599 stride_bk, 

600 stride_bn, 

601 stride_cm, 

602 stride_cn, 

603 BLOCK_M: tl.constexpr, 

604 BLOCK_N: tl.constexpr, 

605 BLOCK_K: tl.constexpr, 

606 SPLIT_K: tl.constexpr, 

607): 

608 pid = tl.program_id(0) 

609 pid_k = tl.program_id(1) 

610 

611 grid_n = tl.cdiv(N, BLOCK_N) 

612 pid_m = pid // grid_n 

613 pid_n = pid % grid_n 

614 

615 offset_am = pid_m * BLOCK_M 

616 offset_bn = pid_n * BLOCK_N 

617 offs_am = offset_am + tl.arange(0, BLOCK_M) 

618 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

619 

620 total_k_iters = tl.cdiv(K, BLOCK_K) 

621 k_per_split = tl.cdiv(total_k_iters, SPLIT_K) 

622 k_start = pid_k * k_per_split 

623 k_end = min((pid_k + 1) * k_per_split, total_k_iters) 

624 

625 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

626 for k in range(k_start, k_end): 

627 offset_k = k * BLOCK_K 

628 offs_k = offset_k + tl.arange(0, BLOCK_K) 

629 

630 a = tl.load( 

631 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak, 

632 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K), 

633 other=0.0, 

634 ) 

635 b = tl.load( 

636 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn, 

637 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), 

638 other=0.0, 

639 ) 

640 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

641 

642 offs_cm = offset_am + tl.arange(0, BLOCK_M) 

643 offs_cn = offset_bn + tl.arange(0, BLOCK_N) 

644 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn 

645 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

646 tl.atomic_add(c_ptrs, acc, mask=mask) 

647 

648 

649def splitk_mm(a, b, c, M, N, K, op_name="mm"): 

650 logger.debug( 

651 "GEMS MM-hopper, [op]: %s, [mm scenario]: splitk, [shape info]: [-, %s, %s, %s](batch, M, N, K)", 

652 op_name, 

653 M, 

654 N, 

655 K, 

656 ) 

657 grid = lambda META: ( 

658 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

659 META["SPLIT_K"], 

660 ) 

661 with torch_device_fn.device(a.device): 

662 mm_kernel_splitk[grid]( 

663 a, 

664 b, 

665 c, 

666 M, 

667 N, 

668 K, 

669 a.stride(0), 

670 a.stride(1), 

671 b.stride(0), 

672 b.stride(1), 

673 c.stride(0), 

674 c.stride(1), 

675 ) 

676 return c 

677 

678 

679def streamk_scenario(a, b, M, N, K): 

680 # TODO: this my change sometime according to the realbenchmark result 

681 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

682 # The optimal settings for other devices need to be determined through real testing. 

683 capability = get_device_capability() 

684 return ( 

685 capability[0] == 8 

686 and a.dtype in [torch.float16, torch.bfloat16] 

687 and b.dtype in [torch.float16, torch.bfloat16] 

688 and a.is_contiguous() 

689 and b.is_contiguous() 

690 and K > M * 5 

691 and K > N * 5 

692 ) 

693 

694 

695if HAS_TLE: 

696 

697 @triton.jit 

698 def _cluster_remote_gemm_kernel( 

699 a_ptr, 

700 b_ptr, 

701 c_ptr, 

702 M, 

703 N, 

704 K, 

705 stride_am, 

706 stride_ak, 

707 stride_bk, 

708 stride_bn, 

709 stride_cm, 

710 stride_cn, 

711 mesh: tl.constexpr, 

712 BM: tl.constexpr, 

713 BN: tl.constexpr, 

714 BK: tl.constexpr, 

715 DOT_K: tl.constexpr, 

716 CLUSTER_SIZE: tl.constexpr, 

717 USE_MASK: tl.constexpr, 

718 A_SLOTS: tl.constexpr, 

719 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr, 

720 ): 

721 pid = tl.program_id(0) 

722 cluster_rank = tle_exp.shard_id(mesh, "cluster_x") 

723 cluster_id = pid // CLUSTER_SIZE 

724 

725 num_pid_n = tl.cdiv(N, BN) 

726 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE) 

727 pid_m = cluster_id // num_pid_n_group 

728 pid_ng = cluster_id % num_pid_n_group 

729 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank 

730 

731 offs_m = pid_m * BM + tl.arange(0, BM) 

732 offs_n = pid_n * BN + tl.arange(0, BN) 

733 offs_k = tl.arange(0, BK) 

734 a_row_base = offs_m - pid_m * BM 

735 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK)) 

736 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) 

737 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM)) 

738 a_buf = tle_exp.gpu.alloc( 

739 [A_SLOTS, BM, BK], 

740 dtype=tl.float16, 

741 layout=None, 

742 scope=tle_exp.gpu.smem, 

743 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT, 

744 ) 

745 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh) 

746 

747 acc = tl.zeros((BM, BN), dtype=tl.float32) 

748 slot0 = 0 

749 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0 

750 if cluster_rank == 0: 

751 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

752 if USE_MASK: 

753 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K) 

754 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

755 else: 

756 a_tile = tl.load(a_ptrs) 

757 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

758 a_buf, (slot0_full, a_rows_full, a_cols_full) 

759 ) 

760 if USE_MASK: 

761 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

762 else: 

763 tl.store(a_local_ptr_tile, a_tile) 

764 

765 tle_exp.distributed_barrier(mesh) 

766 

767 for k0 in range(0, K, BK): 

768 iter_idx = k0 // BK 

769 slot = iter_idx % A_SLOTS 

770 

771 for ks in range(0, BK, DOT_K): 

772 k_local = ks + tl.arange(0, DOT_K) 

773 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM)) 

774 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot 

775 a_ptr_remote = tle_exp.gpu.local_ptr( 

776 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t) 

777 ) 

778 if USE_MASK: 

779 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M) 

780 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0)) 

781 else: 

782 a = tl.trans(tl.load(a_ptr_remote)) 

783 

784 b_ptrs = ( 

785 b_ptr 

786 + (k0 + k_local)[:, None] * stride_bk 

787 + offs_n[None, :] * stride_bn 

788 ) 

789 if USE_MASK: 

790 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N) 

791 b = tl.load(b_ptrs, mask=b_mask, other=0.0) 

792 else: 

793 b = tl.load(b_ptrs) 

794 acc = tl.dot(a, b, acc) 

795 

796 if A_SLOTS == 1: 

797 tle_exp.distributed_barrier(mesh) 

798 

799 next_k0 = k0 + BK 

800 has_next = next_k0 < K 

801 next_iter = iter_idx + 1 

802 next_slot = next_iter % A_SLOTS 

803 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot 

804 if has_next and cluster_rank == 0: 

805 a_ptrs = ( 

806 a_ptr 

807 + offs_m[:, None] * stride_am 

808 + (next_k0 + offs_k)[None, :] * stride_ak 

809 ) 

810 if USE_MASK: 

811 a_mask_tile = (offs_m[:, None] < M) & ( 

812 (next_k0 + offs_k)[None, :] < K 

813 ) 

814 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

815 else: 

816 a_tile = tl.load(a_ptrs) 

817 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

818 a_buf, (next_slot_full, a_rows_full, a_cols_full) 

819 ) 

820 if USE_MASK: 

821 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

822 else: 

823 tl.store(a_local_ptr_tile, a_tile) 

824 

825 tle_exp.distributed_barrier(mesh) 

826 

827 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

828 if USE_MASK: 

829 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

830 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) 

831 else: 

832 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty)) 

833 

834 

835def _select_remote_dot_k(bk: int) -> int: 

836 if bk % 16 == 0: 

837 return 16 

838 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}") 

839 

840 

841def _grid_cluster_remote( 

842 M: int, 

843 N: int, 

844 BM: int, 

845 BN: int, 

846 cluster_size: int = TLE_CLUSTER_SIZE, 

847) -> tuple: 

848 num_pid_n = triton.cdiv(N, BN) 

849 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size) 

850 return (triton.cdiv(M, BM) * num_pid_n_group,) 

851 

852 

853def _run_cluster_remote( 

854 a: torch.Tensor, 

855 b: torch.Tensor, 

856 c: torch.Tensor, 

857 bm: int, 

858 bn: int, 

859 bk: int, 

860 num_warps: int, 

861 num_stages: int, 

862) -> None: 

863 M, K = a.shape 

864 N = b.shape[1] 

865 dot_k = _select_remote_dot_k(bk) 

866 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0) 

867 a_slots = TLE_REMOTE_A_SLOTS 

868 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2) 

869 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)]( 

870 a, 

871 b, 

872 c, 

873 M, 

874 N, 

875 K, 

876 a.stride(0), 

877 a.stride(1), 

878 b.stride(0), 

879 b.stride(1), 

880 c.stride(0), 

881 c.stride(1), 

882 mesh=BLOCK_CLUSTER_MESH, 

883 BM=bm, 

884 BN=bn, 

885 BK=bk, 

886 DOT_K=dot_k, 

887 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

888 USE_MASK=use_mask, 

889 A_SLOTS=a_slots, 

890 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout, 

891 num_ctas=1, 

892 num_warps=num_warps, 

893 num_stages=num_stages, 

894 ) 

895 

896 

897def cluster_remote_mm_scenario(a, b, c, M, N, K): 

898 capability = get_device_capability() 

899 return ( 

900 HAS_TLE 

901 and BLOCK_CLUSTER_MESH is not None 

902 and capability[0] >= 9 

903 and a.is_cuda 

904 and b.is_cuda 

905 and c.is_cuda 

906 and a.dtype == torch.float16 

907 and b.dtype == torch.float16 

908 and c.dtype == torch.float16 

909 and a.is_contiguous() 

910 and b.is_contiguous() 

911 and M >= TLE_REMOTE_BM 

912 and N >= TLE_REMOTE_BN 

913 and K >= TLE_REMOTE_BK 

914 ) 

915 

916 

917def cluster_remote_mm(a, b, c, M, N, K): 

918 logger.debug( 

919 M, 

920 N, 

921 K, 

922 a.stride(0) == 1, 

923 b.stride(0) == 1, 

924 ) 

925 with torch_device_fn.device(a.device): 

926 _run_cluster_remote( 

927 a, 

928 b, 

929 c, 

930 TLE_REMOTE_BM, 

931 TLE_REMOTE_BN, 

932 TLE_REMOTE_BK, 

933 TLE_REMOTE_NUM_WARPS, 

934 TLE_REMOTE_NUM_STAGES, 

935 ) 

936 return c 

937 

938 

939def mm(a, b): 

940 device = a.device 

941 # handle non-contiguous inputs if necessary 

942 if a.stride(0) > 1 and a.stride(1) > 1: 

943 a = a.contiguous() 

944 if b.stride(0) > 1 and b.stride(1) > 1: 

945 b = b.contiguous() 

946 # checks constraints 

947 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

948 M, K = a.shape 

949 _, N = b.shape 

950 # allocates output 

951 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

952 c = torch.empty((M, N), device=device, dtype=c_dtype) 

953 

954 # Optimize for N=1 case (matrix-vector multiplication) 

955 if N == 1: 

956 return gemv_mm(a, b, c, M, K) 

957 # l2_cache_size = get_l2_cache_size() 

958 sm_count = get_sm_count() 

959 if streamk_scenario(a, b, M, N, K): 

960 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

961 if HAS_TLE and BLOCK_CLUSTER_MESH is not None: 

962 if cluster_remote_mm_scenario(a, b, c, M, N, K): 

963 return cluster_remote_mm(a, b, c, M, N, K) 

964 # Use splitk for small M 

965 if M < 2048 and N < 2048 and K >= 4096: 

966 c.zero_() 

967 return splitk_mm(a, b, c, M, N, K) 

968 return general_mm(a, b, c, M, N, K) 

969 

970 

971def mm_out(a, b, *, out): 

972 # handle non-contiguous inputs if necessary 

973 if a.stride(0) > 1 and a.stride(1) > 1: 

974 a = a.contiguous() 

975 if b.stride(0) > 1 and b.stride(1) > 1: 

976 b = b.contiguous() 

977 # checks constraints 

978 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

979 M, K = a.shape 

980 _, N = b.shape 

981 

982 # Optimize for N=1 case (matrix-vector multiplication) 

983 if N == 1: 

984 return gemv_mm(a, b, out, M, K) 

985 # l2_cache_size = get_l2_cache_size() 

986 sm_count = get_sm_count() 

987 if streamk_scenario(a, b, M, N, K): 

988 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

989 if HAS_TLE and BLOCK_CLUSTER_MESH is not None: 

990 if cluster_remote_mm_scenario(a, b, out, M, N, K): 

991 return cluster_remote_mm(a, b, out, M, N, K) 

992 # Use splitk for small M 

993 if M < 2048 and N < 2048 and K >= 4096: 

994 out.zero_() 

995 return splitk_mm(a, b, out, M, N, K) 

996 return general_mm(a, b, out, M, N, K) 

997 

998 

999def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 

1000 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K).""" 

1001 if x.stride(0) > 1 and x.stride(1) > 1: 

1002 x = x.contiguous() 

1003 M, K = x.shape 

1004 N = weight.shape[0] 

1005 c = torch.empty((M, N), device=x.device, dtype=torch.float32) 

1006 b = weight.t() 

1007 if M < 2048 and N < 2048 and K >= 4096: 

1008 c.zero_() 

1009 return splitk_mm(x, b, c, M, N, K, op_name="router_gemm") 

1010 return general_mm(x, b, c, M, N, K, op_name="router_gemm")