Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/mm.py: 0%
97 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("mm"),
17 key=["M", "N", "K", "stride_am", "stride_bk", "stride_ak", "stride_bn"],
18 strategy=[
19 "align32",
20 "align32",
21 "align32",
22 "align32",
23 "align32",
24 "align32",
25 "align32",
26 ],
27 warmup=1,
28 rep=2,
29)
30@triton.heuristics(runtime.get_heuristic_config("mm"))
31@triton.jit
32def mm_kernel(
33 A,
34 B,
35 C,
36 M,
37 N,
38 K,
39 stride_am,
40 stride_ak,
41 stride_bk,
42 stride_bn,
43 stride_cm,
44 stride_cn,
45 dot_out_dtype: tl.constexpr,
46 BLOCK_M: tl.constexpr,
47 BLOCK_N: tl.constexpr,
48 BLOCK_K: tl.constexpr,
49 GROUP_M: tl.constexpr,
50 SPLIT_K: tl.constexpr,
51 EVEN_K: tl.constexpr,
52 UPCAST: tl.constexpr,
53):
54 # matrix multiplication
55 if UPCAST:
56 pid = tl.program_id(0).to(tl.int64)
57 pid_z = tl.program_id(1).to(tl.int64)
58 else:
59 pid = tl.program_id(0)
60 pid_z = tl.program_id(1)
61 grid_m = tl.cdiv(M, BLOCK_M)
62 grid_n = tl.cdiv(N, BLOCK_N)
63 # re-order program ID for better L2 performance
64 width = GROUP_M * grid_n
65 group_id = pid // width
66 group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
67 pid_m = group_id * GROUP_M + (pid % group_size)
68 pid_n = (pid % width) // (group_size)
69 # do matrix multiplication
70 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
71 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
72 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
73 # pointers
74 A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
75 B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
76 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
77 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
78 if EVEN_K:
79 a = tl.load(A, mask=(rm < M)[:, None], other=0.0)
80 b = tl.load(B, mask=(rn < N)[None, :], other=0.0)
81 else:
82 k_remaining = K - k * (BLOCK_K * SPLIT_K)
83 a = tl.load(
84 A, mask=(rk[None, :] < k_remaining) & (rm < M)[:, None], other=0.0
85 )
86 b = tl.load(
87 B, mask=(rk[:, None] < k_remaining) & (rn < N)[None, :], other=0.0
88 )
90 if a.dtype != b.dtype:
91 a = a.to(C.dtype.element_ty)
92 b = b.to(C.dtype.element_ty)
93 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False)
94 A += BLOCK_K * SPLIT_K * stride_ak
95 B += BLOCK_K * SPLIT_K * stride_bk
96 acc = acc.to(C.dtype.element_ty)
97 # rematerialize rm and rn to save registers
98 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
99 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
100 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
101 mask = (rm < M)[:, None] & (rn < N)[None, :]
102 # handles write-back with reduction-splitting
103 if SPLIT_K == 1:
104 tl.store(C, acc, mask=mask)
105 else:
106 tl.atomic_add(C, acc, mask=mask)
109_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]
112def get_higher_dtype(a, b):
113 if a is b:
114 return a
116 assert a in _ordered_datatypes
117 assert b in _ordered_datatypes
119 for d in _ordered_datatypes:
120 if a is d:
121 return b
122 if b is d:
123 return a
126def mm(a, b):
127 logger.debug("GEMS_TSINGMICRO MM")
128 device = a.device
129 # handle non-contiguous inputs if necessary
130 if a.stride(0) > 1 and a.stride(1) > 1:
131 a = a.contiguous()
132 if b.stride(0) > 1 and b.stride(1) > 1:
133 b = b.contiguous()
134 # if not a.is_contiguous():
135 # a = a.contiguous()
136 # if not b.is_contiguous():
137 # b = b.contiguous()
138 # checks constraints
139 assert a.shape[1] == b.shape[0], "incompatible dimensions"
140 M, K = a.shape
141 _, N = b.shape
142 # allocates output
143 c_dtype = get_higher_dtype(a.dtype, b.dtype)
144 c = torch.empty((M, N), device=device, dtype=c_dtype)
145 dot_out_dtype = tl.float32
146 UPCAST = (
147 M * max(a.stride(0), c.stride(0)) >= 1 << 31
148 or N * max(b.stride(1), c.stride(1)) >= 1 << 31
149 or K * max(a.stride(1), b.stride(0)) >= 1 << 31
150 )
151 # launch kernel
152 grid = lambda META: (
153 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
154 META["SPLIT_K"],
155 )
156 with torch_device_fn.device(a.device):
157 mm_kernel[grid](
158 a,
159 b,
160 c,
161 M,
162 N,
163 K,
164 a.stride(0),
165 a.stride(1),
166 b.stride(0),
167 b.stride(1),
168 c.stride(0),
169 c.stride(1),
170 dot_out_dtype=dot_out_dtype,
171 GROUP_M=8,
172 UPCAST=UPCAST,
173 )
174 return c
177def mm_out(a, b, *, out):
178 logger.debug("GEMS_TSINGMICRO MM_OUT")
179 # handle non-contiguous inputs if necessary
180 if a.stride(0) > 1 and a.stride(1) > 1:
181 a = a.contiguous()
182 if b.stride(0) > 1 and b.stride(1) > 1:
183 b = b.contiguous()
184 # if not a.is_contiguous():
185 # a = a.contiguous()
186 # if not b.is_contiguous():
187 # b = b.contiguous()
188 # checks constraints
189 assert a.shape[1] == b.shape[0], "incompatible dimensions"
190 M, K = a.shape
191 _, N = b.shape
192 # allocates output
193 c = out
194 dot_out_dtype = tl.float32
195 UPCAST = (
196 M * max(a.stride(0), c.stride(0)) >= 1 << 31
197 or N * max(b.stride(1), c.stride(1)) >= 1 << 31
198 or K * max(a.stride(1), b.stride(0)) >= 1 << 31
199 )
200 # launch kernel
201 grid = lambda META: (
202 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
203 META["SPLIT_K"],
204 )
205 with torch_device_fn.device(a.device):
206 mm_kernel[grid](
207 a,
208 b,
209 c,
210 M,
211 N,
212 K,
213 a.stride(0),
214 a.stride(1),
215 b.stride(0),
216 b.stride(1),
217 c.stride(0),
218 c.stride(1),
219 dot_out_dtype=dot_out_dtype,
220 GROUP_M=8,
221 UPCAST=UPCAST,
222 )
223 return c