Coverage for src/flag_gems/ops/group_gemm.py: 14%

273 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry, libtuner 

8from flag_gems.utils.device_info import get_device_capability, get_sm_count 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def supports_tma(): 

14 return get_device_capability()[0] >= 9 

15 

16 

17if hasattr(tl, "make_tensor_descriptor"): 

18 _support_device_tensor_descriptor = True 

19 make_tensor_descriptor_fn = tl.make_tensor_descriptor 

20else: 

21 _support_device_tensor_descriptor = False 

22 make_tensor_descriptor_fn = None 

23 

24try: 

25 from triton.tools.tensor_descriptor import TensorDescriptor 

26 

27 _support_host_tensor_descriptor = True 

28except ImportError: 

29 _support_host_tensor_descriptor = False 

30 

31 

32@triton.jit 

33def grouped_launch( 

34 pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr 

35): 

36 grid_m = tl.cdiv(m, block_m) 

37 grid_n = tl.cdiv(n, block_n) 

38 

39 width = group_m * grid_n 

40 group_id = pid // width 

41 group_size = tl.minimum(grid_m - group_id * group_m, group_m) 

42 pid_m = group_id * group_m + (pid % group_size) 

43 pid_n = (pid % width) // group_size 

44 

45 return pid_m, pid_n 

46 

47 

48def matmul_tma_set_block_size_hook(nargs): 

49 BLOCK_M = nargs["BLOCK_M"] 

50 BLOCK_N = nargs["BLOCK_N"] 

51 BLOCK_K = nargs["BLOCK_K"] 

52 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

53 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

54 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

55 

56 

57def get_autotune_config(pre_hook=None): 

58 return [ 

59 triton.Config( 

60 {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, 

61 num_stages=3, 

62 num_warps=8, 

63 pre_hook=pre_hook, 

64 ), 

65 triton.Config( 

66 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, 

67 num_stages=2, 

68 num_warps=4, 

69 pre_hook=pre_hook, 

70 ), 

71 triton.Config( 

72 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, 

73 num_stages=3, 

74 num_warps=4, 

75 pre_hook=pre_hook, 

76 ), 

77 triton.Config( 

78 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, 

79 num_stages=3, 

80 num_warps=8, 

81 pre_hook=pre_hook, 

82 ), 

83 triton.Config( 

84 {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 4}, 

85 num_stages=4, 

86 num_warps=4, 

87 pre_hook=pre_hook, 

88 ), 

89 triton.Config( 

90 {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 4}, 

91 num_stages=4, 

92 num_warps=4, 

93 pre_hook=pre_hook, 

94 ), 

95 triton.Config( 

96 {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, 

97 num_stages=3, 

98 num_warps=8, 

99 pre_hook=pre_hook, 

100 ), 

101 ] 

102 

103 

104@libentry() 

105@libtuner(configs=get_autotune_config(), key=["M", "N", "K"]) 

106@triton.jit 

107def grouped_gemm_tma_kernel( 

108 M, 

109 N, 

110 K, 

111 group_a_ptrs, 

112 group_b_ptrs, 

113 group_c_ptrs, 

114 group_out_ptrs, 

115 group_gemm_sizes, 

116 g_lds, 

117 group_size, 

118 BLOCK_M: tl.constexpr, 

119 BLOCK_N: tl.constexpr, 

120 BLOCK_K: tl.constexpr, 

121 GROUP_M: tl.constexpr, 

122 alpha: tl.constexpr, 

123 beta: tl.constexpr, 

124): 

125 tile_idx = tl.program_id(0) 

126 total_grid = tl.num_programs(0) 

127 last_problem_end = 0 

128 for g in range(group_size): 

129 gm = tl.load(group_gemm_sizes + g * 3) 

130 gn = tl.load(group_gemm_sizes + g * 3 + 1) 

131 gk = tl.load(group_gemm_sizes + g * 3 + 2) 

132 num_m_tiles = tl.cdiv(gm, BLOCK_M) 

133 num_n_tiles = tl.cdiv(gn, BLOCK_N) 

134 num_tiles = num_m_tiles * num_n_tiles 

135 

136 current_problem_end = last_problem_end + num_tiles 

137 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

138 lda = tl.load(g_lds + g * 3) 

139 ldb = tl.load(g_lds + g * 3 + 1) 

140 ldc = tl.load(g_lds + g * 3 + 2) 

141 

142 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

143 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

144 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

145 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

146 

147 a_desc = make_tensor_descriptor_fn( 

148 a_ptr, 

149 shape=[gm, gk], 

150 strides=[lda, 1], 

151 block_shape=[BLOCK_M, BLOCK_K], 

152 ) 

153 

154 b_desc = make_tensor_descriptor_fn( 

155 b_ptr, 

156 shape=[gk, gn], 

157 strides=[ldb, 1], 

158 block_shape=[BLOCK_K, BLOCK_N], 

159 ) 

160 

161 c_desc = make_tensor_descriptor_fn( 

162 c_ptr, 

163 shape=[gm, gn], 

164 strides=[ldc, 1], 

165 block_shape=[BLOCK_M, BLOCK_N], 

166 ) 

167 

168 out_desc = make_tensor_descriptor_fn( 

169 out_ptr, 

170 shape=[gm, gn], 

171 strides=[ldc, 1], 

172 block_shape=[BLOCK_M, BLOCK_N], 

173 ) 

174 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

175 for _ in tl.range(loop_count): 

176 tile_idx_in_gemm = tile_idx - last_problem_end 

177 tile_m_idx, tile_n_idx = grouped_launch( 

178 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M 

179 ) 

180 

181 offs_am = tile_m_idx * BLOCK_M 

182 offs_bn = tile_n_idx * BLOCK_N 

183 

184 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

185 for kk in range(0, tl.cdiv(gk, BLOCK_K)): 

186 a = a_desc.load([offs_am, kk * BLOCK_K]) 

187 b = b_desc.load([kk * BLOCK_K, offs_bn]) 

188 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

189 

190 offs_cm = tile_m_idx * BLOCK_M 

191 offs_cn = tile_n_idx * BLOCK_N 

192 

193 ori_c = c_desc.load([offs_cm, offs_cn]) 

194 accumulator = ori_c * beta + accumulator * alpha 

195 

196 c = accumulator.to(c_desc.dtype) 

197 out_desc.store([offs_cm, offs_cn], c) 

198 

199 tile_idx += total_grid 

200 

201 last_problem_end = current_problem_end 

202 

203 

204@libentry() 

205@libtuner(configs=get_autotune_config(), key=["M", "N", "K"]) 

206@triton.jit 

207def grouped_gemm_kernel( 

208 M, 

209 N, 

210 K, 

211 group_a_ptrs, 

212 group_b_ptrs, 

213 group_c_ptrs, 

214 group_out_ptrs, 

215 group_gemm_sizes, 

216 g_lds, 

217 group_size, 

218 BLOCK_M: tl.constexpr, 

219 BLOCK_N: tl.constexpr, 

220 BLOCK_K: tl.constexpr, 

221 GROUP_M: tl.constexpr, 

222 alpha: tl.constexpr, 

223 beta: tl.constexpr, 

224): 

225 tile_idx = tl.program_id(0) 

226 total_grid = tl.num_programs(0) 

227 last_problem_end = 0 

228 for g in range(group_size): 

229 gm = tl.load(group_gemm_sizes + g * 3) 

230 gn = tl.load(group_gemm_sizes + g * 3 + 1) 

231 gk = tl.load(group_gemm_sizes + g * 3 + 2) 

232 num_m_tiles = tl.cdiv(gm, BLOCK_M) 

233 num_n_tiles = tl.cdiv(gn, BLOCK_N) 

234 num_tiles = num_m_tiles * num_n_tiles 

235 current_problem_end = last_problem_end + num_tiles 

236 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

237 lda = tl.load(g_lds + g * 3) 

238 ldb = tl.load(g_lds + g * 3 + 1) 

239 ldc = tl.load(g_lds + g * 3 + 2) 

240 

241 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

242 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

243 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

244 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

245 

246 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

247 for _ in tl.range(loop_count): 

248 tile_idx_in_gemm = tile_idx - last_problem_end 

249 tile_m_idx, tile_n_idx = grouped_launch( 

250 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M 

251 ) 

252 

253 offs_am = tile_m_idx * BLOCK_M 

254 offs_bn = tile_n_idx * BLOCK_N 

255 

256 a_ptrs = tl.make_block_ptr( 

257 base=a_ptr, 

258 shape=(gm, gk), 

259 strides=(lda, 1), 

260 offsets=(offs_am, 0), 

261 block_shape=(BLOCK_M, BLOCK_K), 

262 order=(1, 0), 

263 ) 

264 b_ptrs = tl.make_block_ptr( 

265 base=b_ptr, 

266 shape=(gk, gn), 

267 strides=(ldb, 1), 

268 offsets=(0, offs_bn), 

269 block_shape=(BLOCK_K, BLOCK_N), 

270 order=(1, 0), 

271 ) 

272 

273 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

274 for kk in range(0, tl.cdiv(gk, BLOCK_K)): 

275 a = tl.load(a_ptrs, boundary_check=(0, 1)) 

276 b = tl.load(b_ptrs, boundary_check=(0, 1)) 

277 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

278 a_ptrs = tl.advance(a_ptrs, (0, BLOCK_K)) 

279 b_ptrs = tl.advance(b_ptrs, (BLOCK_K, 0)) 

280 

281 offs_cm = tile_m_idx * BLOCK_M 

282 offs_cn = tile_n_idx * BLOCK_N 

283 

284 c_ptrs = tl.make_block_ptr( 

285 base=c_ptr, 

286 shape=(gm, gn), 

287 strides=(ldc, 1), 

288 offsets=(offs_cm, offs_cn), 

289 block_shape=(BLOCK_M, BLOCK_N), 

290 order=(1, 0), 

291 ) 

292 

293 out_ptrs = tl.make_block_ptr( 

294 base=out_ptr, 

295 shape=(gm, gn), 

296 strides=(ldc, 1), 

297 offsets=(offs_cm, offs_cn), 

298 block_shape=(BLOCK_M, BLOCK_N), 

299 order=(1, 0), 

300 ) 

301 ori_c = tl.load(c_ptrs, boundary_check=(0, 1)) 

302 accumulator = ori_c * beta + accumulator * alpha 

303 

304 c = accumulator.to(c_ptrs.dtype.element_ty) 

305 tl.store(out_ptrs, c, boundary_check=(0, 1)) 

306 

307 tile_idx += total_grid 

308 

309 last_problem_end = current_problem_end 

310 

311 

312@libentry() 

313@libtuner( 

314 configs=get_autotune_config(matmul_tma_set_block_size_hook), key=["M", "N", "K"] 

315) 

316@triton.jit 

317def grouped_mm_tma_kernel( 

318 a_desc, 

319 b_desc, 

320 c_desc, 

321 C, 

322 offs, 

323 num_groups: tl.constexpr, 

324 M, 

325 N: tl.constexpr, 

326 K: tl.constexpr, 

327 stride_cm: tl.constexpr, 

328 stride_cn: tl.constexpr, 

329 BLOCK_M: tl.constexpr, 

330 BLOCK_N: tl.constexpr, 

331 BLOCK_K: tl.constexpr, 

332 GROUP_M: tl.constexpr, 

333): 

334 total_grid = tl.num_programs(axis=0) 

335 tile_idx = tl.program_id(axis=0) 

336 num_n_tiles = tl.cdiv(N, BLOCK_N) 

337 last_problem_end = 0 

338 group_start = 0 

339 group_end = 0 

340 

341 for group_idx in tl.range(num_groups): 

342 group_end = tl.load(offs + group_idx).to(tl.int32) 

343 m = group_end - group_start 

344 num_m_tiles = tl.cdiv(m, BLOCK_M) 

345 num_tiles = num_m_tiles * num_n_tiles 

346 

347 current_problem_end = last_problem_end + num_tiles 

348 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

349 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

350 for _ in tl.range(loop_count): 

351 tile_idx_in_gemm = tile_idx - last_problem_end 

352 tile_m_idx, tile_n_idx = grouped_launch( 

353 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M 

354 ) 

355 

356 offs_am = group_start + tile_m_idx * BLOCK_M 

357 offs_bn = tile_n_idx * BLOCK_N 

358 offs_bk = group_idx * K 

359 

360 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

361 

362 for k in tl.range(0, tl.cdiv(K, BLOCK_K)): 

363 a = a_desc.load([offs_am, k * BLOCK_K]) 

364 b = b_desc.load([offs_bk + k * BLOCK_K, offs_bn]) 

365 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

366 

367 c = accumulator.to(c_desc.dtype) 

368 

369 if offs_am + BLOCK_M <= group_end: 

370 c_desc.store([offs_am, offs_bn], c) 

371 else: 

372 offs_cm = offs_am + tl.arange(0, BLOCK_M) 

373 offs_cn = offs_bn + tl.arange(0, BLOCK_N) 

374 c_ptrs = ( 

375 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

376 ) 

377 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N) 

378 tl.store(c_ptrs, c, mask=c_mask) 

379 

380 tile_idx += total_grid 

381 

382 last_problem_end = current_problem_end 

383 group_start = group_end 

384 

385 

386@libentry() 

387@libtuner(configs=get_autotune_config(), key=["M", "N", "K"]) 

388@triton.jit 

389def grouped_mm_kernel( 

390 A, 

391 B, 

392 C, 

393 offs, 

394 num_groups: tl.constexpr, 

395 M, 

396 N: tl.constexpr, 

397 K: tl.constexpr, 

398 stride_am: tl.constexpr, 

399 stride_ak: tl.constexpr, 

400 stride_bk: tl.constexpr, 

401 stride_bn: tl.constexpr, 

402 stride_cm: tl.constexpr, 

403 stride_cn: tl.constexpr, 

404 BLOCK_M: tl.constexpr, 

405 BLOCK_N: tl.constexpr, 

406 BLOCK_K: tl.constexpr, 

407 GROUP_M: tl.constexpr, 

408): 

409 total_grid = tl.num_programs(axis=0) 

410 tile_idx = tl.program_id(axis=0) 

411 num_n_tiles = tl.cdiv(N, BLOCK_N) 

412 last_problem_end = 0 

413 group_start = 0 

414 group_end = 0 

415 

416 for group_idx in tl.range(num_groups): 

417 group_end = tl.load(offs + group_idx).to(tl.int32) 

418 m = group_end - group_start 

419 num_m_tiles = tl.cdiv(m, BLOCK_M) 

420 num_tiles = num_m_tiles * num_n_tiles 

421 

422 current_problem_end = last_problem_end + num_tiles 

423 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

424 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

425 for _ in tl.range(loop_count): 

426 tile_idx_in_gemm = tile_idx - last_problem_end 

427 tile_m_idx, tile_n_idx = grouped_launch( 

428 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M 

429 ) 

430 

431 offs_am = group_start + tile_m_idx * BLOCK_M 

432 offs_bn = tile_n_idx * BLOCK_N 

433 offs_bk = group_idx * K 

434 

435 a_block_ptr = tl.make_block_ptr( 

436 base=A, 

437 shape=(M, K), 

438 strides=(stride_am, stride_ak), 

439 offsets=(offs_am, 0), 

440 block_shape=(BLOCK_M, BLOCK_K), 

441 order=(1, 0), 

442 ) 

443 

444 b_block_ptr = tl.make_block_ptr( 

445 base=B, 

446 shape=(num_groups * K, N), 

447 strides=(stride_bk, stride_bn), 

448 offsets=(offs_bk, offs_bn), 

449 block_shape=(BLOCK_K, BLOCK_N), 

450 order=(1, 0), 

451 ) 

452 

453 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

454 

455 for k in tl.range(0, tl.cdiv(K, BLOCK_K)): 

456 a = tl.load(a_block_ptr, boundary_check=(0, 1)) 

457 b = tl.load(b_block_ptr, boundary_check=(0, 1)) 

458 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

459 

460 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) 

461 b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) 

462 

463 c = accumulator.to(C.dtype.element_ty) 

464 

465 c_block_ptr = tl.make_block_ptr( 

466 base=C, 

467 shape=(M, N), 

468 strides=(stride_cm, stride_cn), 

469 offsets=(offs_am, offs_bn), 

470 block_shape=(BLOCK_M, BLOCK_N), 

471 order=(1, 0), 

472 ) 

473 

474 if offs_am + BLOCK_M <= group_end: 

475 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

476 else: 

477 offs_cm = offs_am + tl.arange(0, BLOCK_M) 

478 offs_cn = offs_bn + tl.arange(0, BLOCK_N) 

479 c_ptrs = ( 

480 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

481 ) 

482 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N) 

483 tl.store(c_ptrs, c, mask=c_mask) 

484 

485 tile_idx += total_grid 

486 

487 last_problem_end = current_problem_end 

488 group_start = group_end 

489 

490 

491def group_gemm(group_A, group_B, group_C, offs_table, alpha=1, beta=0): 

492 A_addrs = [] 

493 B_addrs = [] 

494 C_addrs = [] 

495 group_sizes = [] 

496 group_lds = [] 

497 group_size = len(offs_table) 

498 M, N = group_C.shape 

499 K = group_A.shape[1] 

500 group_out = torch.empty((M, N), device=group_A.device, dtype=group_A.dtype) 

501 out_addrs = [] 

502 for i in range(group_size): 

503 M_g = offs_table[i][0] 

504 N_g = offs_table[i][1] 

505 K_g = offs_table[i][2] 

506 A_g = group_A[offs_table[i][3]] 

507 B_g = group_B[offs_table[i][4]] 

508 C_g = group_C[offs_table[i][5]] 

509 out_g = group_out[offs_table[i][5]] 

510 group_sizes += [M_g, N_g, K_g] 

511 group_lds += [K_g, N_g, N_g] 

512 A_addrs.append(A_g.data_ptr()) 

513 B_addrs.append(B_g.data_ptr()) 

514 C_addrs.append(C_g.data_ptr()) 

515 out_addrs.append(out_g.data_ptr()) 

516 

517 d_a_ptrs = torch.tensor(A_addrs, device=group_A.device) 

518 d_b_ptrs = torch.tensor(B_addrs, device=group_A.device) 

519 d_c_ptrs = torch.tensor(C_addrs, device=group_A.device) 

520 d_output_ptrs = torch.tensor(out_addrs, device=group_A.device) 

521 d_g_sizes = torch.tensor(group_sizes, dtype=torch.int32, device=group_A.device) 

522 d_g_lds = torch.tensor(group_lds, dtype=torch.int32, device=group_A.device) 

523 NUM_SMS = get_sm_count() 

524 

525 if _support_device_tensor_descriptor and supports_tma(): 

526 

527 def alloc_fn(size, alignment, stream): 

528 return torch.empty(size, device=group_A.device, dtype=torch.int8) 

529 

530 triton.set_allocator(alloc_fn) 

531 grouped_gemm_tma_kernel[(NUM_SMS,)]( 

532 M, 

533 N, 

534 K, 

535 d_a_ptrs, 

536 d_b_ptrs, 

537 d_c_ptrs, 

538 d_output_ptrs, 

539 d_g_sizes, 

540 d_g_lds, 

541 group_size, 

542 alpha=alpha, 

543 beta=beta, 

544 ) 

545 else: 

546 grouped_gemm_kernel[(NUM_SMS,)]( 

547 M, 

548 N, 

549 K, 

550 d_a_ptrs, 

551 d_b_ptrs, 

552 d_c_ptrs, 

553 d_output_ptrs, 

554 d_g_sizes, 

555 d_g_lds, 

556 group_size, 

557 alpha=alpha, 

558 beta=beta, 

559 ) 

560 

561 return group_out 

562 

563 

564def group_mm(A: torch.Tensor, B: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: 

565 assert A.dim() == 2 

566 assert B.dim() == 3 

567 M, K = A.shape 

568 

569 num_groups, BK, N = B.shape 

570 strideBK, strideBN = B.stride(1), B.stride(2) 

571 

572 assert num_groups == offs.numel() 

573 NUM_SMS = get_sm_count() 

574 C = A.new_empty(M, N) 

575 if _support_host_tensor_descriptor and supports_tma(): 

576 dummy_block = [1, 1] 

577 

578 a_desc = TensorDescriptor(A, A.shape, A.stride(), dummy_block) 

579 b_desc = TensorDescriptor( 

580 B, [num_groups * K, N], [strideBK, strideBN], dummy_block 

581 ) 

582 c_desc = TensorDescriptor(C, C.shape, C.stride(), dummy_block) 

583 

584 grouped_mm_tma_kernel[(NUM_SMS,)]( 

585 a_desc, 

586 b_desc, 

587 c_desc, 

588 C, 

589 offs, 

590 num_groups, 

591 M, 

592 N, 

593 K, 

594 C.stride(0), 

595 C.stride(1), 

596 ) 

597 else: 

598 grouped_mm_kernel[(NUM_SMS,)]( 

599 A, 

600 B, 

601 C, 

602 offs, 

603 num_groups, 

604 M, 

605 N, 

606 K, 

607 A.stride(0), 

608 A.stride(1), 

609 strideBK, 

610 strideBN, 

611 C.stride(0), 

612 C.stride(1), 

613 ) 

614 

615 return C