Coverage for src/flag_gems/ops/group_gemm.py: 14%
273 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry, libtuner
8from flag_gems.utils.device_info import get_device_capability, get_sm_count
10logger = logging.getLogger(__name__)
13def supports_tma():
14 return get_device_capability()[0] >= 9
17if hasattr(tl, "make_tensor_descriptor"):
18 _support_device_tensor_descriptor = True
19 make_tensor_descriptor_fn = tl.make_tensor_descriptor
20else:
21 _support_device_tensor_descriptor = False
22 make_tensor_descriptor_fn = None
24try:
25 from triton.tools.tensor_descriptor import TensorDescriptor
27 _support_host_tensor_descriptor = True
28except ImportError:
29 _support_host_tensor_descriptor = False
32@triton.jit
33def grouped_launch(
34 pid, m, n, block_m: tl.constexpr, block_n: tl.constexpr, group_m: tl.constexpr
35):
36 grid_m = tl.cdiv(m, block_m)
37 grid_n = tl.cdiv(n, block_n)
39 width = group_m * grid_n
40 group_id = pid // width
41 group_size = tl.minimum(grid_m - group_id * group_m, group_m)
42 pid_m = group_id * group_m + (pid % group_size)
43 pid_n = (pid % width) // group_size
45 return pid_m, pid_n
48def matmul_tma_set_block_size_hook(nargs):
49 BLOCK_M = nargs["BLOCK_M"]
50 BLOCK_N = nargs["BLOCK_N"]
51 BLOCK_K = nargs["BLOCK_K"]
52 nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K]
53 nargs["b_desc"].block_shape = [BLOCK_K, BLOCK_N]
54 nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N]
57def get_autotune_config(pre_hook=None):
58 return [
59 triton.Config(
60 {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8},
61 num_stages=3,
62 num_warps=8,
63 pre_hook=pre_hook,
64 ),
65 triton.Config(
66 {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8},
67 num_stages=2,
68 num_warps=4,
69 pre_hook=pre_hook,
70 ),
71 triton.Config(
72 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8},
73 num_stages=3,
74 num_warps=4,
75 pre_hook=pre_hook,
76 ),
77 triton.Config(
78 {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8},
79 num_stages=3,
80 num_warps=8,
81 pre_hook=pre_hook,
82 ),
83 triton.Config(
84 {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 4},
85 num_stages=4,
86 num_warps=4,
87 pre_hook=pre_hook,
88 ),
89 triton.Config(
90 {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 4},
91 num_stages=4,
92 num_warps=4,
93 pre_hook=pre_hook,
94 ),
95 triton.Config(
96 {"BLOCK_M": 256, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8},
97 num_stages=3,
98 num_warps=8,
99 pre_hook=pre_hook,
100 ),
101 ]
104@libentry()
105@libtuner(configs=get_autotune_config(), key=["M", "N", "K"])
106@triton.jit
107def grouped_gemm_tma_kernel(
108 M,
109 N,
110 K,
111 group_a_ptrs,
112 group_b_ptrs,
113 group_c_ptrs,
114 group_out_ptrs,
115 group_gemm_sizes,
116 g_lds,
117 group_size,
118 BLOCK_M: tl.constexpr,
119 BLOCK_N: tl.constexpr,
120 BLOCK_K: tl.constexpr,
121 GROUP_M: tl.constexpr,
122 alpha: tl.constexpr,
123 beta: tl.constexpr,
124):
125 tile_idx = tl.program_id(0)
126 total_grid = tl.num_programs(0)
127 last_problem_end = 0
128 for g in range(group_size):
129 gm = tl.load(group_gemm_sizes + g * 3)
130 gn = tl.load(group_gemm_sizes + g * 3 + 1)
131 gk = tl.load(group_gemm_sizes + g * 3 + 2)
132 num_m_tiles = tl.cdiv(gm, BLOCK_M)
133 num_n_tiles = tl.cdiv(gn, BLOCK_N)
134 num_tiles = num_m_tiles * num_n_tiles
136 current_problem_end = last_problem_end + num_tiles
137 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
138 lda = tl.load(g_lds + g * 3)
139 ldb = tl.load(g_lds + g * 3 + 1)
140 ldc = tl.load(g_lds + g * 3 + 2)
142 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16))
143 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16))
144 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16))
145 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16))
147 a_desc = make_tensor_descriptor_fn(
148 a_ptr,
149 shape=[gm, gk],
150 strides=[lda, 1],
151 block_shape=[BLOCK_M, BLOCK_K],
152 )
154 b_desc = make_tensor_descriptor_fn(
155 b_ptr,
156 shape=[gk, gn],
157 strides=[ldb, 1],
158 block_shape=[BLOCK_K, BLOCK_N],
159 )
161 c_desc = make_tensor_descriptor_fn(
162 c_ptr,
163 shape=[gm, gn],
164 strides=[ldc, 1],
165 block_shape=[BLOCK_M, BLOCK_N],
166 )
168 out_desc = make_tensor_descriptor_fn(
169 out_ptr,
170 shape=[gm, gn],
171 strides=[ldc, 1],
172 block_shape=[BLOCK_M, BLOCK_N],
173 )
174 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
175 for _ in tl.range(loop_count):
176 tile_idx_in_gemm = tile_idx - last_problem_end
177 tile_m_idx, tile_n_idx = grouped_launch(
178 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M
179 )
181 offs_am = tile_m_idx * BLOCK_M
182 offs_bn = tile_n_idx * BLOCK_N
184 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
185 for kk in range(0, tl.cdiv(gk, BLOCK_K)):
186 a = a_desc.load([offs_am, kk * BLOCK_K])
187 b = b_desc.load([kk * BLOCK_K, offs_bn])
188 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
190 offs_cm = tile_m_idx * BLOCK_M
191 offs_cn = tile_n_idx * BLOCK_N
193 ori_c = c_desc.load([offs_cm, offs_cn])
194 accumulator = ori_c * beta + accumulator * alpha
196 c = accumulator.to(c_desc.dtype)
197 out_desc.store([offs_cm, offs_cn], c)
199 tile_idx += total_grid
201 last_problem_end = current_problem_end
204@libentry()
205@libtuner(configs=get_autotune_config(), key=["M", "N", "K"])
206@triton.jit
207def grouped_gemm_kernel(
208 M,
209 N,
210 K,
211 group_a_ptrs,
212 group_b_ptrs,
213 group_c_ptrs,
214 group_out_ptrs,
215 group_gemm_sizes,
216 g_lds,
217 group_size,
218 BLOCK_M: tl.constexpr,
219 BLOCK_N: tl.constexpr,
220 BLOCK_K: tl.constexpr,
221 GROUP_M: tl.constexpr,
222 alpha: tl.constexpr,
223 beta: tl.constexpr,
224):
225 tile_idx = tl.program_id(0)
226 total_grid = tl.num_programs(0)
227 last_problem_end = 0
228 for g in range(group_size):
229 gm = tl.load(group_gemm_sizes + g * 3)
230 gn = tl.load(group_gemm_sizes + g * 3 + 1)
231 gk = tl.load(group_gemm_sizes + g * 3 + 2)
232 num_m_tiles = tl.cdiv(gm, BLOCK_M)
233 num_n_tiles = tl.cdiv(gn, BLOCK_N)
234 num_tiles = num_m_tiles * num_n_tiles
235 current_problem_end = last_problem_end + num_tiles
236 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
237 lda = tl.load(g_lds + g * 3)
238 ldb = tl.load(g_lds + g * 3 + 1)
239 ldc = tl.load(g_lds + g * 3 + 2)
241 a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.bfloat16))
242 b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.bfloat16))
243 c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.bfloat16))
244 out_ptr = tl.load(group_out_ptrs + g).to(tl.pointer_type(tl.bfloat16))
246 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
247 for _ in tl.range(loop_count):
248 tile_idx_in_gemm = tile_idx - last_problem_end
249 tile_m_idx, tile_n_idx = grouped_launch(
250 tile_idx_in_gemm, gm, gn, BLOCK_M, BLOCK_N, GROUP_M
251 )
253 offs_am = tile_m_idx * BLOCK_M
254 offs_bn = tile_n_idx * BLOCK_N
256 a_ptrs = tl.make_block_ptr(
257 base=a_ptr,
258 shape=(gm, gk),
259 strides=(lda, 1),
260 offsets=(offs_am, 0),
261 block_shape=(BLOCK_M, BLOCK_K),
262 order=(1, 0),
263 )
264 b_ptrs = tl.make_block_ptr(
265 base=b_ptr,
266 shape=(gk, gn),
267 strides=(ldb, 1),
268 offsets=(0, offs_bn),
269 block_shape=(BLOCK_K, BLOCK_N),
270 order=(1, 0),
271 )
273 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
274 for kk in range(0, tl.cdiv(gk, BLOCK_K)):
275 a = tl.load(a_ptrs, boundary_check=(0, 1))
276 b = tl.load(b_ptrs, boundary_check=(0, 1))
277 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
278 a_ptrs = tl.advance(a_ptrs, (0, BLOCK_K))
279 b_ptrs = tl.advance(b_ptrs, (BLOCK_K, 0))
281 offs_cm = tile_m_idx * BLOCK_M
282 offs_cn = tile_n_idx * BLOCK_N
284 c_ptrs = tl.make_block_ptr(
285 base=c_ptr,
286 shape=(gm, gn),
287 strides=(ldc, 1),
288 offsets=(offs_cm, offs_cn),
289 block_shape=(BLOCK_M, BLOCK_N),
290 order=(1, 0),
291 )
293 out_ptrs = tl.make_block_ptr(
294 base=out_ptr,
295 shape=(gm, gn),
296 strides=(ldc, 1),
297 offsets=(offs_cm, offs_cn),
298 block_shape=(BLOCK_M, BLOCK_N),
299 order=(1, 0),
300 )
301 ori_c = tl.load(c_ptrs, boundary_check=(0, 1))
302 accumulator = ori_c * beta + accumulator * alpha
304 c = accumulator.to(c_ptrs.dtype.element_ty)
305 tl.store(out_ptrs, c, boundary_check=(0, 1))
307 tile_idx += total_grid
309 last_problem_end = current_problem_end
312@libentry()
313@libtuner(
314 configs=get_autotune_config(matmul_tma_set_block_size_hook), key=["M", "N", "K"]
315)
316@triton.jit
317def grouped_mm_tma_kernel(
318 a_desc,
319 b_desc,
320 c_desc,
321 C,
322 offs,
323 num_groups: tl.constexpr,
324 M,
325 N: tl.constexpr,
326 K: tl.constexpr,
327 stride_cm: tl.constexpr,
328 stride_cn: tl.constexpr,
329 BLOCK_M: tl.constexpr,
330 BLOCK_N: tl.constexpr,
331 BLOCK_K: tl.constexpr,
332 GROUP_M: tl.constexpr,
333):
334 total_grid = tl.num_programs(axis=0)
335 tile_idx = tl.program_id(axis=0)
336 num_n_tiles = tl.cdiv(N, BLOCK_N)
337 last_problem_end = 0
338 group_start = 0
339 group_end = 0
341 for group_idx in tl.range(num_groups):
342 group_end = tl.load(offs + group_idx).to(tl.int32)
343 m = group_end - group_start
344 num_m_tiles = tl.cdiv(m, BLOCK_M)
345 num_tiles = num_m_tiles * num_n_tiles
347 current_problem_end = last_problem_end + num_tiles
348 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
349 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
350 for _ in tl.range(loop_count):
351 tile_idx_in_gemm = tile_idx - last_problem_end
352 tile_m_idx, tile_n_idx = grouped_launch(
353 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M
354 )
356 offs_am = group_start + tile_m_idx * BLOCK_M
357 offs_bn = tile_n_idx * BLOCK_N
358 offs_bk = group_idx * K
360 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
362 for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
363 a = a_desc.load([offs_am, k * BLOCK_K])
364 b = b_desc.load([offs_bk + k * BLOCK_K, offs_bn])
365 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
367 c = accumulator.to(c_desc.dtype)
369 if offs_am + BLOCK_M <= group_end:
370 c_desc.store([offs_am, offs_bn], c)
371 else:
372 offs_cm = offs_am + tl.arange(0, BLOCK_M)
373 offs_cn = offs_bn + tl.arange(0, BLOCK_N)
374 c_ptrs = (
375 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
376 )
377 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N)
378 tl.store(c_ptrs, c, mask=c_mask)
380 tile_idx += total_grid
382 last_problem_end = current_problem_end
383 group_start = group_end
386@libentry()
387@libtuner(configs=get_autotune_config(), key=["M", "N", "K"])
388@triton.jit
389def grouped_mm_kernel(
390 A,
391 B,
392 C,
393 offs,
394 num_groups: tl.constexpr,
395 M,
396 N: tl.constexpr,
397 K: tl.constexpr,
398 stride_am: tl.constexpr,
399 stride_ak: tl.constexpr,
400 stride_bk: tl.constexpr,
401 stride_bn: tl.constexpr,
402 stride_cm: tl.constexpr,
403 stride_cn: tl.constexpr,
404 BLOCK_M: tl.constexpr,
405 BLOCK_N: tl.constexpr,
406 BLOCK_K: tl.constexpr,
407 GROUP_M: tl.constexpr,
408):
409 total_grid = tl.num_programs(axis=0)
410 tile_idx = tl.program_id(axis=0)
411 num_n_tiles = tl.cdiv(N, BLOCK_N)
412 last_problem_end = 0
413 group_start = 0
414 group_end = 0
416 for group_idx in tl.range(num_groups):
417 group_end = tl.load(offs + group_idx).to(tl.int32)
418 m = group_end - group_start
419 num_m_tiles = tl.cdiv(m, BLOCK_M)
420 num_tiles = num_m_tiles * num_n_tiles
422 current_problem_end = last_problem_end + num_tiles
423 if tile_idx >= last_problem_end and tile_idx < current_problem_end:
424 loop_count = (current_problem_end - tile_idx + total_grid - 1) // total_grid
425 for _ in tl.range(loop_count):
426 tile_idx_in_gemm = tile_idx - last_problem_end
427 tile_m_idx, tile_n_idx = grouped_launch(
428 tile_idx_in_gemm, m, N, BLOCK_M, BLOCK_N, GROUP_M
429 )
431 offs_am = group_start + tile_m_idx * BLOCK_M
432 offs_bn = tile_n_idx * BLOCK_N
433 offs_bk = group_idx * K
435 a_block_ptr = tl.make_block_ptr(
436 base=A,
437 shape=(M, K),
438 strides=(stride_am, stride_ak),
439 offsets=(offs_am, 0),
440 block_shape=(BLOCK_M, BLOCK_K),
441 order=(1, 0),
442 )
444 b_block_ptr = tl.make_block_ptr(
445 base=B,
446 shape=(num_groups * K, N),
447 strides=(stride_bk, stride_bn),
448 offsets=(offs_bk, offs_bn),
449 block_shape=(BLOCK_K, BLOCK_N),
450 order=(1, 0),
451 )
453 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
455 for k in tl.range(0, tl.cdiv(K, BLOCK_K)):
456 a = tl.load(a_block_ptr, boundary_check=(0, 1))
457 b = tl.load(b_block_ptr, boundary_check=(0, 1))
458 accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=False)
460 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
461 b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
463 c = accumulator.to(C.dtype.element_ty)
465 c_block_ptr = tl.make_block_ptr(
466 base=C,
467 shape=(M, N),
468 strides=(stride_cm, stride_cn),
469 offsets=(offs_am, offs_bn),
470 block_shape=(BLOCK_M, BLOCK_N),
471 order=(1, 0),
472 )
474 if offs_am + BLOCK_M <= group_end:
475 tl.store(c_block_ptr, c, boundary_check=(0, 1))
476 else:
477 offs_cm = offs_am + tl.arange(0, BLOCK_M)
478 offs_cn = offs_bn + tl.arange(0, BLOCK_N)
479 c_ptrs = (
480 C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
481 )
482 c_mask = (offs_cm[:, None] < group_end) & (offs_cn[None, :] < N)
483 tl.store(c_ptrs, c, mask=c_mask)
485 tile_idx += total_grid
487 last_problem_end = current_problem_end
488 group_start = group_end
491def group_gemm(group_A, group_B, group_C, offs_table, alpha=1, beta=0):
492 A_addrs = []
493 B_addrs = []
494 C_addrs = []
495 group_sizes = []
496 group_lds = []
497 group_size = len(offs_table)
498 M, N = group_C.shape
499 K = group_A.shape[1]
500 group_out = torch.empty((M, N), device=group_A.device, dtype=group_A.dtype)
501 out_addrs = []
502 for i in range(group_size):
503 M_g = offs_table[i][0]
504 N_g = offs_table[i][1]
505 K_g = offs_table[i][2]
506 A_g = group_A[offs_table[i][3]]
507 B_g = group_B[offs_table[i][4]]
508 C_g = group_C[offs_table[i][5]]
509 out_g = group_out[offs_table[i][5]]
510 group_sizes += [M_g, N_g, K_g]
511 group_lds += [K_g, N_g, N_g]
512 A_addrs.append(A_g.data_ptr())
513 B_addrs.append(B_g.data_ptr())
514 C_addrs.append(C_g.data_ptr())
515 out_addrs.append(out_g.data_ptr())
517 d_a_ptrs = torch.tensor(A_addrs, device=group_A.device)
518 d_b_ptrs = torch.tensor(B_addrs, device=group_A.device)
519 d_c_ptrs = torch.tensor(C_addrs, device=group_A.device)
520 d_output_ptrs = torch.tensor(out_addrs, device=group_A.device)
521 d_g_sizes = torch.tensor(group_sizes, dtype=torch.int32, device=group_A.device)
522 d_g_lds = torch.tensor(group_lds, dtype=torch.int32, device=group_A.device)
523 NUM_SMS = get_sm_count()
525 if _support_device_tensor_descriptor and supports_tma():
527 def alloc_fn(size, alignment, stream):
528 return torch.empty(size, device=group_A.device, dtype=torch.int8)
530 triton.set_allocator(alloc_fn)
531 grouped_gemm_tma_kernel[(NUM_SMS,)](
532 M,
533 N,
534 K,
535 d_a_ptrs,
536 d_b_ptrs,
537 d_c_ptrs,
538 d_output_ptrs,
539 d_g_sizes,
540 d_g_lds,
541 group_size,
542 alpha=alpha,
543 beta=beta,
544 )
545 else:
546 grouped_gemm_kernel[(NUM_SMS,)](
547 M,
548 N,
549 K,
550 d_a_ptrs,
551 d_b_ptrs,
552 d_c_ptrs,
553 d_output_ptrs,
554 d_g_sizes,
555 d_g_lds,
556 group_size,
557 alpha=alpha,
558 beta=beta,
559 )
561 return group_out
564def group_mm(A: torch.Tensor, B: torch.Tensor, offs: torch.Tensor) -> torch.Tensor:
565 assert A.dim() == 2
566 assert B.dim() == 3
567 M, K = A.shape
569 num_groups, BK, N = B.shape
570 strideBK, strideBN = B.stride(1), B.stride(2)
572 assert num_groups == offs.numel()
573 NUM_SMS = get_sm_count()
574 C = A.new_empty(M, N)
575 if _support_host_tensor_descriptor and supports_tma():
576 dummy_block = [1, 1]
578 a_desc = TensorDescriptor(A, A.shape, A.stride(), dummy_block)
579 b_desc = TensorDescriptor(
580 B, [num_groups * K, N], [strideBK, strideBN], dummy_block
581 )
582 c_desc = TensorDescriptor(C, C.shape, C.stride(), dummy_block)
584 grouped_mm_tma_kernel[(NUM_SMS,)](
585 a_desc,
586 b_desc,
587 c_desc,
588 C,
589 offs,
590 num_groups,
591 M,
592 N,
593 K,
594 C.stride(0),
595 C.stride(1),
596 )
597 else:
598 grouped_mm_kernel[(NUM_SMS,)](
599 A,
600 B,
601 C,
602 offs,
603 num_groups,
604 M,
605 N,
606 K,
607 A.stride(0),
608 A.stride(1),
609 strideBK,
610 strideBN,
611 C.stride(0),
612 C.stride(1),
613 )
615 return C