Coverage for src/flag_gems/ops/group_gemm.py: 14%
272 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry, libtuner
9logger = logging.getLogger(__name__)
12def supports_tma():
13 return torch.cuda.get_device_capability()[0] >= 9
16if hasattr(tl, "make_tensor_descriptor"):
17 _support_device_tensor_descriptor = True
18 make_tensor_descriptor_fn = tl.make_tensor_descriptor
19else:
20 _support_device_tensor_descriptor = False
21 make_tensor_descriptor_fn = None
23try:
24 from triton.tools.tensor_descriptor import TensorDescriptor
26 _support_host_tensor_descriptor = True
27except ImportError:
28 _support_host_tensor_descriptor = False
31@triton.jit
32def grouped_launch(
33 pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr
34):
35 grid_m = tl.cdiv(m, block_m)
36 grid_n = tl.cdiv(n, block_n)
38 width = group_m * grid_n
39 group_id = pid // width
40 group_size = tl.minimum(grid_m - group_id * group_m, group_m)
41 pid_m = group_id * group_m + (pid % group_size)
42 pid_n = (pid % width) // group_size
44 return pid_m, pid_n
47def matmul_tma_set_block_size_hook(nargs):
48 BLOCK_M = nargs["BLOCK_M"]
49 BLOCK_N = nargs["BLOCK_N"]
50 BLOCK_K = nargs["BLOCK_K"]
51 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
52 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
53 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
56def get_autotune_config(pre_hook=None):
57 return [
58 triton.Config(
59 {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8},
60 num_stages=3,
61 num_warps=8,
62 pre_hook=pre_hook,
63 ),
64 triton.Config(
65 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8},
66 num_stages=2,
67 num_warps=4,
68 pre_hook=pre_hook,
69 ),
70 triton.Config(
71 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8},
72 num_stages=3,
73 num_warps=4,
74 pre_hook=pre_hook,
75 ),
76 triton.Config(
77 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8},
78 num_stages=3,
79 num_warps=8,
80 pre_hook=pre_hook,
81 ),
82 triton.Config(
83 {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 4},
84 num_stages=4,
85 num_warps=4,
86 pre_hook=pre_hook,
87 ),
88 triton.Config(
89 {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 4},
90 num_stages=4,
91 num_warps=4,
92 pre_hook=pre_hook,
93 ),
94 triton.Config(
95 {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8},
96 num_stages=3,
97 num_warps=8,
98 pre_hook=pre_hook,
99 ),
100 ]
103@libentry()
104@libtuner(configs=get_autotune_config(), key=["M", "N", "K"])
105@triton.jit
106def grouped_gemm_tma_kernel(
107 M,
108 N,
109 K,
110 group_a_ptrs,
111 group_b_ptrs,
112 group_c_ptrs,
113 group_out_ptrs,
114 group_gemm_sizes,
115 g_lds,
116 group_size,
117 BLOCK_M: tl.constexpr,
118 BLOCK_N: tl.constexpr,
119 BLOCK_K: tl.constexpr,
120 GROUP_M: tl.constexpr,
121 alpha: tl.constexpr,
122 beta: tl.constexpr,
123):
124 tile_idx = tl.program_id(0)
125 total_grid = tl.num_programs(0)
126 last_problem_end = 0
127 for g in range(group_size):
128 gm = tl.load(group_gemm_sizes + g * 3)
129 gn = tl.load(group_gemm_sizes + g * 3 + 1)
130 gk = tl.load(group_gemm_sizes + g * 3 + 2)
131 num_m_tiles = tl.cdiv(gm, BLOCK_M)
132 num_n_tiles = tl.cdiv(gn, BLOCK_N)
133 num_tiles = num_m_tiles * num_n_tiles
135 current_problem_end = last_problem_end + num_tiles
136 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
137 lda = tl.load(g_lds + g * 3)
138 ldb = tl.load(g_lds + g * 3 + 1)
139 ldc = tl.load(g_lds + g * 3 + 2)
141 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16))
142 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16))
143 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16))
144 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16))
146 a_desc = make_tensor_descriptor_fn(
147 a_ptr,
148 shape=[gm, gk],
149 strides=[lda, 1],
150 block_shape=[BLOCK_M, BLOCK_K],
151 )
153 b_desc = make_tensor_descriptor_fn(
154 b_ptr,
155 shape=[gk, gn],
156 strides=[ldb, 1],
157 block_shape=[BLOCK_K, BLOCK_N],
158 )
160 c_desc = make_tensor_descriptor_fn(
161 c_ptr,
162 shape=[gm, gn],
163 strides=[ldc, 1],
164 block_shape=[BLOCK_M, BLOCK_N],
165 )
167 out_desc = make_tensor_descriptor_fn(
168 out_ptr,
169 shape=[gm, gn],
170 strides=[ldc, 1],
171 block_shape=[BLOCK_M, BLOCK_N],
172 )
173 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
174 for _ in tl.range(loop_count):
175 tile_idx_in_gemm = tile_idx - last_problem_end
176 tile_m_idx, tile_n_idx = grouped_launch(
177 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M
178 )
180 offs_am = tile_m_idx * BLOCK_M
181 offs_bn = tile_n_idx * BLOCK_N
183 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
184 for kk in range(0, tl.cdiv(gk, BLOCK_K)):
185 a = a_desc.load([offs_am, kk * BLOCK_K])
186 b = b_desc.load([kk * BLOCK_K, offs_bn])
187 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
189 offs_cm = tile_m_idx * BLOCK_M
190 offs_cn = tile_n_idx * BLOCK_N
192 ori_c = c_desc.load([offs_cm, offs_cn])
193 accumulator = ori_c * beta + accumulator * alpha
195 c = accumulator.to(c_desc.dtype)
196 out_desc.store([offs_cm, offs_cn], c)
198 tile_idx += total_grid
200 last_problem_end = current_problem_end
203@libentry()
204@libtuner(configs=get_autotune_config(), key=["M", "N", "K"])
205@triton.jit
206def grouped_gemm_kernel(
207 M,
208 N,
209 K,
210 group_a_ptrs,
211 group_b_ptrs,
212 group_c_ptrs,
213 group_out_ptrs,
214 group_gemm_sizes,
215 g_lds,
216 group_size,
217 BLOCK_M: tl.constexpr,
218 BLOCK_N: tl.constexpr,
219 BLOCK_K: tl.constexpr,
220 GROUP_M: tl.constexpr,
221 alpha: tl.constexpr,
222 beta: tl.constexpr,
223):
224 tile_idx = tl.program_id(0)
225 total_grid = tl.num_programs(0)
226 last_problem_end = 0
227 for g in range(group_size):
228 gm = tl.load(group_gemm_sizes + g * 3)
229 gn = tl.load(group_gemm_sizes + g * 3 + 1)
230 gk = tl.load(group_gemm_sizes + g * 3 + 2)
231 num_m_tiles = tl.cdiv(gm, BLOCK_M)
232 num_n_tiles = tl.cdiv(gn, BLOCK_N)
233 num_tiles = num_m_tiles * num_n_tiles
234 current_problem_end = last_problem_end + num_tiles
235 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
236 lda = tl.load(g_lds + g * 3)
237 ldb = tl.load(g_lds + g * 3 + 1)
238 ldc = tl.load(g_lds + g * 3 + 2)
240 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16))
241 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16))
242 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16))
243 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16))
245 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
246 for _ in tl.range(loop_count):
247 tile_idx_in_gemm = tile_idx - last_problem_end
248 tile_m_idx, tile_n_idx = grouped_launch(
249 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M
250 )
252 offs_am = tile_m_idx * BLOCK_M
253 offs_bn = tile_n_idx * BLOCK_N
255 a_ptrs = tl.make_block_ptr(
256 base=a_ptr,
257 shape=(gm, gk),
258 strides=(lda, 1),
259 offsets=(offs_am, 0),
260 block_shape=(BLOCK_M, BLOCK_K),
261 order=(1, 0),
262 )
263 b_ptrs = tl.make_block_ptr(
264 base=b_ptr,
265 shape=(gk, gn),
266 strides=(ldb, 1),
267 offsets=(0, offs_bn),
268 block_shape=(BLOCK_K, BLOCK_N),
269 order=(1, 0),
270 )
272 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
273 for kk in range(0, tl.cdiv(gk, BLOCK_K)):
274 a = tl.load(a_ptrs, boundary_check=(0, 1))
275 b = tl.load(b_ptrs, boundary_check=(0, 1))
276 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
277 a_ptrs = tl.advance(a_ptrs, (0, BLOCK_K))
278 b_ptrs = tl.advance(b_ptrs, (BLOCK_K, 0))
280 offs_cm = tile_m_idx * BLOCK_M
281 offs_cn = tile_n_idx * BLOCK_N
283 c_ptrs = tl.make_block_ptr(
284 base=c_ptr,
285 shape=(gm, gn),
286 strides=(ldc, 1),
287 offsets=(offs_cm, offs_cn),
288 block_shape=(BLOCK_M, BLOCK_N),
289 order=(1, 0),
290 )
292 out_ptrs = tl.make_block_ptr(
293 base=out_ptr,
294 shape=(gm, gn),
295 strides=(ldc, 1),
296 offsets=(offs_cm, offs_cn),
297 block_shape=(BLOCK_M, BLOCK_N),
298 order=(1, 0),
299 )
300 ori_c = tl.load(c_ptrs, boundary_check=(0, 1))
301 accumulator = ori_c * beta + accumulator * alpha
303 c = accumulator.to(c_ptrs.dtype.element_ty)
304 tl.store(out_ptrs, c, boundary_check=(0, 1))
306 tile_idx += total_grid
308 last_problem_end = current_problem_end
311@libentry()
312@libtuner(
313 configs=get_autotune_config(matmul_tma_set_block_size_hook), key=["M", "N", "K"]
314)
315@triton.jit
316def grouped_mm_tma_kernel(
317 a_desc,
318 b_desc,
319 c_desc,
320 C,
321 offs,
322 num_groups: tl.constexpr,
323 M,
324 N: tl.constexpr,
325 K: tl.constexpr,
326 stride_cm: tl.constexpr,
327 stride_cn: tl.constexpr,
328 BLOCK_M: tl.constexpr,
329 BLOCK_N: tl.constexpr,
330 BLOCK_K: tl.constexpr,
331 GROUP_M: tl.constexpr,
332):
333 total_grid = tl.num_programs(axis=0)
334 tile_idx = tl.program_id(axis=0)
335 num_n_tiles = tl.cdiv(N, BLOCK_N)
336 last_problem_end = 0
337 group_start = 0
338 group_end = 0
340 for group_idx in tl.range(num_groups):
341 group_end = tl.load(offs + group_idx).to(tl.int32)
342 m = group_end - group_start
343 num_m_tiles = tl.cdiv(m, BLOCK_M)
344 num_tiles = num_m_tiles * num_n_tiles
346 current_problem_end = last_problem_end + num_tiles
347 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
348 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
349 for _ in tl.range(loop_count):
350 tile_idx_in_gemm = tile_idx - last_problem_end
351 tile_m_idx, tile_n_idx = grouped_launch(
352 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M
353 )
355 offs_am = group_start + tile_m_idx * BLOCK_M
356 offs_bn = tile_n_idx * BLOCK_N
357 offs_bk = group_idx * K
359 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
361 for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
362 a = a_desc.load([offs_am, k * BLOCK_K])
363 b = b_desc.load([offs_bk + k * BLOCK_K, offs_bn])
364 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
366 c = accumulator.to(c_desc.dtype)
368 if offs_am + BLOCK_M <= group_end:
369 c_desc.store([offs_am, offs_bn], c)
370 else:
371 offs_cm = offs_am + tl.arange(0, BLOCK_M)
372 offs_cn = offs_bn + tl.arange(0, BLOCK_N)
373 c_ptrs = (
374 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
375 )
376 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N)
377 tl.store(c_ptrs, c, mask=c_mask)
379 tile_idx += total_grid
381 last_problem_end = current_problem_end
382 group_start = group_end
385@libentry()
386@libtuner(configs=get_autotune_config(), key=["M", "N", "K"])
387@triton.jit
388def grouped_mm_kernel(
389 A,
390 B,
391 C,
392 offs,
393 num_groups: tl.constexpr,
394 M,
395 N: tl.constexpr,
396 K: tl.constexpr,
397 stride_am: tl.constexpr,
398 stride_ak: tl.constexpr,
399 stride_bk: tl.constexpr,
400 stride_bn: tl.constexpr,
401 stride_cm: tl.constexpr,
402 stride_cn: tl.constexpr,
403 BLOCK_M: tl.constexpr,
404 BLOCK_N: tl.constexpr,
405 BLOCK_K: tl.constexpr,
406 GROUP_M: tl.constexpr,
407):
408 total_grid = tl.num_programs(axis=0)
409 tile_idx = tl.program_id(axis=0)
410 num_n_tiles = tl.cdiv(N, BLOCK_N)
411 last_problem_end = 0
412 group_start = 0
413 group_end = 0
415 for group_idx in tl.range(num_groups):
416 group_end = tl.load(offs + group_idx).to(tl.int32)
417 m = group_end - group_start
418 num_m_tiles = tl.cdiv(m, BLOCK_M)
419 num_tiles = num_m_tiles * num_n_tiles
421 current_problem_end = last_problem_end + num_tiles
422 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
423 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
424 for _ in tl.range(loop_count):
425 tile_idx_in_gemm = tile_idx - last_problem_end
426 tile_m_idx, tile_n_idx = grouped_launch(
427 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M
428 )
430 offs_am = group_start + tile_m_idx * BLOCK_M
431 offs_bn = tile_n_idx * BLOCK_N
432 offs_bk = group_idx * K
434 a_block_ptr = tl.make_block_ptr(
435 base=A,
436 shape=(M, K),
437 strides=(stride_am, stride_ak),
438 offsets=(offs_am, 0),
439 block_shape=(BLOCK_M, BLOCK_K),
440 order=(1, 0),
441 )
443 b_block_ptr = tl.make_block_ptr(
444 base=B,
445 shape=(num_groups * K, N),
446 strides=(stride_bk, stride_bn),
447 offsets=(offs_bk, offs_bn),
448 block_shape=(BLOCK_K, BLOCK_N),
449 order=(1, 0),
450 )
452 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
454 for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
455 a = tl.load(a_block_ptr, boundary_check=(0, 1))
456 b = tl.load(b_block_ptr, boundary_check=(0, 1))
457 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
459 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
460 b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
462 c = accumulator.to(C.dtype.element_ty)
464 c_block_ptr = tl.make_block_ptr(
465 base=C,
466 shape=(M, N),
467 strides=(stride_cm, stride_cn),
468 offsets=(offs_am, offs_bn),
469 block_shape=(BLOCK_M, BLOCK_N),
470 order=(1, 0),
471 )
473 if offs_am + BLOCK_M <= group_end:
474 tl.store(c_block_ptr, c, boundary_check=(0, 1))
475 else:
476 offs_cm = offs_am + tl.arange(0, BLOCK_M)
477 offs_cn = offs_bn + tl.arange(0, BLOCK_N)
478 c_ptrs = (
479 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
480 )
481 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N)
482 tl.store(c_ptrs, c, mask=c_mask)
484 tile_idx += total_grid
486 last_problem_end = current_problem_end
487 group_start = group_end
490def group_gemm(group_A, group_B, group_C, offs_table, alpha=1, beta=0):
491 A_addrs = []
492 B_addrs = []
493 C_addrs = []
494 group_sizes = []
495 group_lds = []
496 group_size = len(offs_table)
497 M, N = group_C.shape
498 K = group_A.shape[1]
499 group_out = torch.empty((M, N), device=group_A.device, dtype=group_A.dtype)
500 out_addrs = []
501 for i in range(group_size):
502 M_g = offs_table[i][0]
503 N_g = offs_table[i][1]
504 K_g = offs_table[i][2]
505 A_g = group_A[offs_table[i][3]]
506 B_g = group_B[offs_table[i][4]]
507 C_g = group_C[offs_table[i][5]]
508 out_g = group_out[offs_table[i][5]]
509 group_sizes += [M_g, N_g, K_g]
510 group_lds += [K_g, N_g, N_g]
511 A_addrs.append(A_g.data_ptr())
512 B_addrs.append(B_g.data_ptr())
513 C_addrs.append(C_g.data_ptr())
514 out_addrs.append(out_g.data_ptr())
516 d_a_ptrs = torch.tensor(A_addrs, device=group_A.device)
517 d_b_ptrs = torch.tensor(B_addrs, device=group_A.device)
518 d_c_ptrs = torch.tensor(C_addrs, device=group_A.device)
519 d_output_ptrs = torch.tensor(out_addrs, device=group_A.device)
520 d_g_sizes = torch.tensor(group_sizes, dtype=torch.int32, device=group_A.device)
521 d_g_lds = torch.tensor(group_lds, dtype=torch.int32, device=group_A.device)
522 NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
524 if _support_device_tensor_descriptor and supports_tma():
526 def alloc_fn(size, alignment, stream):
527 return torch.empty(size, device=group_A.device, dtype=torch.int8)
529 triton.set_allocator(alloc_fn)
530 grouped_gemm_tma_kernel[(NUM_SMS,)](
531 M,
532 N,
533 K,
534 d_a_ptrs,
535 d_b_ptrs,
536 d_c_ptrs,
537 d_output_ptrs,
538 d_g_sizes,
539 d_g_lds,
540 group_size,
541 alpha=alpha,
542 beta=beta,
543 )
544 else:
545 grouped_gemm_kernel[(NUM_SMS,)](
546 M,
547 N,
548 K,
549 d_a_ptrs,
550 d_b_ptrs,
551 d_c_ptrs,
552 d_output_ptrs,
553 d_g_sizes,
554 d_g_lds,
555 group_size,
556 alpha=alpha,
557 beta=beta,
558 )
560 return group_out
563def group_mm(A: torch.Tensor, B: torch.Tensor, offs: torch.Tensor) -> torch.Tensor:
564 assert A.dim() == 2
565 assert B.dim() == 3
566 M, K = A.shape
568 num_groups, BK, N = B.shape
569 strideBK, strideBN = B.stride(1), B.stride(2)
571 assert num_groups == offs.numel()
572 NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
573 C = A.new_empty(M, N)
574 if _support_host_tensor_descriptor and supports_tma():
575 dummy_block = [1, 1]
577 a_desc = TensorDescriptor(A, A.shape, A.stride(), dummy_block)
578 b_desc = TensorDescriptor(
579 B, [num_groups * K, N], [strideBK, strideBN], dummy_block
580 )
581 c_desc = TensorDescriptor(C, C.shape, C.stride(), dummy_block)
583 grouped_mm_tma_kernel[(NUM_SMS,)](
584 a_desc,
585 b_desc,
586 c_desc,
587 C,
588 offs,
589 num_groups,
590 M,
591 N,
592 K,
593 C.stride(0),
594 C.stride(1),
595 )
596 else:
597 grouped_mm_kernel[(NUM_SMS,)](
598 A,
599 B,
600 C,
601 offs,
602 num_groups,
603 M,
604 N,
605 K,
606 A.stride(0),
607 A.stride(1),
608 strideBK,
609 strideBN,
610 C.stride(0),
611 C.stride(1),
612 )
614 return C