Coverage for src/flag_gems/ops/mm.py: 32%

283 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.mm_streamk import streamk_mm 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as ext 

12from flag_gems.utils.device_info import get_device_capability, get_sm_count 

13from flag_gems.utils.triton_version_utils import ( # noqa: F401 

14 HAS_TLE, 

15 HAS_TLE_DEVICE_MESH, 

16 _triton_version_at_least, 

17) 

18 

19if HAS_TLE_DEVICE_MESH: 

20 import triton.experimental.tle.language as tle_exp 

21 

22 BLOCK_CLUSTER_MESH = tle_exp.device_mesh({"block_cluster": [("cluster_x", 2)]}) 

23else: 

24 tle_exp = None 

25 BLOCK_CLUSTER_MESH = None 

26 

27CACHE_USAGE_THRESHOLD = 0.8 

28TLE_CLUSTER_SIZE = 2 

29TLE_REMOTE_BM = 64 

30TLE_REMOTE_BN = 256 

31TLE_REMOTE_BK = 64 

32TLE_REMOTE_NUM_WARPS = 8 

33TLE_REMOTE_NUM_STAGES = 2 

34TLE_REMOTE_A_SLOTS = 2 

35 

36logger = logging.getLogger(__name__) 

37 

38 

39@triton.jit 

40def prev_multiple_of(a, b): 

41 # the largest x<a that x%b ==0 

42 return tl.cdiv(a, b) * b - b 

43 

44 

45@libentry() 

46@libtuner( 

47 configs=runtime.get_tuned_config("mm"), 

48 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides. 

49 key=["M", "N", "K", "stride_am", "stride_bk"], 

50 strategy=["align32", "align32", "align32", "align32", "align32"], 

51 warmup=5, 

52 rep=10, 

53) 

54@triton.jit 

55def mm_kernel_general( 

56 A, 

57 B, 

58 C, 

59 M, 

60 N, 

61 K, 

62 stride_am, 

63 stride_ak, 

64 stride_bk, 

65 stride_bn, 

66 stride_cm, 

67 stride_cn, 

68 BLOCK_M: tl.constexpr, 

69 BLOCK_N: tl.constexpr, 

70 BLOCK_K: tl.constexpr, 

71 GROUP_M: tl.constexpr, 

72 IS_FP64: tl.constexpr = False, 

73): 

74 # matrix multiplication 

75 pid = ext.program_id(0) 

76 grid_m = tl.cdiv(M, BLOCK_M) 

77 grid_n = tl.cdiv(N, BLOCK_N) 

78 # re-order program ID for better L2 performance 

79 width = GROUP_M * grid_n 

80 group_id = pid // width 

81 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

82 pid_m = group_id * GROUP_M + (pid % group_size) 

83 pid_n = (pid % width) // (group_size) 

84 # do matrix multiplication 

85 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

86 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

87 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

88 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64) 

89 rm = rm.to(tl.int64) 

90 rn = rn.to(tl.int64) 

91 prev_multiple = prev_multiple_of(K, BLOCK_K) 

92 

93 if IS_FP64: 

94 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) 

95 else: 

96 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

97 for start_k in range(0, prev_multiple, BLOCK_K): 

98 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

99 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)) 

100 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)) 

101 if a.dtype != b.dtype: 

102 a = a.to(C.dtype.element_ty) 

103 b = b.to(C.dtype.element_ty) 

104 if IS_FP64: 

105 acc += tl.dot(a, b, allow_tf32=False) 

106 else: 

107 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

108 

109 # loop peeling 

110 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64) 

111 mask_k = rk < K 

112 a = tl.load( 

113 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

114 mask=mask_k[None, :], 

115 other=0.0, 

116 ) 

117 b = tl.load( 

118 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn), 

119 mask=mask_k[:, None], 

120 other=0.0, 

121 ) 

122 if a.dtype != b.dtype: 

123 a = a.to(C.dtype.element_ty) 

124 b = b.to(C.dtype.element_ty) 

125 if IS_FP64: 

126 acc += tl.dot(a, b, allow_tf32=False) 

127 else: 

128 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

129 

130 acc = acc.to(C.dtype.element_ty) 

131 # rematerialize rm and rn to save registers 

132 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64) 

133 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64) 

134 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

135 mask = (rm < M)[:, None] & (rn < N)[None, :] 

136 # handles write-back with reduction-splitting 

137 tl.store(C, acc, mask=mask) 

138 

139 

140if HAS_TLE: 

141 

142 @triton.jit 

143 def _cluster_remote_gemm_kernel( 

144 a_ptr, 

145 b_ptr, 

146 c_ptr, 

147 M, 

148 N, 

149 K, 

150 stride_am, 

151 stride_ak, 

152 stride_bk, 

153 stride_bn, 

154 stride_cm, 

155 stride_cn, 

156 mesh: tl.constexpr, 

157 BM: tl.constexpr, 

158 BN: tl.constexpr, 

159 BK: tl.constexpr, 

160 DOT_K: tl.constexpr, 

161 CLUSTER_SIZE: tl.constexpr, 

162 USE_MASK: tl.constexpr, 

163 A_SLOTS: tl.constexpr, 

164 USE_NV_MMA_SMEM_LAYOUT: tl.constexpr, 

165 ): 

166 pid = tl.program_id(0) 

167 cluster_rank = tle_exp.shard_id(mesh, "cluster_x") 

168 cluster_id = pid // CLUSTER_SIZE 

169 

170 num_pid_n = tl.cdiv(N, BN) 

171 num_pid_n_group = tl.cdiv(num_pid_n, CLUSTER_SIZE) 

172 pid_m = cluster_id // num_pid_n_group 

173 pid_ng = cluster_id % num_pid_n_group 

174 pid_n = pid_ng * CLUSTER_SIZE + cluster_rank 

175 

176 offs_m = pid_m * BM + tl.arange(0, BM) 

177 offs_n = pid_n * BN + tl.arange(0, BN) 

178 offs_k = tl.arange(0, BK) 

179 a_row_base = offs_m - pid_m * BM 

180 a_rows_full = tl.broadcast_to(a_row_base[:, None], (BM, BK)) 

181 a_cols_full = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK)) 

182 a_rows_t = tl.broadcast_to(a_row_base[None, :], (DOT_K, BM)) 

183 a_buf = tle_exp.gpu.alloc( 

184 [A_SLOTS, BM, BK], 

185 dtype=tl.float16, 

186 layout=None, 

187 scope=tle_exp.gpu.smem, 

188 nv_mma_shared_layout=USE_NV_MMA_SMEM_LAYOUT, 

189 ) 

190 a_buf_remote = tle_exp.remote(a_buf, 0, scope=mesh) 

191 

192 acc = tl.zeros((BM, BN), dtype=tl.float32) 

193 slot0 = 0 

194 slot0_full = tl.zeros((BM, BK), dtype=tl.int32) + slot0 

195 if cluster_rank == 0: 

196 a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

197 if USE_MASK: 

198 a_mask_tile = (offs_m[:, None] < M) & (offs_k[None, :] < K) 

199 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

200 else: 

201 a_tile = tl.load(a_ptrs) 

202 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

203 a_buf, (slot0_full, a_rows_full, a_cols_full) 

204 ) 

205 if USE_MASK: 

206 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

207 else: 

208 tl.store(a_local_ptr_tile, a_tile) 

209 

210 tle_exp.distributed_barrier(mesh) 

211 

212 for k0 in range(0, K, BK): 

213 iter_idx = k0 // BK 

214 slot = iter_idx % A_SLOTS 

215 

216 for ks in range(0, BK, DOT_K): 

217 k_local = ks + tl.arange(0, DOT_K) 

218 a_cols_t = tl.broadcast_to(k_local[:, None], (DOT_K, BM)) 

219 slot_dot_t = tl.zeros((DOT_K, BM), dtype=tl.int32) + slot 

220 a_ptr_remote = tle_exp.gpu.local_ptr( 

221 a_buf_remote, (slot_dot_t, a_rows_t, a_cols_t) 

222 ) 

223 if USE_MASK: 

224 a_mask_t = ((k0 + k_local)[:, None] < K) & (offs_m[None, :] < M) 

225 a = tl.trans(tl.load(a_ptr_remote, mask=a_mask_t, other=0.0)) 

226 else: 

227 a = tl.trans(tl.load(a_ptr_remote)) 

228 

229 b_ptrs = ( 

230 b_ptr 

231 + (k0 + k_local)[:, None] * stride_bk 

232 + offs_n[None, :] * stride_bn 

233 ) 

234 if USE_MASK: 

235 b_mask = ((k0 + k_local)[:, None] < K) & (offs_n[None, :] < N) 

236 b = tl.load(b_ptrs, mask=b_mask, other=0.0) 

237 else: 

238 b = tl.load(b_ptrs) 

239 acc = tl.dot(a, b, acc) 

240 

241 if A_SLOTS == 1: 

242 tle_exp.distributed_barrier(mesh) 

243 

244 next_k0 = k0 + BK 

245 has_next = next_k0 < K 

246 next_iter = iter_idx + 1 

247 next_slot = next_iter % A_SLOTS 

248 next_slot_full = tl.zeros((BM, BK), dtype=tl.int32) + next_slot 

249 if has_next and cluster_rank == 0: 

250 a_ptrs = ( 

251 a_ptr 

252 + offs_m[:, None] * stride_am 

253 + (next_k0 + offs_k)[None, :] * stride_ak 

254 ) 

255 if USE_MASK: 

256 a_mask_tile = (offs_m[:, None] < M) & ( 

257 (next_k0 + offs_k)[None, :] < K 

258 ) 

259 a_tile = tl.load(a_ptrs, mask=a_mask_tile, other=0.0) 

260 else: 

261 a_tile = tl.load(a_ptrs) 

262 a_local_ptr_tile = tle_exp.gpu.local_ptr( 

263 a_buf, (next_slot_full, a_rows_full, a_cols_full) 

264 ) 

265 if USE_MASK: 

266 tl.store(a_local_ptr_tile, a_tile, mask=a_mask_tile) 

267 else: 

268 tl.store(a_local_ptr_tile, a_tile) 

269 

270 tle_exp.distributed_barrier(mesh) 

271 

272 c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn 

273 if USE_MASK: 

274 c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

275 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask) 

276 else: 

277 tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty)) 

278 

279 

280def _select_remote_dot_k(bk: int) -> int: 

281 if bk % 16 == 0: 

282 return 16 

283 raise ValueError(f"BK must be divisible by 16 for remote dot path, got BK={bk}") 

284 

285 

286def _grid_cluster_remote( 

287 M: int, 

288 N: int, 

289 BM: int, 

290 BN: int, 

291 cluster_size: int = TLE_CLUSTER_SIZE, 

292) -> tuple[int]: 

293 num_pid_n = triton.cdiv(N, BN) 

294 num_pid_n_group = triton.cdiv(num_pid_n, cluster_size) 

295 return (triton.cdiv(M, BM) * num_pid_n_group,) 

296 

297 

298def _run_cluster_remote( 

299 a: torch.Tensor, 

300 b: torch.Tensor, 

301 c: torch.Tensor, 

302 bm: int, 

303 bn: int, 

304 bk: int, 

305 num_warps: int, 

306 num_stages: int, 

307) -> None: 

308 M, K = a.shape 

309 N = b.shape[1] 

310 dot_k = _select_remote_dot_k(bk) 

311 use_mask = (M % bm != 0) or (N % bn != 0) or (K % bk != 0) 

312 a_slots = TLE_REMOTE_A_SLOTS 

313 use_nv_mma_smem_layout = (bk == 32) or (bk == 64 and num_stages <= 2) 

314 _cluster_remote_gemm_kernel[_grid_cluster_remote(M, N, bm, bn)]( 

315 a, 

316 b, 

317 c, 

318 M, 

319 N, 

320 K, 

321 a.stride(0), 

322 a.stride(1), 

323 b.stride(0), 

324 b.stride(1), 

325 c.stride(0), 

326 c.stride(1), 

327 mesh=BLOCK_CLUSTER_MESH, 

328 BM=bm, 

329 BN=bn, 

330 BK=bk, 

331 DOT_K=dot_k, 

332 CLUSTER_SIZE=TLE_CLUSTER_SIZE, 

333 USE_MASK=use_mask, 

334 A_SLOTS=a_slots, 

335 USE_NV_MMA_SMEM_LAYOUT=use_nv_mma_smem_layout, 

336 num_ctas=1, 

337 num_warps=num_warps, 

338 num_stages=num_stages, 

339 ) 

340 

341 

342def cluster_remote_mm_scenario(a, b, c, M, N, K): 

343 capability = get_device_capability() 

344 return ( 

345 HAS_TLE 

346 and BLOCK_CLUSTER_MESH is not None 

347 and capability[0] >= 9 

348 and a.is_cuda 

349 and b.is_cuda 

350 and c.is_cuda 

351 and a.dtype == torch.float16 

352 and b.dtype == torch.float16 

353 and c.dtype == torch.float16 

354 and a.is_contiguous() 

355 and b.is_contiguous() 

356 and M >= TLE_REMOTE_BM 

357 and N >= TLE_REMOTE_BN 

358 and K >= TLE_REMOTE_BK 

359 ) 

360 

361 

362def cluster_remote_mm(a, b, c, M, N, K): 

363 logger.debug( 

364 "GEMS MM [cluster_remote]: M=%s N=%s K=%s, A_col_major=%s, B_col_major=%s", 

365 M, 

366 N, 

367 K, 

368 a.stride(0) == 1, 

369 b.stride(0) == 1, 

370 ) 

371 with torch_device_fn.device(a.device): 

372 _run_cluster_remote( 

373 a, 

374 b, 

375 c, 

376 TLE_REMOTE_BM, 

377 TLE_REMOTE_BN, 

378 TLE_REMOTE_BK, 

379 TLE_REMOTE_NUM_WARPS, 

380 TLE_REMOTE_NUM_STAGES, 

381 ) 

382 return c 

383 

384 

385_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64] 

386 

387 

388def get_higher_dtype(a, b): 

389 if a is b: 

390 return a 

391 

392 assert a in _ordered_datatypes 

393 assert b in _ordered_datatypes 

394 

395 for d in _ordered_datatypes: 

396 if a is d: 

397 return b 

398 if b is d: 

399 return a 

400 

401 

402def general_mm(a, b, c, M, N, K): 

403 grid = lambda META: ( 

404 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

405 ) 

406 with torch_device_fn.device(a.device): 

407 mm_kernel_general[grid]( 

408 a, 

409 b, 

410 c, 

411 M, 

412 N, 

413 K, 

414 a.stride(0), 

415 a.stride(1), 

416 b.stride(0), 

417 b.stride(1), 

418 c.stride(0), 

419 c.stride(1), 

420 GROUP_M=8, 

421 IS_FP64=a.dtype == torch.float64, 

422 ) 

423 return c 

424 

425 

426@libentry() 

427@libtuner( 

428 configs=runtime.get_tuned_config("mm_self_transpose"), 

429 key=["M", "K", "stride_am", "stride_ak"], 

430 strategy=["align32", "align32", "align32", "align32"], 

431 warmup=2, 

432 rep=4, 

433) 

434@triton.jit 

435def mm_kernel_syrk( 

436 A, 

437 C, 

438 M, 

439 K, 

440 stride_am, 

441 stride_ak, 

442 stride_cm, 

443 stride_cn, 

444 BLOCK_M: tl.constexpr, 

445 BLOCK_K: tl.constexpr, 

446): 

447 pid = tl.program_id(0) 

448 

449 # Packed lower-triangular launch domain: 

450 # pid = row * (row + 1) / 2 + col, where 0 <= col <= row. 

451 # 

452 # Invert the triangular-number indexing by solving: 

453 # row^2 + row - 2 * pid = 0 

454 # => row = (-1 + sqrt(1 + 8 * pid)) / 2 

455 # 

456 # We take floor(...) as the candidate row, then apply an integer +/-1 correction 

457 # because fp32 sqrt can be off near triangular-number boundaries. 

458 pid_f = pid.to(tl.float32) 

459 pid_m = tl.floor((tl.sqrt(8.0 * pid_f + 1.0) - 1.0) / 2.0).to(tl.int32) 

460 tri_start = pid_m * (pid_m + 1) // 2 

461 pid_m = tl.where(tri_start > pid, pid_m - 1, pid_m) 

462 next_tri_start = (pid_m + 1) * (pid_m + 2) // 2 

463 pid_m = tl.where(next_tri_start <= pid, pid_m + 1, pid_m) 

464 tri_start = pid_m * (pid_m + 1) // 2 

465 pid_n = pid - tri_start 

466 

467 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

468 rn = pid_n * BLOCK_M + tl.arange(0, BLOCK_M) 

469 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64) 

470 ran = tl.max_contiguous(tl.multiple_of(rn % M, BLOCK_M), BLOCK_M).to(tl.int64) 

471 rm = rm.to(tl.int64) 

472 rn = rn.to(tl.int64) 

473 acc = tl.zeros((BLOCK_M, BLOCK_M), dtype=tl.float32) 

474 

475 for start_k in range(0, K, BLOCK_K): 

476 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64) 

477 mask_k = rk < K 

478 a = tl.load( 

479 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak), 

480 mask=mask_k[None, :], 

481 other=0.0, 

482 ) 

483 b = tl.load( 

484 A + (rk[:, None] * stride_ak + ran[None, :] * stride_am), 

485 mask=mask_k[:, None], 

486 other=0.0, 

487 ) 

488 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False) 

489 

490 out = acc.to(C.dtype.element_ty) 

491 c_ptr = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

492 mask = (rm < M)[:, None] & (rn < M)[None, :] 

493 tl.store(c_ptr, out, mask=mask) 

494 

495 if pid_m > pid_n: 

496 c_t_ptr = C + (rn[:, None] * stride_cm + rm[None, :] * stride_cn) 

497 mask_t = (rn < M)[:, None] & (rm < M)[None, :] 

498 tl.store(c_t_ptr, tl.trans(out), mask=mask_t) 

499 

500 

501def is_syrk_transpose_pair(a, b): 

502 return ( 

503 a.ndim == 2 

504 and b.ndim == 2 

505 and a.shape[0] == b.shape[1] 

506 and a.shape[1] == b.shape[0] 

507 and a.stride(0) == b.stride(1) 

508 and a.stride(1) == b.stride(0) 

509 and a.storage_offset() == b.storage_offset() 

510 and a.data_ptr() == b.data_ptr() 

511 ) 

512 

513 

514def syrk_mm(a, c, M, K): 

515 grid = lambda META: ( 

516 # Number of tile rows is tiles = ceil(M / BLOCK_M). 

517 # Packed lower triangle contains: 

518 # 1 + 2 + ... + tiles = tiles * (tiles + 1) / 2 

519 triton.cdiv(M, META["BLOCK_M"]) 

520 * (triton.cdiv(M, META["BLOCK_M"]) + 1) 

521 // 2, 

522 ) 

523 with torch_device_fn.device(a.device): 

524 mm_kernel_syrk[grid]( 

525 a, 

526 c, 

527 M, 

528 K, 

529 a.stride(0), 

530 a.stride(1), 

531 c.stride(0), 

532 c.stride(1), 

533 ) 

534 return c 

535 

536 

537def streamk_scenario(a, b, M, N, K): 

538 # TODO: this my change sometime according to the realbenchmark result 

539 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8). 

540 # The optimal settings for other devices need to be determined through real testing. 

541 capability = get_device_capability() 

542 return ( 

543 capability[0] == 8 

544 and a.dtype in [torch.float16, torch.bfloat16] 

545 and b.dtype in [torch.float16, torch.bfloat16] 

546 and a.is_contiguous() 

547 and b.is_contiguous() 

548 and K > M * 5 

549 and K > N * 5 

550 ) 

551 

552 

553def mm(a, b): 

554 logger.debug("GEMS MM") 

555 

556 device = a.device 

557 if is_syrk_transpose_pair(a, b): 

558 M, K = a.shape 

559 c = torch.empty((M, M), device=device, dtype=a.dtype) 

560 return syrk_mm(a, c, M, K) 

561 # handle non-contiguous inputs if necessary 

562 if a.stride(0) > 1 and a.stride(1) > 1: 

563 a = a.contiguous() 

564 if b.stride(0) > 1 and b.stride(1) > 1: 

565 b = b.contiguous() 

566 # checks constraints 

567 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

568 M, K = a.shape 

569 _, N = b.shape 

570 # allocates output 

571 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

572 c = torch.empty((M, N), device=device, dtype=c_dtype) 

573 # l2_cache_size = get_l2_cache_size() 

574 sm_count = get_sm_count() 

575 if streamk_scenario(a, b, M, N, K): 

576 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count) 

577 if cluster_remote_mm_scenario(a, b, c, M, N, K): 

578 return cluster_remote_mm(a, b, c, M, N, K) 

579 return general_mm(a, b, c, M, N, K) 

580 

581 

582def mm_out(a, b, *, out): 

583 logger.debug("GEMS MM_OUT") 

584 

585 if is_syrk_transpose_pair(a, b): 

586 M, K = a.shape 

587 return syrk_mm(a, out, M, K) 

588 # handle non-contiguous inputs if necessary 

589 if a.stride(0) > 1 and a.stride(1) > 1: 

590 a = a.contiguous() 

591 if b.stride(0) > 1 and b.stride(1) > 1: 

592 b = b.contiguous() 

593 # checks constraints 

594 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

595 M, K = a.shape 

596 _, N = b.shape 

597 # l2_cache_size = get_l2_cache_size() 

598 sm_count = get_sm_count() 

599 if streamk_scenario(a, b, M, N, K): 

600 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count) 

601 if cluster_remote_mm_scenario(a, b, out, M, N, K): 

602 return cluster_remote_mm(a, b, out, M, N, K) 

603 return general_mm(a, b, out, M, N, K) 

604 

605 

606def router_gemm(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: 

607 """bf16 x bf16 -> fp32 GEMM for MoE router gate. weight shape: (N, K).""" 

608 if x.stride(0) > 1 and x.stride(1) > 1: 

609 x = x.contiguous() 

610 M, K = x.shape 

611 N = weight.shape[0] 

612 c = torch.empty((M, N), device=x.device, dtype=torch.float32) 

613 b = weight.t().contiguous() 

614 return general_mm(x, b, c, M, N, K)