Coverage for src/flag_gems/ops/mm.py: 40%
159 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.ops.mm_streamk import streamk_mm
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.device_info import get_device_capability, get_sm_count
14CACHE_USAGE_THRESHOLD = 0.8
16logger = logging.getLogger(__name__)
19@triton.jit
20def prev_multiple_of(a, b):
21 # the largest x<a that x%b ==0
22 return tl.cdiv(a, b) * b - b
25@libentry()
26@libtuner(
27 configs=runtime.get_tuned_config("mm"),
28 # Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
29 key=["M", "N", "K", "stride_am", "stride_bk"],
30 strategy=["align32", "align32", "align32", "align32", "align32"],
31 warmup=5,
32 rep=10,
33)
34@triton.jit
35def mm_kernel_general(
36 A,
37 B,
38 C,
39 M,
40 N,
41 K,
42 stride_am,
43 stride_ak,
44 stride_bk,
45 stride_bn,
46 stride_cm,
47 stride_cn,
48 BLOCK_M: tl.constexpr,
49 BLOCK_N: tl.constexpr,
50 BLOCK_K: tl.constexpr,
51 GROUP_M: tl.constexpr,
52 IS_FP64: tl.constexpr = False,
53):
54 # matrix multiplication
55 pid = tle.program_id(0)
56 grid_m = tl.cdiv(M, BLOCK_M)
57 grid_n = tl.cdiv(N, BLOCK_N)
58 # re-order program ID for better L2 performance
59 width = GROUP_M * grid_n
60 group_id = pid // width
61 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
62 pid_m = group_id * GROUP_M + (pid % group_size)
63 pid_n = (pid % width) // (group_size)
64 # do matrix multiplication
65 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
66 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
67 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
68 rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N).to(tl.int64)
69 rm = rm.to(tl.int64)
70 rn = rn.to(tl.int64)
71 prev_multiple = prev_multiple_of(K, BLOCK_K)
73 if IS_FP64:
74 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64)
75 else:
76 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
77 for start_k in range(0, prev_multiple, BLOCK_K):
78 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
79 a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
80 b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
81 if a.dtype != b.dtype:
82 a = a.to(C.dtype.element_ty)
83 b = b.to(C.dtype.element_ty)
84 if IS_FP64:
85 acc += tl.dot(a, b, allow_tf32=False)
86 else:
87 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
89 # loop peeling
90 rk = (prev_multiple + tl.arange(0, BLOCK_K)).to(tl.int64)
91 mask_k = rk < K
92 a = tl.load(
93 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
94 mask=mask_k[None, :],
95 other=0.0,
96 )
97 b = tl.load(
98 B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
99 mask=mask_k[:, None],
100 other=0.0,
101 )
102 if a.dtype != b.dtype:
103 a = a.to(C.dtype.element_ty)
104 b = b.to(C.dtype.element_ty)
105 if IS_FP64:
106 acc += tl.dot(a, b, allow_tf32=False)
107 else:
108 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
110 acc = acc.to(C.dtype.element_ty)
111 # rematerialize rm and rn to save registers
112 rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
113 rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)).to(tl.int64)
114 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
115 mask = (rm < M)[:, None] & (rn < N)[None, :]
116 # handles write-back with reduction-splitting
117 tl.store(C, acc, mask=mask)
120_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
123def get_higher_dtype(a, b):
124 if a is b:
125 return a
127 assert a in _ordered_datatypes
128 assert b in _ordered_datatypes
130 for d in _ordered_datatypes:
131 if a is d:
132 return b
133 if b is d:
134 return a
137def general_mm(a, b, c, M, N, K):
138 grid = lambda META: (
139 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
140 )
141 with torch_device_fn.device(a.device):
142 mm_kernel_general[grid](
143 a,
144 b,
145 c,
146 M,
147 N,
148 K,
149 a.stride(0),
150 a.stride(1),
151 b.stride(0),
152 b.stride(1),
153 c.stride(0),
154 c.stride(1),
155 GROUP_M=8,
156 IS_FP64=a.dtype == torch.float64,
157 )
158 return c
161@libentry()
162@libtuner(
163 configs=runtime.get_tuned_config("mm_self_transpose"),
164 key=["M", "K", "stride_am", "stride_ak"],
165 strategy=["align32", "align32", "align32", "align32"],
166 warmup=2,
167 rep=4,
168)
169@triton.jit
170def mm_kernel_syrk(
171 A,
172 C,
173 M,
174 K,
175 stride_am,
176 stride_ak,
177 stride_cm,
178 stride_cn,
179 BLOCK_M: tl.constexpr,
180 BLOCK_K: tl.constexpr,
181):
182 pid = tl.program_id(0)
184 # Packed lower-triangular launch domain:
185 # pid = row * (row + 1) / 2 + col, where 0 <= col <= row.
186 #
187 # Invert the triangular-number indexing by solving:
188 # row^2 + row - 2 * pid = 0
189 # => row = (-1 + sqrt(1 + 8 * pid)) / 2
190 #
191 # We take floor(...) as the candidate row, then apply an integer +/-1 correction
192 # because fp32 sqrt can be off near triangular-number boundaries.
193 pid_f = pid.to(tl.float32)
194 pid_m = tl.floor((tl.sqrt(8.0 * pid_f + 1.0) - 1.0) / 2.0).to(tl.int32)
195 tri_start = pid_m * (pid_m + 1) // 2
196 pid_m = tl.where(tri_start > pid, pid_m - 1, pid_m)
197 next_tri_start = (pid_m + 1) * (pid_m + 2) // 2
198 pid_m = tl.where(next_tri_start <= pid, pid_m + 1, pid_m)
199 tri_start = pid_m * (pid_m + 1) // 2
200 pid_n = pid - tri_start
202 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
203 rn = pid_n * BLOCK_M + tl.arange(0, BLOCK_M)
204 ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M).to(tl.int64)
205 ran = tl.max_contiguous(tl.multiple_of(rn % M, BLOCK_M), BLOCK_M).to(tl.int64)
206 rm = rm.to(tl.int64)
207 rn = rn.to(tl.int64)
208 acc = tl.zeros((BLOCK_M, BLOCK_M), dtype=tl.float32)
210 for start_k in range(0, K, BLOCK_K):
211 rk = (start_k + tl.arange(0, BLOCK_K)).to(tl.int64)
212 mask_k = rk < K
213 a = tl.load(
214 A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
215 mask=mask_k[None, :],
216 other=0.0,
217 )
218 b = tl.load(
219 A + (rk[:, None] * stride_ak + ran[None, :] * stride_am),
220 mask=mask_k[:, None],
221 other=0.0,
222 )
223 acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
225 out = acc.to(C.dtype.element_ty)
226 c_ptr = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
227 mask = (rm < M)[:, None] & (rn < M)[None, :]
228 tl.store(c_ptr, out, mask=mask)
230 if pid_m > pid_n:
231 c_t_ptr = C + (rn[:, None] * stride_cm + rm[None, :] * stride_cn)
232 mask_t = (rn < M)[:, None] & (rm < M)[None, :]
233 tl.store(c_t_ptr, tl.trans(out), mask=mask_t)
236def is_syrk_transpose_pair(a, b):
237 return (
238 a.ndim == 2
239 and b.ndim == 2
240 and a.shape[0] == b.shape[1]
241 and a.shape[1] == b.shape[0]
242 and a.stride(0) == b.stride(1)
243 and a.stride(1) == b.stride(0)
244 and a.storage_offset() == b.storage_offset()
245 and a.data_ptr() == b.data_ptr()
246 )
249def syrk_mm(a, c, M, K):
250 grid = lambda META: (
251 # Number of tile rows is tiles = ceil(M / BLOCK_M).
252 # Packed lower triangle contains:
253 # 1 + 2 + ... + tiles = tiles * (tiles + 1) / 2
254 triton.cdiv(M, META["BLOCK_M"])
255 * (triton.cdiv(M, META["BLOCK_M"]) + 1)
256 // 2,
257 )
258 with torch_device_fn.device(a.device):
259 mm_kernel_syrk[grid](
260 a,
261 c,
262 M,
263 K,
264 a.stride(0),
265 a.stride(1),
266 c.stride(0),
267 c.stride(1),
268 )
269 return c
272def streamk_scenario(a, b, M, N, K):
273 # TODO: this my change sometime according to the realbenchmark result
274 # Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
275 # The optimal settings for other devices need to be determined through real testing.
276 capability = get_device_capability()
277 return (
278 capability[0] == 8
279 and a.dtype in [torch.float16, torch.bfloat16]
280 and b.dtype in [torch.float16, torch.bfloat16]
281 and a.is_contiguous()
282 and b.is_contiguous()
283 and K > M * 5
284 and K > N * 5
285 )
288def mm(a, b):
289 logger.debug("GEMS MM")
291 device = a.device
292 if is_syrk_transpose_pair(a, b):
293 M, K = a.shape
294 c = torch.empty((M, M), device=device, dtype=a.dtype)
295 return syrk_mm(a, c, M, K)
296 # handle non-contiguous inputs if necessary
297 if a.stride(0) > 1 and a.stride(1) > 1:
298 a = a.contiguous()
299 if b.stride(0) > 1 and b.stride(1) > 1:
300 b = b.contiguous()
301 # checks constraints
302 assert a.shape[1] == b.shape[0], "incompatible dimensions"
303 M, K = a.shape
304 _, N = b.shape
305 # allocates output
306 c_dtype = get_higher_dtype(a.dtype, b.dtype)
307 c = torch.empty((M, N), device=device, dtype=c_dtype)
308 # l2_cache_size = get_l2_cache_size()
309 sm_count = get_sm_count()
310 if streamk_scenario(a, b, M, N, K):
311 return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
312 else:
313 return general_mm(a, b, c, M, N, K)
316def mm_out(a, b, *, out):
317 logger.debug("GEMS MM_OUT")
319 if is_syrk_transpose_pair(a, b):
320 M, K = a.shape
321 return syrk_mm(a, out, M, K)
322 # handle non-contiguous inputs if necessary
323 if a.stride(0) > 1 and a.stride(1) > 1:
324 a = a.contiguous()
325 if b.stride(0) > 1 and b.stride(1) > 1:
326 b = b.contiguous()
327 # checks constraints
328 assert a.shape[1] == b.shape[0], "incompatible dimensions"
329 M, K = a.shape
330 _, N = b.shape
331 # l2_cache_size = get_l2_cache_size()
332 sm_count = get_sm_count()
333 if streamk_scenario(a, b, M, N, K):
334 return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
335 else:
336 return general_mm(a, b, out, M, N, K)