Coverage for src/flag_gems/runtime/backend/_spacemit/ops/conv2d.py: 0%
108 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.extra.smt as smt
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def fused_im2col_bmm_kernel(
17 input_ptr,
18 weight_ptr,
19 bias_ptr,
20 output_ptr,
21 im2col_buf_ptr,
22 N,
23 C,
24 IH,
25 IW,
26 KH,
27 KW,
28 OC,
29 stride_h,
30 stride_w,
31 pad_h,
32 pad_w,
33 dilation_h,
34 dilation_w,
35 OH,
36 OW,
37 GEMM_M,
38 GEMM_K,
39 KK,
40 input_stride_n,
41 input_stride_h,
42 input_stride_w,
43 input_stride_c,
44 im2col_stride_n,
45 im2col_stride_m,
46 im2col_stride_k,
47 weight_stride_oc,
48 weight_stride_k,
49 output_stride_n,
50 output_stride_oc,
51 output_stride_m,
52 NUM_IM2COL_BLOCKS: tl.constexpr,
53 NUM_BMM_TILES_PER_BATCH: tl.constexpr,
54 NUM_TILES_N: tl.constexpr,
55 BLOCK_SIZE_C: tl.constexpr,
56 TILE_M: tl.constexpr,
57 TILE_N: tl.constexpr,
58 TILE_K: tl.constexpr,
59 HAS_BIAS: tl.constexpr,
60 SUB_BLK_M: tl.constexpr,
61 MICRO_M: tl.constexpr,
62 MICRO_K: tl.constexpr,
63 MICRO_N: tl.constexpr,
64):
65 pid = tl.program_id(0)
66 n_im2col = pid // (OH * OW)
67 ohow = pid % (OH * OW)
68 oh = ohow // OW
69 ow = ohow % OW
70 window_h = oh * stride_h - pad_h
71 window_w = ow * stride_w - pad_w
72 bmm_pid = tl.maximum(pid - NUM_IM2COL_BLOCKS, 0)
73 pid_b = bmm_pid // NUM_BMM_TILES_PER_BATCH
74 local_tile = bmm_pid % NUM_BMM_TILES_PER_BATCH
75 pid_m = local_tile // NUM_TILES_N
76 pid_n = local_tile % NUM_TILES_N
77 block_m = pid_m * TILE_M
78 block_n = pid_n * TILE_N
79 bar = smt.global_mbarrier(0)
80 is_im2col = pid < NUM_IM2COL_BLOCKS
82 if is_im2col:
83 input_block_ptr = tl.make_block_ptr(
84 base=input_ptr,
85 shape=(N, IH, IW, C),
86 strides=(input_stride_n, input_stride_h, input_stride_w, input_stride_c),
87 offsets=(n_im2col, 0, 0, 0),
88 block_shape=(1, 1, 1, BLOCK_SIZE_C),
89 order=(3, 2, 1, 0),
90 )
91 output_col_base_ptr = tl.make_block_ptr(
92 base=im2col_buf_ptr,
93 shape=(N, GEMM_M, GEMM_K),
94 strides=(im2col_stride_n, im2col_stride_m, im2col_stride_k),
95 offsets=(n_im2col, ohow, 0),
96 block_shape=(1, 1, BLOCK_SIZE_C),
97 order=(2, 1, 0),
98 )
100 for kh in range(KH):
101 for kw in range(KW):
102 h = window_h + kh * dilation_h
103 w = window_w + kw * dilation_w
104 valid_h = (h >= 0) & (h < IH)
105 valid_w = (w >= 0) & (w < IW)
106 valid = valid_h & valid_w
107 for c_start in range(0, C, BLOCK_SIZE_C):
108 if valid:
109 input_ptr_cur = tl.advance(input_block_ptr, (0, h, w, c_start))
110 vals = tl.load(input_ptr_cur, boundary_check=(0, 1, 2, 3))
111 vals = tl.reshape(vals, (1, 1, BLOCK_SIZE_C))
112 else:
113 vals = tl.zeros(
114 (1, 1, BLOCK_SIZE_C), dtype=input_ptr.dtype.element_ty
115 )
116 col_idx = c_start * KK + kh * KW + kw
117 output_ptr_cur = tl.advance(output_col_base_ptr, (0, 0, col_idx))
118 tl.store(output_ptr_cur, vals, boundary_check=(0, 1, 2))
119 smt.barrier_arrive(bar)
121 else:
122 if pid == NUM_IM2COL_BLOCKS:
123 smt.barrier_set_expect(bar, NUM_IM2COL_BLOCKS)
125 smt.barrier_wait(bar)
126 a_ptr = tl.make_block_ptr(
127 base=im2col_buf_ptr,
128 shape=(N, GEMM_M, GEMM_K),
129 strides=(im2col_stride_n, im2col_stride_m, im2col_stride_k),
130 offsets=(pid_b, block_m, 0),
131 block_shape=(1, TILE_M, TILE_K),
132 order=(2, 1, 0),
133 )
135 b_ptr = tl.make_block_ptr(
136 base=weight_ptr,
137 shape=(OC, GEMM_K),
138 strides=(weight_stride_oc, weight_stride_k),
139 offsets=(block_n, 0),
140 block_shape=(TILE_N, TILE_K),
141 order=(1, 0),
142 )
144 if HAS_BIAS:
145 bias_block_ptr = tl.make_block_ptr(
146 base=bias_ptr,
147 shape=(OC,),
148 strides=(1,),
149 offsets=(block_n,),
150 block_shape=(TILE_N,),
151 order=(0,),
152 )
153 bias_vals = tl.load(bias_block_ptr, boundary_check=(0,))
154 output_ptr = output_ptr + pid_b * output_stride_n
156 a_tile = tl.load(a_ptr, boundary_check=(0, 1, 2))
157 a_tile = tl.trans(tl.reshape(a_tile, (TILE_M, TILE_K)))
158 b_descriptor_load = smt.descriptor_load(b_ptr, (0, 0))
159 b = smt.view(b_descriptor_load, (0, 0), (TILE_N, TILE_K), (MICRO_N, MICRO_K))
160 sub_num = (min(TILE_M, GEMM_M - TILE_M * pid_m) + SUB_BLK_M - 1) // SUB_BLK_M
161 for s in smt.parallel(0, sub_num):
162 a = smt.view(
163 a_tile, (0, s * SUB_BLK_M), (TILE_K, SUB_BLK_M), (MICRO_K, MICRO_M)
164 )
165 acc = smt.dot(b, a)
166 acc = smt.view(acc, (0, 0), (TILE_N, SUB_BLK_M), (1, 1))
167 if HAS_BIAS:
168 acc += bias_vals[:, None]
169 acc = acc.to(output_ptr.dtype.element_ty)
170 o_ptr = tl.make_block_ptr(
171 base=output_ptr,
172 shape=(OC, GEMM_M),
173 strides=(output_stride_oc, output_stride_m),
174 offsets=(block_n, block_m + s * SUB_BLK_M),
175 block_shape=(TILE_N, SUB_BLK_M),
176 order=(1, 0),
177 )
178 tl.store(o_ptr, acc, boundary_check=(0, 1))
181def conv2d(input, weight, bias=None, padding=0, stride=1, dilation=1, groups=1):
182 logger.debug("GEMS_SPACEMIT CONV2D")
184 N, C, H, W = input.shape
185 OC, _, KH, KW = weight.shape
187 str_h, str_w = (stride, stride) if isinstance(stride, int) else stride
188 pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
189 dil_h, dil_w = (dilation, dilation) if isinstance(dilation, int) else dilation
191 OH = (H + 2 * pad_h - dil_h * (KH - 1) - 1) // str_h + 1
192 OW = (W + 2 * pad_w - dil_w * (KW - 1) - 1) // str_w + 1
194 GEMM_M = OH * OW
195 KK = KH * KW
196 GEMM_K = C * KK
198 im2col_buf = torch.empty(
199 (N, GEMM_M, GEMM_K), dtype=input.dtype, device=input.device
200 )
202 output = torch.empty((N, OC, OH, OW), dtype=input.dtype, device=input.device)
204 input_nhwc = input.permute(0, 2, 3, 1).contiguous()
205 weight_flat = weight.view(OC, -1).contiguous()
207 NUM_IM2COL_BLOCKS = N * OH * OW
209 TILE_M = 128
210 TILE_N = 128
211 TILE_K = triton.next_power_of_2(GEMM_K)
212 BLOCK_SIZE_C = 32
213 SUB_BLK_M = 32
214 MICRO_M = 8
215 MICRO_K = 8
216 MICRO_N = 16
218 num_tiles_m = triton.cdiv(GEMM_M, TILE_M)
219 num_tiles_n = triton.cdiv(OC, TILE_N)
220 NUM_BMM_TILES_PER_BATCH = num_tiles_m * num_tiles_n
221 NUM_BMM_BLOCKS = N * NUM_BMM_TILES_PER_BATCH
223 total_blocks = NUM_IM2COL_BLOCKS + NUM_BMM_BLOCKS
224 grid = (total_blocks,)
226 if bias is not None:
227 bias_ptr = bias.contiguous()
228 else:
229 bias_ptr = torch.empty(0, device=input.device, dtype=input.dtype)
231 output_3d = output.view(N, OC, GEMM_M)
233 with torch_device_fn.device(input.device):
234 fused_im2col_bmm_kernel[grid](
235 input_nhwc,
236 weight_flat,
237 bias_ptr,
238 output_3d,
239 im2col_buf,
240 N,
241 C,
242 H,
243 W,
244 KH,
245 KW,
246 OC,
247 str_h,
248 str_w,
249 pad_h,
250 pad_w,
251 dil_h,
252 dil_w,
253 OH,
254 OW,
255 GEMM_M,
256 GEMM_K,
257 KK,
258 input_nhwc.stride(0),
259 input_nhwc.stride(1),
260 input_nhwc.stride(2),
261 input_nhwc.stride(3),
262 im2col_buf.stride(0),
263 im2col_buf.stride(1),
264 im2col_buf.stride(2),
265 weight_flat.stride(0),
266 weight_flat.stride(1),
267 output_3d.stride(0),
268 output_3d.stride(1),
269 output_3d.stride(2),
270 NUM_IM2COL_BLOCKS=NUM_IM2COL_BLOCKS,
271 NUM_BMM_TILES_PER_BATCH=NUM_BMM_TILES_PER_BATCH,
272 NUM_TILES_N=num_tiles_n,
273 BLOCK_SIZE_C=BLOCK_SIZE_C,
274 TILE_M=TILE_M,
275 TILE_N=TILE_N,
276 TILE_K=TILE_K,
277 HAS_BIAS=(bias is not None),
278 SUB_BLK_M=SUB_BLK_M,
279 MICRO_M=MICRO_M,
280 MICRO_K=MICRO_K,
281 MICRO_N=MICRO_N,
282 )
284 return output