Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py: 0%

432 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.mm_streamk import streamk_mm 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as ext 

14from flag_gems.utils.device_info import get_device_capability, get_sm_count 

15from flag_gems.utils.triton_version_utils import HAS_TLE, HAS_TLE_DEVICE_MESH 

16 

17logger = logging.getLogger("flag_gems.runtime.backend._nvidia.hopper.ops.mm") 

18CACHE_USAGE_THRESHOLD = 0.8 

19EXPAND_CONFIG_FILENAME = os.path.normpath( 

20 os.path.join(os.path.dirname(__file__), "..", "mm_hopper_expand.yaml") 

21) 

22_SHARED_MEM_SAFETY_MARGIN_BYTES = 1024 

23 

24 

25def _get_shared_memory_limit_bytes(): 

26 """Return per-block opt-in shared-memory limit for current CUDA device.""" 

27 try: 

28 if not torch.cuda.is_available(): 

29 return None 

30 return torch.cuda.get_device_properties( 

31 torch.cuda.current_device() 

32 ).shared_memory_per_block_optin 

33 except Exception: 

34 return None 

35 

36 

37def _estimate_tma_shared_memory_bytes(block_m, block_n, block_k, num_stages): 

38 bytes_per_element = 4 

39 tile_bytes = (block_m * block_k + block_k * block_n) * bytes_per_element 

40 return tile_bytes * num_stages + _SHARED_MEM_SAFETY_MARGIN_BYTES 

41 

42 

43if HAS_TLE_DEVICE_MESH: 

44 import triton.experimental.tle.language as tle_exp 

45 

46 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]}) 

47 TLE_CLUSTER_SIZE = 2 

48 TLE_REMOTE_BM = 64 

49 TLE_REMOTE_BN = 256 

50 TLE_REMOTE_BK = 64 

51 TLE_REMOTE_NUM_WARPS = 8 

52 TLE_REMOTE_NUM_STAGES = 2 

53 TLE_REMOTE_A_SLOTS = 2 

54else: 

55 tle_exp = None 

56 BLOCK_CLUSTER_MESH = None 

57 TLE_CLUSTER_SIZE = 2 

58 TLE_REMOTE_BM = 64 

59 TLE_REMOTE_BN = 256 

60 TLE_REMOTE_BK = 64 

61 TLE_REMOTE_NUM_WARPS = 8 

62 TLE_REMOTE_NUM_STAGES = 2 

63 TLE_REMOTE_A_SLOTS = 2 

64 

65 

66def is_tma_compatible(a, b, N, K): 

67 """ 

68 Check if tensors are compatible with TMA (Tensor Memory Accelerator). 

69 

70 TMA requires 128-bit (16-byte) alignment for memory access: 

71 - For FP16/BF16 (2 bytes/element): N and K must be multiples of 8 

72 (8 elements × 2 bytes = 16 bytes) 

73 - For FP32 (4 bytes/element): N and K must be multiples of 4 

74 (4 elements × 4 bytes = 16 bytes) 

75 

76 Args: 

77 a, b: Input tensors 

78 N, K: Matrix dimensions 

79 

80 Returns: 

81 bool: True if compatible with TMA's alignment requirements 

82 """ 

83 return ( 

84 a.dtype in (torch.float16, torch.bfloat16) 

85 and b.dtype in (torch.float16, torch.bfloat16) 

86 and N % 8 == 0 

87 and K % 8 == 0 

88 ) or ( 

89 a.dtype in (torch.float32,) 

90 and b.dtype in (torch.float32,) 

91 and N % 4 == 0 

92 and K % 4 == 0 

93 ) 

94 

95 

96@triton.jit 

97def prev_multiple_of(a, b): 

98 # the largest x<a that x%b ==0 

99 return tl.cdiv(a, b) * b - b 

100 

101 

102def matmul_tma_set_block_size_hook(nargs): 

103 BLOCK_M = nargs["BLOCK_M"] 

104 BLOCK_N = nargs["BLOCK_N"] 

105 BLOCK_K = nargs["BLOCK_K"] 

106 if nargs["A_ROW_MAJOR"]: 

107 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

108 else: 

109 nargs["a_desc"].block_shape = [BLOCK_K, BLOCK_M] 

110 

111 if nargs["B_ROW_MAJOR"]: 

112 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

113 else: 

114 nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] 

115 

116 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

117 

118 

119@libentry() 

120@libtuner( 

121 configs=runtime.get_tuned_config("mm"), 

122 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

123 key=["M", "N", "K", "stride_am", "stride_bk"], 

124 strategy=["default", "default", "default", "default", "default"], 

125 warmup=5, 

126 rep=10, 

127) 

128@triton.jit 

129def mm_kernel_general( 

130 A, 

131 B, 

132 C, 

133 M, 

134 N, 

135 K, 

136 stride_am, 

137 stride_ak, 

138 stride_bk, 

139 stride_bn, 

140 stride_cm, 

141 stride_cn, 

142 BLOCK_M: tl.constexpr, 

143 BLOCK_N: tl.constexpr, 

144 BLOCK_K: tl.constexpr, 

145 GROUP_M: tl.constexpr, 

146 IS_FP64: tl.constexpr = False, 

147): 

148 # matrix multiplication 

149 pid = ext.program_id(0) 

150 grid_m = tl.cdiv(M, BLOCK_M) 

151 grid_n = tl.cdiv(N, BLOCK_N) 

152 # re-order program ID for better L2 performance 

153 width = GROUP_M * grid_n 

154 group_id = pid // width 

155 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

156 pid_m = group_id * GROUP_M + (pid % group_size) 

157 pid_n = (pid % width) // (group_size) 

158 

159 if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0: 

160 # offset 

161 offset_am = pid_m * BLOCK_M 

162 offset_bn = pid_n * BLOCK_N 

163 offset_k = 0 

164 

165 a_desc = tl.make_tensor_descriptor( 

166 base=A, 

167 shape=[M, K], 

168 strides=[K, 1], 

169 block_shape=[BLOCK_M, BLOCK_K], 

170 ) 

171 

172 # row-major 

173 b_desc = tl.make_tensor_descriptor( 

174 base=B, 

175 shape=[K, N], 

176 strides=[N, 1], 

177 block_shape=[BLOCK_K, BLOCK_N], 

178 ) 

179 

180 # column-major 

181 # b_desc = tl.make_tensor_descriptor( 

182 # B, 

183 # shape = [N, K], 

184 # strides = [K, 1], 

185 # block_shape = [BLOCK_N, BLOCK_K], 

186 # ) 

187 

188 c_desc = tl.make_tensor_descriptor( 

189 base=C, 

190 shape=[M, N], 

191 strides=[N, 1], 

192 block_shape=[BLOCK_M, BLOCK_N], 

193 ) 

194 

195 if IS_FP64: 

196 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

197 else: 

198 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

199 for k in range(0, tl.cdiv(K, BLOCK_K)): 

200 a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)]) 

201 b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)]) 

202 if IS_FP64: 

203 acc += tl.dot(a, b, allow_tf32=False) 

204 else: 

205 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

206 offset_k += BLOCK_K 

207 

208 acc = acc.to(a_desc.dtype) 

209 c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc) 

210 

211 else: 

212 # do matrix multiplication 

213 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

214 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

215 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

216 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

217 rm = rm.to(tl.int64) 

218 rn = rn.to(tl.int64) 

219 prev_multiple = prev_multiple_of(K, BLOCK_K) 

220 

221 if IS_FP64: 

222 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

223 else: 

224 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

225 for start_k in range(0, prev_multiple, BLOCK_K): 

226 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

227 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

228 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

229 if a.dtype != b.dtype: 

230 a = a.to(C.dtype.element_ty) 

231 b = b.to(C.dtype.element_ty) 

232 if IS_FP64: 

233 acc += tl.dot(a, b, allow_tf32=False) 

234 else: 

235 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

236 

237 # loop peeling 

238 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

239 mask_k = rk < K 

240 a = tl.load( 

241 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

242 mask=mask_k[None, :], 

243 other=0.0, 

244 ) 

245 b = tl.load( 

246 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

247 mask=mask_k[:, None], 

248 other=0.0, 

249 ) 

250 if a.dtype != b.dtype: 

251 a = a.to(C.dtype.element_ty) 

252 b = b.to(C.dtype.element_ty) 

253 if IS_FP64: 

254 acc += tl.dot(a, b, allow_tf32=False) 

255 else: 

256 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

257 

258 acc = acc.to(C.dtype.element_ty) 

259 # rematerialize rm and rn to save registers 

260 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

261 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

262 offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

263 mask = (rm < M)[:, None] & (rn < N)[None, :] 

264 # handles write-back with reduction-splitting 

265 tl.store(offsets, acc, mask=mask) 

266 

267 

268def matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook): 

269 configs = [ 

270 triton.Config( 

271 {"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK}, 

272 num_stages=s, 

273 num_warps=w, 

274 pre_hook=pre_hook, 

275 ) 

276 for BM in [32, 64, 128, 256] 

277 for BN in [32, 64, 128] 

278 for BK in [32, 64, 128] 

279 for s in [2, 3, 4] 

280 for w in [4, 8] 

281 ] 

282 shared_mem_limit = _get_shared_memory_limit_bytes() 

283 if shared_mem_limit is None: 

284 return configs 

285 

286 filtered_configs = [ 

287 cfg 

288 for cfg in configs 

289 if _estimate_tma_shared_memory_bytes( 

290 cfg.kwargs["BLOCK_M"], 

291 cfg.kwargs["BLOCK_N"], 

292 cfg.kwargs["BLOCK_K"], 

293 cfg.num_stages, 

294 ) 

295 <= shared_mem_limit 

296 ] 

297 if not filtered_configs: 

298 logger.warning( 

299 "No mm_general_tma config fits shared memory limit (%s bytes); falling back to unfiltered configs.", 

300 shared_mem_limit, 

301 ) 

302 return configs 

303 return filtered_configs 

304 

305 

306@libentry() 

307@libtuner( 

308 configs=runtime.ops_get_configs( 

309 "mm_general_tma", 

310 pre_hook=matmul_tma_set_block_size_hook, 

311 yaml_path=EXPAND_CONFIG_FILENAME, 

312 ) 

313 if os.environ.get("USE_FLAGTUNE") == "1" 

314 else matmul_get_configs(), 

315 key=["M", "N", "K", "stride_am", "stride_bk", "dtype"], 

316 strategy=runtime.get_expand_config( 

317 "mm_general_tma", yaml_path=EXPAND_CONFIG_FILENAME 

318 )["strategy"] 

319 if os.environ.get("USE_FLAGTUNE") == "1" 

320 else ["align32", "align32", "align32", "align32", "align32", "default"], 

321 warmup=5, 

322 rep=5, 

323) 

324@triton.jit 

325def mm_kernel_general_host_tma( 

326 a_desc, 

327 b_desc, 

328 c_desc, 

329 M, 

330 N, 

331 K, 

332 stride_am, 

333 stride_ak, 

334 stride_bk, 

335 stride_bn, 

336 stride_cm, 

337 stride_cn, 

338 BLOCK_M: tl.constexpr, 

339 BLOCK_N: tl.constexpr, 

340 BLOCK_K: tl.constexpr, 

341 GROUP_M: tl.constexpr, 

342 A_ROW_MAJOR: tl.constexpr, 

343 B_ROW_MAJOR: tl.constexpr, 

344 dtype: tl.constexpr, 

345 enable_warp_specialization=True, 

346): 

347 pid = tl.program_id(0) 

348 grid_m = tl.cdiv(M, BLOCK_M) 

349 grid_n = tl.cdiv(N, BLOCK_N) 

350 

351 width = GROUP_M * grid_n 

352 group_id = pid // width 

353 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

354 pid_m = group_id * GROUP_M + (pid % group_size) 

355 pid_n = (pid % width) // (group_size) 

356 

357 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

358 offset_am = (pid_m * BLOCK_M).to(tl.int32) 

359 offset_bn = (pid_n * BLOCK_N).to(tl.int32) 

360 iters = tl.cdiv(K, BLOCK_K) 

361 for k in range(iters): 

362 offset_ak = (k * BLOCK_K).to(tl.int32) 

363 

364 if A_ROW_MAJOR: 

365 a = a_desc.load([offset_am, offset_ak]) 

366 else: 

367 a_t = a_desc.load([offset_ak, offset_am]) 

368 a = tl.trans(a_t) 

369 

370 if B_ROW_MAJOR: 

371 b = b_desc.load([offset_ak, offset_bn]) 

372 else: 

373 b_t = b_desc.load([offset_bn, offset_ak]) 

374 b = tl.trans(b_t) 

375 

376 if a_desc.dtype == tl.float16 or a_desc.dtype == tl.bfloat16: 

377 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

378 else: 

379 accumulator = tl.dot(a, b, acc=accumulator, input_precision="tf32x3") 

380 

381 c = accumulator.to(c_desc.dtype) 

382 c_desc.store([offset_am, offset_bn], c) 

383 

384 

385def get_higher_dtype(a, b): 

386 _ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

387 

388 if a is b: 

389 return a 

390 

391 assert a in _ordered_datatypes 

392 assert b in _ordered_datatypes 

393 

394 for d in _ordered_datatypes: 

395 if a is d: 

396 return b 

397 if b is d: 

398 return a 

399 

400 

401def general_mm(a, b, c, M, N, K, op_name="mm"): 

402 # TODO: Remove this debug message 

403 logger.debug( 

404 "GEMS MM-hopper, [op]: %s, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), " 

405 "[A column-major]: %s, [B column-major]: %s", 

406 op_name, 

407 M, 

408 N, 

409 K, 

410 a.stride(0) == 1, 

411 b.stride(0) == 1, 

412 ) 

413 # Broadcast tensors from expand() have stride=0, incompatible with TMA 

414 if 0 in a.stride(): 

415 a = a.contiguous() 

416 if 0 in b.stride(): 

417 b = b.contiguous() 

418 grid = lambda META: ( 

419 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

420 ) 

421 if hasattr( 

422 triton.tools.tensor_descriptor, "TensorDescriptor" 

423 ) and is_tma_compatible(a, b, N, K): 

424 a_row_major = a.stride(1) == 1 

425 b_row_major = b.stride(1) == 1 

426 dummy_block = [1, 1] 

427 # triton 3.5.0 

428 from triton.tools.tensor_descriptor import TensorDescriptor 

429 

430 if a_row_major: 

431 a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) 

432 else: 

433 a_desc = TensorDescriptor(a, a.T.shape, a.T.stride(), dummy_block) 

434 if b_row_major: 

435 b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) 

436 else: 

437 b_desc = TensorDescriptor(b, b.T.shape, b.T.stride(), dummy_block) 

438 c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) 

439 

440 input_dtype = a.dtype 

441 dtype_str = str(input_dtype).split(".")[-1] 

442 

443 with torch_device_fn.device(a.device): 

444 mm_kernel_general_host_tma[grid]( 

445 a_desc, 

446 b_desc, 

447 c_desc, 

448 M, 

449 N, 

450 K, 

451 a.stride(0), 

452 a.stride(1), 

453 b.stride(0), 

454 b.stride(1), 

455 c.stride(0), 

456 c.stride(1), 

457 GROUP_M=8, 

458 A_ROW_MAJOR=a_row_major, 

459 B_ROW_MAJOR=b_row_major, 

460 dtype=dtype_str, 

461 ) 

462 else: 

463 

464 def alloc_fn(size: int, align: int, stream: Optional[int]): 

465 return torch.empty(size, dtype=torch.int8, device=a.device) 

466 

467 triton.set_allocator(alloc_fn) 

468 

469 with torch_device_fn.device(a.device): 

470 mm_kernel_general[grid]( 

471 a, 

472 b, 

473 c, 

474 M, 

475 N, 

476 K, 

477 a.stride(0), 

478 a.stride(1), 

479 b.stride(0), 

480 b.stride(1), 

481 c.stride(0), 

482 c.stride(1), 

483 GROUP_M=8, 

484 IS_FP64=a.dtype == torch.float64, 

485 ) 

486 return c 

487 

488 

489@libentry() 

490@libtuner( 

491 configs=runtime.ops_get_configs( 

492 "gemv", pre_hook=None, yaml_path=EXPAND_CONFIG_FILENAME 

493 ) 

494 if os.environ.get("USE_FLAGTUNE") == "1" 

495 else [ 

496 triton.Config( 

497 {"BLOCK_M": 32, "BLOCK_K": 256}, 

498 ) 

499 ], 

500 key=["M", "K", "stride_am", "stride_bk"], 

501 strategy=runtime.get_expand_config("gemv", yaml_path=EXPAND_CONFIG_FILENAME)[ 

502 "strategy" 

503 ] 

504 if os.environ.get("USE_FLAGTUNE") == "1" 

505 else ["align32", "align32", "align32", "default"], 

506 warmup=5, 

507 rep=10, 

508) 

509@triton.jit 

510def gemv_kernel( 

511 A, 

512 B, 

513 C, 

514 M, 

515 K, 

516 stride_am, 

517 stride_ak, 

518 stride_bk, 

519 BLOCK_M: tl.constexpr, 

520 BLOCK_K: tl.constexpr, 

521 IS_FP64: tl.constexpr = False, 

522): 

523 """Optimized kernel for matrix-vector multiplication (N=1 case)""" 

524 pid = tl.program_id(0) 

525 

526 # Each program handles BLOCK_M rows 

527 row_start = pid * BLOCK_M 

528 row_offset = row_start + tl.arange(0, BLOCK_M) 

529 row_mask = row_offset < M 

530 

531 # Accumulator for this block of rows 

532 if IS_FP64: 

533 acc = tl.zeros((BLOCK_M,), dtype=tl.float64) 

534 else: 

535 acc = tl.zeros((BLOCK_M,), dtype=tl.float32) 

536 

537 # Iterate over K dimension 

538 for k_start in range(0, K, BLOCK_K): 

539 k_offset = k_start + tl.arange(0, BLOCK_K) 

540 k_mask = k_offset < K 

541 

542 # Load block from matrix A: [BLOCK_M, BLOCK_K] 

543 a_ptrs = A + row_offset[:, None] * stride_am + k_offset[None, :] * stride_ak 

544 a = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0) 

545 

546 # Load block from vector B: [BLOCK_K] 

547 b_ptrs = B + k_offset * stride_bk 

548 b = tl.load(b_ptrs, mask=k_mask, other=0.0) 

549 

550 # Accumulate: sum over K dimension 

551 if IS_FP64: 

552 acc += tl.sum(a * b[None, :], axis=1) 

553 else: 

554 acc += tl.sum(a.to(tl.float32) * b.to(tl.float32)[None, :], axis=1) 

555 

556 # Store result 

557 c_ptrs = C + row_offset 

558 acc = acc.to(C.dtype.element_ty) 

559 tl.store(c_ptrs, acc, mask=row_mask) 

560 

561 

562def gemv_mm(a, b, c, M, K): 

563 """Optimized matrix-vector multiplication for N=1 case""" 

564 logger.debug( 

565 "GEMS MM-hopper, [mm scenario]: gemv (N=1), [shape info]: [%s, %s, 1](M, K, N)", 

566 M, 

567 K, 

568 ) 

569 

570 grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) 

571 

572 with torch_device_fn.device(a.device): 

573 gemv_kernel[grid]( 

574 a, 

575 b, 

576 c, 

577 M, 

578 K, 

579 a.stride(0), 

580 a.stride(1), 

581 b.stride(0), 

582 IS_FP64=a.dtype == torch.float64, 

583 ) 

584 return c 

585 

586 

587@libentry() 

588@libtuner( 

589 configs=runtime.ops_get_configs( 

590 "mm_splitk", 

591 pre_hook=None, 

592 yaml_path=EXPAND_CONFIG_FILENAME, 

593 ) 

594 if os.environ.get("USE_FLAGTUNE") == "1" 

595 else runtime.get_tuned_config("mm_splitk"), 

596 key=["M", "N", "K", "stride_am", "stride_bk"], 

597 reset_to_zero=["C"], 

598 strategy=runtime.get_expand_config("mm_splitk", yaml_path=EXPAND_CONFIG_FILENAME)[ 

599 "strategy" 

600 ] 

601 if os.environ.get("USE_FLAGTUNE") == "1" 

602 else ["align32", "align32", "align32", "align32", "align32"], 

603 warmup=5, 

604 rep=10, 

605) 

606@triton.jit 

607def mm_kernel_splitk( 

608 A, 

609 B, 

610 C, 

611 M, 

612 N, 

613 K, 

614 stride_am, 

615 stride_ak, 

616 stride_bk, 

617 stride_bn, 

618 stride_cm, 

619 stride_cn, 

620 BLOCK_M: tl.constexpr, 

621 BLOCK_N: tl.constexpr, 

622 BLOCK_K: tl.constexpr, 

623 SPLIT_K: tl.constexpr, 

624): 

625 pid = tl.program_id(0) 

626 pid_k = tl.program_id(1) 

627 

628 grid_n = tl.cdiv(N, BLOCK_N) 

629 pid_m = pid // grid_n 

630 pid_n = pid % grid_n 

631 

632 offset_am = pid_m * BLOCK_M 

633 offset_bn = pid_n * BLOCK_N 

634 offs_am = offset_am + tl.arange(0, BLOCK_M) 

635 offs_bn = offset_bn + tl.arange(0, BLOCK_N) 

636 

637 total_k_iters = tl.cdiv(K, BLOCK_K) 

638 k_per_split = tl.cdiv(total_k_iters, SPLIT_K) 

639 k_start = pid_k * k_per_split 

640 k_end = min((pid_k + 1) * k_per_split, total_k_iters) 

641 

642 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

643 for k in range(k_start, k_end): 

644 offset_k = k * BLOCK_K 

645 offs_k = offset_k + tl.arange(0, BLOCK_K) 

646 

647 a = tl.load( 

648 A + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak, 

649 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K), 

650 other=0.0, 

651 ) 

652 b = tl.load( 

653 B + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn, 

654 mask=(offs_k[:, None] < K) & (offs_bn[None, :] < N), 

655 other=0.0, 

656 ) 

657 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

658 

659 offs_cm = offset_am + tl.arange(0, BLOCK_M) 

660 offs_cn = offset_bn + tl.arange(0, BLOCK_N) 

661 c_ptrs = C + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn 

662 mask = (offs_cm < M)[:, None] & (offs_cn < N)[None, :] 

663 tl.atomic_add(c_ptrs, acc, mask=mask) 

664 

665 

666def splitk_mm(a, b, c, M, N, K, op_name="mm"): 

667 logger.debug( 

668 "GEMS MM-hopper, [op]: %s, [mm scenario]: splitk, [shape info]: [-, %s, %s, %s](batch, M, N, K)", 

669 op_name, 

670 M, 

671 N, 

672 K, 

673 ) 

674 grid = lambda META: ( 

675 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

676 META["SPLIT_K"], 

677 ) 

678 with torch_device_fn.device(a.device): 

679 mm_kernel_splitk[grid]( 

680 a, 

681 b, 

682 c, 

683 M, 

684 N, 

685 K, 

686 a.stride(0), 

687 a.stride(1), 

688 b.stride(0), 

689 b.stride(1), 

690 c.stride(0), 

691 c.stride(1), 

692 ) 

693 return c 

694 

695 

696def streamk_scenario(a, b, M, N, K): 

697 # TODO: this my change sometime according to the realbenchmark result 

698 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

699 # The optimal settings for other devices need to be determined through real testing. 

700 capability = get_device_capability() 

701 return ( 

702 capability[0] == 8 

703 and a.dtype in [torch.float16, torch.bfloat16] 

704 and b.dtype in [torch.float16, torch.bfloat16] 

705 and a.is_contiguous() 

706 and b.is_contiguous() 

707 and K > M * 5 

708 and K > N * 5 

709 ) 

710 

711 

712if HAS_TLE: 

713 

714 @triton.jit 

715 def _cluster_remote_gemm_kernel( 

716 a_ptr, 

717 b_ptr, 

718 c_ptr, 

719 M, 

720 N, 

721 K, 

722 stride_am, 

723 stride_ak, 

724 stride_bk, 

725 stride_bn, 

726 stride_cm, 

727 stride_cn, 

728 mesh: tl.constexpr, 

729 BM: tl.constexpr, 

730 BN: tl.constexpr, 

731 BK: tl.constexpr, 

732 DOT_K: tl.constexpr, 

733 CLUSTER_SIZE: tl.constexpr, 

734 USE_MASK: tl.constexpr, 

735 A_SLOTS: tl.constexpr, 

736 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr, 

737 ): 

738 pid = tl.program_id(0) 

739 cluster_rank = tle_exp.shard_id(mesh, "cluster_x") 

740 cluster_id = pid // CLUSTER_SIZE 

741 

742 num_pid_n = tl.cdiv(N, BN) 

743 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE) 

744 pid_m = cluster_id // num_pid_n_group 

745 pid_ng = cluster_id % num_pid_n_group 

746 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank 

747 

748 offs_m = pid_m * BM + tl.arange(0, BM) 

749 offs_n = pid_n * BN + tl.arange(0, BN) 

750 offs_k = tl.arange(0, BK) 

751 a_row_base = offs_m - pid_m * BM 

752 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK)) 

753 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) 

754 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM)) 

755 a_buf = tle_exp.gpu.alloc( 

756 [A_SLOTS, BM, BK], 

757 dtype=tl.float16, 

758 layout=None, 

759 scope=tle_exp.gpu.smem, 

760 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT, 

761 ) 

762 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh) 

763 

764 acc = tl.zeros((BM, BN), dtype=tl.float32) 

765 slot0 = 0 

766 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0 

767 if cluster_rank == 0: 

768 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

769 if USE_MASK: 

770 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K) 

771 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

772 else: 

773 a_tile = tl.load(a_ptrs) 

774 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

775 a_buf, (slot0_full, a_rows_full, a_cols_full) 

776 ) 

777 if USE_MASK: 

778 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

779 else: 

780 tl.store(a_local_ptr_tile, a_tile) 

781 

782 tle_exp.distributed_barrier(mesh) 

783 

784 for k0 in range(0, K, BK): 

785 iter_idx = k0 // BK 

786 slot = iter_idx % A_SLOTS 

787 

788 for ks in range(0, BK, DOT_K): 

789 k_local = ks + tl.arange(0, DOT_K) 

790 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM)) 

791 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot 

792 a_ptr_remote = tle_exp.gpu.local_ptr( 

793 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t) 

794 ) 

795 if USE_MASK: 

796 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M) 

797 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0)) 

798 else: 

799 a = tl.trans(tl.load(a_ptr_remote)) 

800 

801 b_ptrs = ( 

802 b_ptr 

803 + (k0 + k_local)[:, None] * stride_bk 

804 + offs_n[None, :] * stride_bn 

805 ) 

806 if USE_MASK: 

807 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N) 

808 b = tl.load(b_ptrs, mask=b_mask, other=0.0) 

809 else: 

810 b = tl.load(b_ptrs) 

811 acc = tl.dot(a, b, acc) 

812 

813 if A_SLOTS == 1: 

814 tle_exp.distributed_barrier(mesh) 

815 

816 next_k0 = k0 + BK 

817 has_next = next_k0 < K 

818 next_iter = iter_idx + 1 

819 next_slot = next_iter % A_SLOTS 

820 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot 

821 if has_next and cluster_rank == 0: 

822 a_ptrs = ( 

823 a_ptr 

824 + offs_m[:, None] * stride_am 

825 + (next_k0 + offs_k)[None, :] * stride_ak 

826 ) 

827 if USE_MASK: 

828 a_mask_tile = (offs_m[:, None] < M) & ( 

829 (next_k0 + offs_k)[None, :] < K 

830 ) 

831 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

832 else: 

833 a_tile = tl.load(a_ptrs) 

834 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

835 a_buf, (next_slot_full, a_rows_full, a_cols_full) 

836 ) 

837 if USE_MASK: 

838 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

839 else: 

840 tl.store(a_local_ptr_tile, a_tile) 

841 

842 tle_exp.distributed_barrier(mesh) 

843 

844 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

845 if USE_MASK: 

846 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

847 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) 

848 else: 

849 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty)) 

850 

851 

852def _select_remote_dot_k(bk: int) -> int: 

853 if bk % 16 == 0: 

854 return 16 

855 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}") 

856 

857 

858def _grid_cluster_remote( 

859 M: int, 

860 N: int, 

861 BM: int, 

862 BN: int, 

863 cluster_size: int = TLE_CLUSTER_SIZE, 

864) -> tuple: 

865 num_pid_n = triton.cdiv(N, BN) 

866 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size) 

867 return (triton.cdiv(M, BM) * num_pid_n_group,) 

868 

869 

870def _run_cluster_remote( 

871 a: torch.Tensor, 

872 b: torch.Tensor, 

873 c: torch.Tensor, 

874 bm: int, 

875 bn: int, 

876 bk: int, 

877 num_warps: int, 

878 num_stages: int, 

879) -> None: 

880 M, K = a.shape 

881 N = b.shape[1] 

882 dot_k = _select_remote_dot_k(bk) 

883 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0) 

884 a_slots = TLE_REMOTE_A_SLOTS 

885 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2) 

886 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)]( 

887 a, 

888 b, 

889 c, 

890 M, 

891 N, 

892 K, 

893 a.stride(0), 

894 a.stride(1), 

895 b.stride(0), 

896 b.stride(1), 

897 c.stride(0), 

898 c.stride(1), 

899 mesh=BLOCK_CLUSTER_MESH, 

900 BM=bm, 

901 BN=bn, 

902 BK=bk, 

903 DOT_K=dot_k, 

904 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

905 USE_MASK=use_mask, 

906 A_SLOTS=a_slots, 

907 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout, 

908 num_ctas=1, 

909 num_warps=num_warps, 

910 num_stages=num_stages, 

911 ) 

912 

913 

914def cluster_remote_mm_scenario(a, b, c, M, N, K): 

915 capability = get_device_capability() 

916 return ( 

917 HAS_TLE 

918 and BLOCK_CLUSTER_MESH is not None 

919 and capability[0] >= 9 

920 and a.is_cuda 

921 and b.is_cuda 

922 and c.is_cuda 

923 and a.dtype == torch.float16 

924 and b.dtype == torch.float16 

925 and c.dtype == torch.float16 

926 and a.is_contiguous() 

927 and b.is_contiguous() 

928 and M >= TLE_REMOTE_BM 

929 and N >= TLE_REMOTE_BN 

930 and K >= TLE_REMOTE_BK 

931 ) 

932 

933 

934def cluster_remote_mm(a, b, c, M, N, K): 

935 logger.debug( 

936 M, 

937 N, 

938 K, 

939 a.stride(0) == 1, 

940 b.stride(0) == 1, 

941 ) 

942 with torch_device_fn.device(a.device): 

943 _run_cluster_remote( 

944 a, 

945 b, 

946 c, 

947 TLE_REMOTE_BM, 

948 TLE_REMOTE_BN, 

949 TLE_REMOTE_BK, 

950 TLE_REMOTE_NUM_WARPS, 

951 TLE_REMOTE_NUM_STAGES, 

952 ) 

953 return c 

954 

955 

956def mm(a, b): 

957 device = a.device 

958 # handle non-contiguous inputs if necessary 

959 if a.stride(0) > 1 and a.stride(1) > 1: 

960 a = a.contiguous() 

961 if b.stride(0) > 1 and b.stride(1) > 1: 

962 b = b.contiguous() 

963 # checks constraints 

964 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

965 M, K = a.shape 

966 _, N = b.shape 

967 # allocates output 

968 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

969 c = torch.empty((M, N), device=device, dtype=c_dtype) 

970 

971 # Optimize for N=1 case (matrix-vector multiplication) 

972 if N == 1: 

973 return gemv_mm(a, b, c, M, K) 

974 # l2_cache_size = get_l2_cache_size() 

975 sm_count = get_sm_count() 

976 if streamk_scenario(a, b, M, N, K): 

977 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

978 if HAS_TLE and BLOCK_CLUSTER_MESH is not None: 

979 if cluster_remote_mm_scenario(a, b, c, M, N, K): 

980 return cluster_remote_mm(a, b, c, M, N, K) 

981 # Use splitk for small M 

982 if M < 2048 and N < 2048 and K >= 4096: 

983 c.zero_() 

984 return splitk_mm(a, b, c, M, N, K) 

985 return general_mm(a, b, c, M, N, K) 

986 

987 

988def mm_out(a, b, *, out): 

989 # handle non-contiguous inputs if necessary 

990 if a.stride(0) > 1 and a.stride(1) > 1: 

991 a = a.contiguous() 

992 if b.stride(0) > 1 and b.stride(1) > 1: 

993 b = b.contiguous() 

994 # checks constraints 

995 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

996 M, K = a.shape 

997 _, N = b.shape 

998 

999 # Optimize for N=1 case (matrix-vector multiplication) 

1000 if N == 1: 

1001 return gemv_mm(a, b, out, M, K) 

1002 # l2_cache_size = get_l2_cache_size() 

1003 sm_count = get_sm_count() 

1004 if streamk_scenario(a, b, M, N, K): 

1005 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

1006 if HAS_TLE and BLOCK_CLUSTER_MESH is not None: 

1007 if cluster_remote_mm_scenario(a, b, out, M, N, K): 

1008 return cluster_remote_mm(a, b, out, M, N, K) 

1009 # Use splitk for small M 

1010 if M < 2048 and N < 2048 and K >= 4096: 

1011 out.zero_() 

1012 return splitk_mm(a, b, out, M, N, K) 

1013 return general_mm(a, b, out, M, N, K) 

1014 

1015 

1016def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 

1017 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K).""" 

1018 if x.stride(0) > 1 and x.stride(1) > 1: 

1019 x = x.contiguous() 

1020 M, K = x.shape 

1021 N = weight.shape[0] 

1022 c = torch.empty((M, N), device=x.device, dtype=torch.float32) 

1023 b = weight.t() 

1024 if M < 2048 and N < 2048 and K >= 4096: 

1025 c.zero_() 

1026 return splitk_mm(x, b, c, M, N, K, op_name="router_gemm") 

1027 return general_mm(x, b, c, M, N, K, op_name="router_gemm")