Coverage for src/flag_gems/ops/group_gemm.py: 14%

272 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry, libtuner 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12def supports_tma(): 

13 return torch.cuda.get_device_capability()[0] >= 9 

14 

15 

16if hasattr(tl, "make_tensor_descriptor"): 

17 _support_device_tensor_descriptor = True 

18 make_tensor_descriptor_fn = tl.make_tensor_descriptor 

19else: 

20 _support_device_tensor_descriptor = False 

21 make_tensor_descriptor_fn = None 

22 

23try: 

24 from triton.tools.tensor_descriptor import TensorDescriptor 

25 

26 _support_host_tensor_descriptor = True 

27except ImportError: 

28 _support_host_tensor_descriptor = False 

29 

30 

31@triton.jit 

32def grouped_launch( 

33 pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr 

34): 

35 grid_m = tl.cdiv(m, block_m) 

36 grid_n = tl.cdiv(n, block_n) 

37 

38 width = group_m * grid_n 

39 group_id = pid // width 

40 group_size = tl.minimum(grid_m - group_id * group_m, group_m) 

41 pid_m = group_id * group_m + (pid % group_size) 

42 pid_n = (pid % width) // group_size 

43 

44 return pid_m, pid_n 

45 

46 

47def matmul_tma_set_block_size_hook(nargs): 

48 BLOCK_M = nargs["BLOCK_M"] 

49 BLOCK_N = nargs["BLOCK_N"] 

50 BLOCK_K = nargs["BLOCK_K"] 

51 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] 

52 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N] 

53 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] 

54 

55 

56def get_autotune_config(pre_hook=None): 

57 return [ 

58 triton.Config( 

59 {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, 

60 num_stages=3, 

61 num_warps=8, 

62 pre_hook=pre_hook, 

63 ), 

64 triton.Config( 

65 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, 

66 num_stages=2, 

67 num_warps=4, 

68 pre_hook=pre_hook, 

69 ), 

70 triton.Config( 

71 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, 

72 num_stages=3, 

73 num_warps=4, 

74 pre_hook=pre_hook, 

75 ), 

76 triton.Config( 

77 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, 

78 num_stages=3, 

79 num_warps=8, 

80 pre_hook=pre_hook, 

81 ), 

82 triton.Config( 

83 {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 4}, 

84 num_stages=4, 

85 num_warps=4, 

86 pre_hook=pre_hook, 

87 ), 

88 triton.Config( 

89 {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 4}, 

90 num_stages=4, 

91 num_warps=4, 

92 pre_hook=pre_hook, 

93 ), 

94 triton.Config( 

95 {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, 

96 num_stages=3, 

97 num_warps=8, 

98 pre_hook=pre_hook, 

99 ), 

100 ] 

101 

102 

103@libentry() 

104@libtuner(configs=get_autotune_config(), key=["M", "N", "K"]) 

105@triton.jit 

106def grouped_gemm_tma_kernel( 

107 M, 

108 N, 

109 K, 

110 group_a_ptrs, 

111 group_b_ptrs, 

112 group_c_ptrs, 

113 group_out_ptrs, 

114 group_gemm_sizes, 

115 g_lds, 

116 group_size, 

117 BLOCK_M: tl.constexpr, 

118 BLOCK_N: tl.constexpr, 

119 BLOCK_K: tl.constexpr, 

120 GROUP_M: tl.constexpr, 

121 alpha: tl.constexpr, 

122 beta: tl.constexpr, 

123): 

124 tile_idx = tl.program_id(0) 

125 total_grid = tl.num_programs(0) 

126 last_problem_end = 0 

127 for g in range(group_size): 

128 gm = tl.load(group_gemm_sizes + g * 3) 

129 gn = tl.load(group_gemm_sizes + g * 3 + 1) 

130 gk = tl.load(group_gemm_sizes + g * 3 + 2) 

131 num_m_tiles = tl.cdiv(gm, BLOCK_M) 

132 num_n_tiles = tl.cdiv(gn, BLOCK_N) 

133 num_tiles = num_m_tiles * num_n_tiles 

134 

135 current_problem_end = last_problem_end + num_tiles 

136 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

137 lda = tl.load(g_lds + g * 3) 

138 ldb = tl.load(g_lds + g * 3 + 1) 

139 ldc = tl.load(g_lds + g * 3 + 2) 

140 

141 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

142 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

143 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

144 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

145 

146 a_desc = make_tensor_descriptor_fn( 

147 a_ptr, 

148 shape=[gm, gk], 

149 strides=[lda, 1], 

150 block_shape=[BLOCK_M, BLOCK_K], 

151 ) 

152 

153 b_desc = make_tensor_descriptor_fn( 

154 b_ptr, 

155 shape=[gk, gn], 

156 strides=[ldb, 1], 

157 block_shape=[BLOCK_K, BLOCK_N], 

158 ) 

159 

160 c_desc = make_tensor_descriptor_fn( 

161 c_ptr, 

162 shape=[gm, gn], 

163 strides=[ldc, 1], 

164 block_shape=[BLOCK_M, BLOCK_N], 

165 ) 

166 

167 out_desc = make_tensor_descriptor_fn( 

168 out_ptr, 

169 shape=[gm, gn], 

170 strides=[ldc, 1], 

171 block_shape=[BLOCK_M, BLOCK_N], 

172 ) 

173 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

174 for _ in tl.range(loop_count): 

175 tile_idx_in_gemm = tile_idx - last_problem_end 

176 tile_m_idx, tile_n_idx = grouped_launch( 

177 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M 

178 ) 

179 

180 offs_am = tile_m_idx * BLOCK_M 

181 offs_bn = tile_n_idx * BLOCK_N 

182 

183 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

184 for kk in range(0, tl.cdiv(gk, BLOCK_K)): 

185 a = a_desc.load([offs_am, kk * BLOCK_K]) 

186 b = b_desc.load([kk * BLOCK_K, offs_bn]) 

187 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

188 

189 offs_cm = tile_m_idx * BLOCK_M 

190 offs_cn = tile_n_idx * BLOCK_N 

191 

192 ori_c = c_desc.load([offs_cm, offs_cn]) 

193 accumulator = ori_c * beta + accumulator * alpha 

194 

195 c = accumulator.to(c_desc.dtype) 

196 out_desc.store([offs_cm, offs_cn], c) 

197 

198 tile_idx += total_grid 

199 

200 last_problem_end = current_problem_end 

201 

202 

203@libentry() 

204@libtuner(configs=get_autotune_config(), key=["M", "N", "K"]) 

205@triton.jit 

206def grouped_gemm_kernel( 

207 M, 

208 N, 

209 K, 

210 group_a_ptrs, 

211 group_b_ptrs, 

212 group_c_ptrs, 

213 group_out_ptrs, 

214 group_gemm_sizes, 

215 g_lds, 

216 group_size, 

217 BLOCK_M: tl.constexpr, 

218 BLOCK_N: tl.constexpr, 

219 BLOCK_K: tl.constexpr, 

220 GROUP_M: tl.constexpr, 

221 alpha: tl.constexpr, 

222 beta: tl.constexpr, 

223): 

224 tile_idx = tl.program_id(0) 

225 total_grid = tl.num_programs(0) 

226 last_problem_end = 0 

227 for g in range(group_size): 

228 gm = tl.load(group_gemm_sizes + g * 3) 

229 gn = tl.load(group_gemm_sizes + g * 3 + 1) 

230 gk = tl.load(group_gemm_sizes + g * 3 + 2) 

231 num_m_tiles = tl.cdiv(gm, BLOCK_M) 

232 num_n_tiles = tl.cdiv(gn, BLOCK_N) 

233 num_tiles = num_m_tiles * num_n_tiles 

234 current_problem_end = last_problem_end + num_tiles 

235 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

236 lda = tl.load(g_lds + g * 3) 

237 ldb = tl.load(g_lds + g * 3 + 1) 

238 ldc = tl.load(g_lds + g * 3 + 2) 

239 

240 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

241 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

242 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

243 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16)) 

244 

245 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

246 for _ in tl.range(loop_count): 

247 tile_idx_in_gemm = tile_idx - last_problem_end 

248 tile_m_idx, tile_n_idx = grouped_launch( 

249 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M 

250 ) 

251 

252 offs_am = tile_m_idx * BLOCK_M 

253 offs_bn = tile_n_idx * BLOCK_N 

254 

255 a_ptrs = tl.make_block_ptr( 

256 base=a_ptr, 

257 shape=(gm, gk), 

258 strides=(lda, 1), 

259 offsets=(offs_am, 0), 

260 block_shape=(BLOCK_M, BLOCK_K), 

261 order=(1, 0), 

262 ) 

263 b_ptrs = tl.make_block_ptr( 

264 base=b_ptr, 

265 shape=(gk, gn), 

266 strides=(ldb, 1), 

267 offsets=(0, offs_bn), 

268 block_shape=(BLOCK_K, BLOCK_N), 

269 order=(1, 0), 

270 ) 

271 

272 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

273 for kk in range(0, tl.cdiv(gk, BLOCK_K)): 

274 a = tl.load(a_ptrs, boundary_check=(0, 1)) 

275 b = tl.load(b_ptrs, boundary_check=(0, 1)) 

276 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

277 a_ptrs = tl.advance(a_ptrs, (0, BLOCK_K)) 

278 b_ptrs = tl.advance(b_ptrs, (BLOCK_K, 0)) 

279 

280 offs_cm = tile_m_idx * BLOCK_M 

281 offs_cn = tile_n_idx * BLOCK_N 

282 

283 c_ptrs = tl.make_block_ptr( 

284 base=c_ptr, 

285 shape=(gm, gn), 

286 strides=(ldc, 1), 

287 offsets=(offs_cm, offs_cn), 

288 block_shape=(BLOCK_M, BLOCK_N), 

289 order=(1, 0), 

290 ) 

291 

292 out_ptrs = tl.make_block_ptr( 

293 base=out_ptr, 

294 shape=(gm, gn), 

295 strides=(ldc, 1), 

296 offsets=(offs_cm, offs_cn), 

297 block_shape=(BLOCK_M, BLOCK_N), 

298 order=(1, 0), 

299 ) 

300 ori_c = tl.load(c_ptrs, boundary_check=(0, 1)) 

301 accumulator = ori_c * beta + accumulator * alpha 

302 

303 c = accumulator.to(c_ptrs.dtype.element_ty) 

304 tl.store(out_ptrs, c, boundary_check=(0, 1)) 

305 

306 tile_idx += total_grid 

307 

308 last_problem_end = current_problem_end 

309 

310 

311@libentry() 

312@libtuner( 

313 configs=get_autotune_config(matmul_tma_set_block_size_hook), key=["M", "N", "K"] 

314) 

315@triton.jit 

316def grouped_mm_tma_kernel( 

317 a_desc, 

318 b_desc, 

319 c_desc, 

320 C, 

321 offs, 

322 num_groups: tl.constexpr, 

323 M, 

324 N: tl.constexpr, 

325 K: tl.constexpr, 

326 stride_cm: tl.constexpr, 

327 stride_cn: tl.constexpr, 

328 BLOCK_M: tl.constexpr, 

329 BLOCK_N: tl.constexpr, 

330 BLOCK_K: tl.constexpr, 

331 GROUP_M: tl.constexpr, 

332): 

333 total_grid = tl.num_programs(axis=0) 

334 tile_idx = tl.program_id(axis=0) 

335 num_n_tiles = tl.cdiv(N, BLOCK_N) 

336 last_problem_end = 0 

337 group_start = 0 

338 group_end = 0 

339 

340 for group_idx in tl.range(num_groups): 

341 group_end = tl.load(offs + group_idx).to(tl.int32) 

342 m = group_end - group_start 

343 num_m_tiles = tl.cdiv(m, BLOCK_M) 

344 num_tiles = num_m_tiles * num_n_tiles 

345 

346 current_problem_end = last_problem_end + num_tiles 

347 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

348 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

349 for _ in tl.range(loop_count): 

350 tile_idx_in_gemm = tile_idx - last_problem_end 

351 tile_m_idx, tile_n_idx = grouped_launch( 

352 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M 

353 ) 

354 

355 offs_am = group_start + tile_m_idx * BLOCK_M 

356 offs_bn = tile_n_idx * BLOCK_N 

357 offs_bk = group_idx * K 

358 

359 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

360 

361 for k in tl.range(0, tl.cdiv(K, BLOCK_K)): 

362 a = a_desc.load([offs_am, k * BLOCK_K]) 

363 b = b_desc.load([offs_bk + k * BLOCK_K, offs_bn]) 

364 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

365 

366 c = accumulator.to(c_desc.dtype) 

367 

368 if offs_am + BLOCK_M <= group_end: 

369 c_desc.store([offs_am, offs_bn], c) 

370 else: 

371 offs_cm = offs_am + tl.arange(0, BLOCK_M) 

372 offs_cn = offs_bn + tl.arange(0, BLOCK_N) 

373 c_ptrs = ( 

374 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

375 ) 

376 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N) 

377 tl.store(c_ptrs, c, mask=c_mask) 

378 

379 tile_idx += total_grid 

380 

381 last_problem_end = current_problem_end 

382 group_start = group_end 

383 

384 

385@libentry() 

386@libtuner(configs=get_autotune_config(), key=["M", "N", "K"]) 

387@triton.jit 

388def grouped_mm_kernel( 

389 A, 

390 B, 

391 C, 

392 offs, 

393 num_groups: tl.constexpr, 

394 M, 

395 N: tl.constexpr, 

396 K: tl.constexpr, 

397 stride_am: tl.constexpr, 

398 stride_ak: tl.constexpr, 

399 stride_bk: tl.constexpr, 

400 stride_bn: tl.constexpr, 

401 stride_cm: tl.constexpr, 

402 stride_cn: tl.constexpr, 

403 BLOCK_M: tl.constexpr, 

404 BLOCK_N: tl.constexpr, 

405 BLOCK_K: tl.constexpr, 

406 GROUP_M: tl.constexpr, 

407): 

408 total_grid = tl.num_programs(axis=0) 

409 tile_idx = tl.program_id(axis=0) 

410 num_n_tiles = tl.cdiv(N, BLOCK_N) 

411 last_problem_end = 0 

412 group_start = 0 

413 group_end = 0 

414 

415 for group_idx in tl.range(num_groups): 

416 group_end = tl.load(offs + group_idx).to(tl.int32) 

417 m = group_end - group_start 

418 num_m_tiles = tl.cdiv(m, BLOCK_M) 

419 num_tiles = num_m_tiles * num_n_tiles 

420 

421 current_problem_end = last_problem_end + num_tiles 

422 if tile_idx >= last_problem_end and tile_idx < current_problem_end: 

423 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid 

424 for _ in tl.range(loop_count): 

425 tile_idx_in_gemm = tile_idx - last_problem_end 

426 tile_m_idx, tile_n_idx = grouped_launch( 

427 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M 

428 ) 

429 

430 offs_am = group_start + tile_m_idx * BLOCK_M 

431 offs_bn = tile_n_idx * BLOCK_N 

432 offs_bk = group_idx * K 

433 

434 a_block_ptr = tl.make_block_ptr( 

435 base=A, 

436 shape=(M, K), 

437 strides=(stride_am, stride_ak), 

438 offsets=(offs_am, 0), 

439 block_shape=(BLOCK_M, BLOCK_K), 

440 order=(1, 0), 

441 ) 

442 

443 b_block_ptr = tl.make_block_ptr( 

444 base=B, 

445 shape=(num_groups * K, N), 

446 strides=(stride_bk, stride_bn), 

447 offsets=(offs_bk, offs_bn), 

448 block_shape=(BLOCK_K, BLOCK_N), 

449 order=(1, 0), 

450 ) 

451 

452 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 

453 

454 for k in tl.range(0, tl.cdiv(K, BLOCK_K)): 

455 a = tl.load(a_block_ptr, boundary_check=(0, 1)) 

456 b = tl.load(b_block_ptr, boundary_check=(0, 1)) 

457 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False) 

458 

459 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K)) 

460 b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0)) 

461 

462 c = accumulator.to(C.dtype.element_ty) 

463 

464 c_block_ptr = tl.make_block_ptr( 

465 base=C, 

466 shape=(M, N), 

467 strides=(stride_cm, stride_cn), 

468 offsets=(offs_am, offs_bn), 

469 block_shape=(BLOCK_M, BLOCK_N), 

470 order=(1, 0), 

471 ) 

472 

473 if offs_am + BLOCK_M <= group_end: 

474 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

475 else: 

476 offs_cm = offs_am + tl.arange(0, BLOCK_M) 

477 offs_cn = offs_bn + tl.arange(0, BLOCK_N) 

478 c_ptrs = ( 

479 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] 

480 ) 

481 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N) 

482 tl.store(c_ptrs, c, mask=c_mask) 

483 

484 tile_idx += total_grid 

485 

486 last_problem_end = current_problem_end 

487 group_start = group_end 

488 

489 

490def group_gemm(group_A, group_B, group_C, offs_table, alpha=1, beta=0): 

491 A_addrs = [] 

492 B_addrs = [] 

493 C_addrs = [] 

494 group_sizes = [] 

495 group_lds = [] 

496 group_size = len(offs_table) 

497 M, N = group_C.shape 

498 K = group_A.shape[1] 

499 group_out = torch.empty((M, N), device=group_A.device, dtype=group_A.dtype) 

500 out_addrs = [] 

501 for i in range(group_size): 

502 M_g = offs_table[i][0] 

503 N_g = offs_table[i][1] 

504 K_g = offs_table[i][2] 

505 A_g = group_A[offs_table[i][3]] 

506 B_g = group_B[offs_table[i][4]] 

507 C_g = group_C[offs_table[i][5]] 

508 out_g = group_out[offs_table[i][5]] 

509 group_sizes += [M_g, N_g, K_g] 

510 group_lds += [K_g, N_g, N_g] 

511 A_addrs.append(A_g.data_ptr()) 

512 B_addrs.append(B_g.data_ptr()) 

513 C_addrs.append(C_g.data_ptr()) 

514 out_addrs.append(out_g.data_ptr()) 

515 

516 d_a_ptrs = torch.tensor(A_addrs, device=group_A.device) 

517 d_b_ptrs = torch.tensor(B_addrs, device=group_A.device) 

518 d_c_ptrs = torch.tensor(C_addrs, device=group_A.device) 

519 d_output_ptrs = torch.tensor(out_addrs, device=group_A.device) 

520 d_g_sizes = torch.tensor(group_sizes, dtype=torch.int32, device=group_A.device) 

521 d_g_lds = torch.tensor(group_lds, dtype=torch.int32, device=group_A.device) 

522 NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count 

523 

524 if _support_device_tensor_descriptor and supports_tma(): 

525 

526 def alloc_fn(size, alignment, stream): 

527 return torch.empty(size, device=group_A.device, dtype=torch.int8) 

528 

529 triton.set_allocator(alloc_fn) 

530 grouped_gemm_tma_kernel[(NUM_SMS,)]( 

531 M, 

532 N, 

533 K, 

534 d_a_ptrs, 

535 d_b_ptrs, 

536 d_c_ptrs, 

537 d_output_ptrs, 

538 d_g_sizes, 

539 d_g_lds, 

540 group_size, 

541 alpha=alpha, 

542 beta=beta, 

543 ) 

544 else: 

545 grouped_gemm_kernel[(NUM_SMS,)]( 

546 M, 

547 N, 

548 K, 

549 d_a_ptrs, 

550 d_b_ptrs, 

551 d_c_ptrs, 

552 d_output_ptrs, 

553 d_g_sizes, 

554 d_g_lds, 

555 group_size, 

556 alpha=alpha, 

557 beta=beta, 

558 ) 

559 

560 return group_out 

561 

562 

563def group_mm(A: torch.Tensor, B: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: 

564 assert A.dim() == 2 

565 assert B.dim() == 3 

566 M, K = A.shape 

567 

568 num_groups, BK, N = B.shape 

569 strideBK, strideBN = B.stride(1), B.stride(2) 

570 

571 assert num_groups == offs.numel() 

572 NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count 

573 C = A.new_empty(M, N) 

574 if _support_host_tensor_descriptor and supports_tma(): 

575 dummy_block = [1, 1] 

576 

577 a_desc = TensorDescriptor(A, A.shape, A.stride(), dummy_block) 

578 b_desc = TensorDescriptor( 

579 B, [num_groups * K, N], [strideBK, strideBN], dummy_block 

580 ) 

581 c_desc = TensorDescriptor(C, C.shape, C.stride(), dummy_block) 

582 

583 grouped_mm_tma_kernel[(NUM_SMS,)]( 

584 a_desc, 

585 b_desc, 

586 c_desc, 

587 C, 

588 offs, 

589 num_groups, 

590 M, 

591 N, 

592 K, 

593 C.stride(0), 

594 C.stride(1), 

595 ) 

596 else: 

597 grouped_mm_kernel[(NUM_SMS,)]( 

598 A, 

599 B, 

600 C, 

601 offs, 

602 num_groups, 

603 M, 

604 N, 

605 K, 

606 A.stride(0), 

607 A.stride(1), 

608 strideBK, 

609 strideBN, 

610 C.stride(0), 

611 C.stride(1), 

612 ) 

613 

614 return C