Coverage for src/flag_gems/runtime/backend/_spacemit/ops/mm.py: 0%
102 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import torch
2import triton
3import triton.language as tl
4import triton.language.extra.smt as smt
6from flag_gems import runtime
7from flag_gems.utils import libentry, libtuner
10@libentry()
11@libtuner(
12 configs=runtime.get_tuned_config("mm_spacemit"),
13 key=["M", "N", "K"],
14)
15@triton.jit
16def mm_kernel(
17 a_ptr,
18 b_ptr,
19 c_ptr,
20 M,
21 N,
22 K,
23 stride_am,
24 stride_ak,
25 stride_bk,
26 stride_bn,
27 stride_cm,
28 stride_cn,
29 BLOCK_SIZE_M: tl.constexpr,
30 BLOCK_SIZE_N: tl.constexpr,
31 BLOCK_SIZE_K: tl.constexpr,
32 EVEN_K: tl.constexpr,
33 SPLIT_M: tl.constexpr,
34 SPLIT_N: tl.constexpr,
35 SPLIT_K: tl.constexpr,
36 SUB_BLK_M: tl.constexpr,
37 SUB_BLK_N: tl.constexpr,
38 MICRO_M: tl.constexpr,
39 MICRO_K: tl.constexpr,
40 MICRO_N: tl.constexpr,
41 SUB_BLK_K: tl.constexpr,
42):
43 pid_m = tl.program_id(0)
44 pid_n = tl.program_id(1)
45 a_block_ptr = tl.make_block_ptr(
46 base=a_ptr,
47 shape=[M, K],
48 strides=[stride_am, stride_ak],
49 offsets=[pid_m * BLOCK_SIZE_M, 0],
50 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
51 order=[1, 0],
52 )
54 b_block_ptr = tl.make_block_ptr(
55 base=b_ptr,
56 shape=[K, N],
57 strides=[stride_bk, stride_bn],
58 offsets=[0, pid_n * BLOCK_SIZE_N],
59 block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N],
60 order=[1, 0],
61 )
63 if EVEN_K:
64 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0))
65 a = smt.view(
66 a_descriptor_load, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K)
67 )
68 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0))
69 b = smt.view(
70 b_descriptor_load, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N)
71 )
72 accumulator = smt.dot(a, b)
73 accumulator = smt.view(
74 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (1, 1)
75 )
76 c = accumulator.to(c_ptr.dtype.element_ty)
77 c_block_ptr = tl.make_block_ptr(
78 base=c_ptr,
79 shape=[M, N],
80 strides=[stride_cm, stride_cn],
81 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
82 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
83 order=[1, 0],
84 )
85 tl.store(c_block_ptr, c, boundary_check=(0, 1))
87 elif SPLIT_M:
88 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0))
89 b = smt.view(
90 b_descriptor_load, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N)
91 )
92 sub_num = (
93 min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m) + SUB_BLK_M - 1
94 ) // SUB_BLK_M
95 for s in smt.parallel(0, sub_num):
96 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0))
97 a = smt.view(
98 a_descriptor_load,
99 (s * SUB_BLK_M, 0),
100 (SUB_BLK_M, BLOCK_SIZE_K),
101 (MICRO_M, MICRO_K),
102 )
103 accumulator = smt.dot(a, b)
104 accumulator = smt.view(
105 accumulator, (0, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1)
106 )
107 c = accumulator.to(c_ptr.dtype.element_ty)
108 c_block_ptr = tl.make_block_ptr(
109 base=c_ptr,
110 shape=[M, N],
111 strides=[stride_cm, stride_cn],
112 offsets=[pid_m * BLOCK_SIZE_M + s * SUB_BLK_M, pid_n * BLOCK_SIZE_N],
113 block_shape=[SUB_BLK_M, BLOCK_SIZE_N],
114 order=[1, 0],
115 )
116 tl.store(c_block_ptr, c, boundary_check=(0, 1))
118 elif SPLIT_N:
119 sub_num_m = (
120 min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m) + SUB_BLK_M - 1
121 ) // SUB_BLK_M
122 sub_num_n = (
123 min(BLOCK_SIZE_N, N - BLOCK_SIZE_N * pid_n) + SUB_BLK_N - 1
124 ) // SUB_BLK_N
125 total_sub_blocks = sub_num_m * sub_num_n
126 b_alloc_ptr = smt.alloc(shape=[BLOCK_SIZE_K, BLOCK_SIZE_N])
127 b_alloc_view_ptr = smt.view(
128 b_alloc_ptr, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N)
129 )
130 bar = smt.mbarrier(flag=0, expect_count=sub_num_n)
131 for s in smt.parallel(0, total_sub_blocks):
132 s_m = s // sub_num_n
133 s_n = s % sub_num_n
134 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0))
135 a = smt.view(
136 a_descriptor_load,
137 (s_m * SUB_BLK_M, 0),
138 (SUB_BLK_M, BLOCK_SIZE_K),
139 (MICRO_M, MICRO_K),
140 )
141 b_alloc_sub_ptr = smt.view(
142 b_alloc_view_ptr, (0, s_n * SUB_BLK_N), (BLOCK_SIZE_K, SUB_BLK_N)
143 )
144 if s_m == 0:
145 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0))
146 b = smt.view(
147 b_descriptor_load,
148 (0, s_n * SUB_BLK_N),
149 (BLOCK_SIZE_K, SUB_BLK_N),
150 (MICRO_K, MICRO_N),
151 )
152 tl.store(b_alloc_sub_ptr, b, boundary_check=(0, 1, 2, 3))
153 smt.barrier_arrive(bar)
154 else:
155 smt.barrier_wait(bar, flag=1)
157 b_alloc = tl.load(b_alloc_sub_ptr, boundary_check=(0, 1, 2, 3))
158 accumulator = smt.dot(a, b_alloc)
159 accumulator = smt.view(accumulator, (0, 0), (SUB_BLK_M, SUB_BLK_N), (1, 1))
160 c = accumulator.to(c_ptr.dtype.element_ty)
161 c_block_ptr = tl.make_block_ptr(
162 base=c_ptr,
163 shape=[M, N],
164 strides=[stride_cm, stride_cn],
165 offsets=[
166 pid_m * BLOCK_SIZE_M + s_m * SUB_BLK_M,
167 pid_n * BLOCK_SIZE_N + s_n * SUB_BLK_N,
168 ],
169 block_shape=[SUB_BLK_M, SUB_BLK_N],
170 order=[1, 0],
171 )
172 tl.store(c_block_ptr, c, boundary_check=(0, 1))
174 elif SPLIT_K:
175 accumulator = tl.zeros(
176 (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty
177 )
178 accumulator = smt.view(
179 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N)
180 )
181 sub_num = (K + SUB_BLK_K - 1) // SUB_BLK_K
182 for k in tl.range(0, sub_num):
183 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0))
184 a = smt.view(
185 a_descriptor_load,
186 (0, k * SUB_BLK_K),
187 (BLOCK_SIZE_M, SUB_BLK_K),
188 (MICRO_M, MICRO_K),
189 )
190 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0))
191 b = smt.view(
192 b_descriptor_load,
193 (k * SUB_BLK_K, 0),
194 (SUB_BLK_K, BLOCK_SIZE_N),
195 (MICRO_K, MICRO_N),
196 )
197 accumulator += smt.dot(a, b)
198 accumulator = smt.view(
199 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (1, 1)
200 )
201 c = accumulator.to(c_ptr.dtype.element_ty)
203 c_block_ptr = tl.make_block_ptr(
204 base=c_ptr,
205 shape=[M, N],
206 strides=[stride_cm, stride_cn],
207 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
208 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
209 order=[1, 0],
210 )
211 tl.store(c_block_ptr, c, boundary_check=(0, 1))
214def mm(a, b):
215 if not a.is_contiguous():
216 a = a.contiguous()
217 if b.stride(0) > 1 and b.stride(1) > 1:
218 b = b.contiguous()
219 # checks constraints
220 assert a.shape[1] == b.shape[0], "incompatible dimensions"
221 M, K = a.shape
222 _, N = b.shape
223 # allocates output
224 c = torch.empty((M, N), device=a.device, dtype=a.dtype)
225 # launch kernel
226 grid = lambda META: (
227 triton.cdiv(M, META["BLOCK_SIZE_M"]),
228 triton.cdiv(N, META["BLOCK_SIZE_N"]),
229 )
230 BLOCK_SIZE_K = triton.next_power_of_2(K)
231 SUB_BLK_K = min(512, BLOCK_SIZE_K)
233 mm_kernel[grid](
234 a,
235 b,
236 c,
237 M,
238 N,
239 K,
240 a.stride(0),
241 a.stride(1),
242 b.stride(0),
243 b.stride(1),
244 c.stride(0),
245 c.stride(1),
246 BLOCK_SIZE_K=BLOCK_SIZE_K,
247 SUB_BLK_K=SUB_BLK_K,
248 )
249 return c
252def mm_out(a, b, *, out):
253 if not a.is_contiguous():
254 a = a.contiguous()
255 if b.stride(0) > 1 and b.stride(1) > 1:
256 b = b.contiguous()
258 # checks constraints
259 assert a.shape[1] == b.shape[0], "incompatible dimensions"
260 M, K = a.shape
261 _, N = b.shape
263 # launch kernel
264 grid = lambda META: (
265 triton.cdiv(M, META["BLOCK_SIZE_M"]),
266 triton.cdiv(N, META["BLOCK_SIZE_N"]),
267 )
268 BLOCK_SIZE_K = triton.next_power_of_2(K)
269 SUB_BLK_K = min(512, BLOCK_SIZE_K)
271 mm_kernel[grid](
272 a,
273 b,
274 out,
275 M,
276 N,
277 K,
278 a.stride(0),
279 a.stride(1),
280 b.stride(0),
281 b.stride(1),
282 out.stride(0),
283 out.stride(1),
284 BLOCK_SIZE_K=BLOCK_SIZE_K,
285 SUB_BLK_K=SUB_BLK_K,
286 )
287 return out