Coverage for src/flag_gems/runtime/backend/_spacemit/ops/conv2d.py: 0%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.extra.smt as smt 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def fused_im2col_bmm_kernel( 

17 input_ptr, 

18 weight_ptr, 

19 bias_ptr, 

20 output_ptr, 

21 im2col_buf_ptr, 

22 N, 

23 C, 

24 IH, 

25 IW, 

26 KH, 

27 KW, 

28 OC, 

29 stride_h, 

30 stride_w, 

31 pad_h, 

32 pad_w, 

33 dilation_h, 

34 dilation_w, 

35 OH, 

36 OW, 

37 GEMM_M, 

38 GEMM_K, 

39 KK, 

40 input_stride_n, 

41 input_stride_h, 

42 input_stride_w, 

43 input_stride_c, 

44 im2col_stride_n, 

45 im2col_stride_m, 

46 im2col_stride_k, 

47 weight_stride_oc, 

48 weight_stride_k, 

49 output_stride_n, 

50 output_stride_oc, 

51 output_stride_m, 

52 NUM_IM2COL_BLOCKS: tl.constexpr, 

53 NUM_BMM_TILES_PER_BATCH: tl.constexpr, 

54 NUM_TILES_N: tl.constexpr, 

55 BLOCK_SIZE_C: tl.constexpr, 

56 TILE_M: tl.constexpr, 

57 TILE_N: tl.constexpr, 

58 TILE_K: tl.constexpr, 

59 HAS_BIAS: tl.constexpr, 

60 SUB_BLK_M: tl.constexpr, 

61 MICRO_M: tl.constexpr, 

62 MICRO_K: tl.constexpr, 

63 MICRO_N: tl.constexpr, 

64): 

65 pid = tl.program_id(0) 

66 n_im2col = pid // (OH * OW) 

67 ohow = pid % (OH * OW) 

68 oh = ohow // OW 

69 ow = ohow % OW 

70 window_h = oh * stride_h - pad_h 

71 window_w = ow * stride_w - pad_w 

72 bmm_pid = tl.maximum(pid - NUM_IM2COL_BLOCKS, 0) 

73 pid_b = bmm_pid // NUM_BMM_TILES_PER_BATCH 

74 local_tile = bmm_pid % NUM_BMM_TILES_PER_BATCH 

75 pid_m = local_tile // NUM_TILES_N 

76 pid_n = local_tile % NUM_TILES_N 

77 block_m = pid_m * TILE_M 

78 block_n = pid_n * TILE_N 

79 bar = smt.global_mbarrier(0) 

80 is_im2col = pid < NUM_IM2COL_BLOCKS 

81 

82 if is_im2col: 

83 input_block_ptr = tl.make_block_ptr( 

84 base=input_ptr, 

85 shape=(N, IH, IW, C), 

86 strides=(input_stride_n, input_stride_h, input_stride_w, input_stride_c), 

87 offsets=(n_im2col, 0, 0, 0), 

88 block_shape=(1, 1, 1, BLOCK_SIZE_C), 

89 order=(3, 2, 1, 0), 

90 ) 

91 output_col_base_ptr = tl.make_block_ptr( 

92 base=im2col_buf_ptr, 

93 shape=(N, GEMM_M, GEMM_K), 

94 strides=(im2col_stride_n, im2col_stride_m, im2col_stride_k), 

95 offsets=(n_im2col, ohow, 0), 

96 block_shape=(1, 1, BLOCK_SIZE_C), 

97 order=(2, 1, 0), 

98 ) 

99 

100 for kh in range(KH): 

101 for kw in range(KW): 

102 h = window_h + kh * dilation_h 

103 w = window_w + kw * dilation_w 

104 valid_h = (h >= 0) & (h < IH) 

105 valid_w = (w >= 0) & (w < IW) 

106 valid = valid_h & valid_w 

107 for c_start in range(0, C, BLOCK_SIZE_C): 

108 if valid: 

109 input_ptr_cur = tl.advance(input_block_ptr, (0, h, w, c_start)) 

110 vals = tl.load(input_ptr_cur, boundary_check=(0, 1, 2, 3)) 

111 vals = tl.reshape(vals, (1, 1, BLOCK_SIZE_C)) 

112 else: 

113 vals = tl.zeros( 

114 (1, 1, BLOCK_SIZE_C), dtype=input_ptr.dtype.element_ty 

115 ) 

116 col_idx = c_start * KK + kh * KW + kw 

117 output_ptr_cur = tl.advance(output_col_base_ptr, (0, 0, col_idx)) 

118 tl.store(output_ptr_cur, vals, boundary_check=(0, 1, 2)) 

119 smt.barrier_arrive(bar) 

120 

121 else: 

122 if pid == NUM_IM2COL_BLOCKS: 

123 smt.barrier_set_expect(bar, NUM_IM2COL_BLOCKS) 

124 

125 smt.barrier_wait(bar) 

126 a_ptr = tl.make_block_ptr( 

127 base=im2col_buf_ptr, 

128 shape=(N, GEMM_M, GEMM_K), 

129 strides=(im2col_stride_n, im2col_stride_m, im2col_stride_k), 

130 offsets=(pid_b, block_m, 0), 

131 block_shape=(1, TILE_M, TILE_K), 

132 order=(2, 1, 0), 

133 ) 

134 

135 b_ptr = tl.make_block_ptr( 

136 base=weight_ptr, 

137 shape=(OC, GEMM_K), 

138 strides=(weight_stride_oc, weight_stride_k), 

139 offsets=(block_n, 0), 

140 block_shape=(TILE_N, TILE_K), 

141 order=(1, 0), 

142 ) 

143 

144 if HAS_BIAS: 

145 bias_block_ptr = tl.make_block_ptr( 

146 base=bias_ptr, 

147 shape=(OC,), 

148 strides=(1,), 

149 offsets=(block_n,), 

150 block_shape=(TILE_N,), 

151 order=(0,), 

152 ) 

153 bias_vals = tl.load(bias_block_ptr, boundary_check=(0,)) 

154 output_ptr = output_ptr + pid_b * output_stride_n 

155 

156 a_tile = tl.load(a_ptr, boundary_check=(0, 1, 2)) 

157 a_tile = tl.trans(tl.reshape(a_tile, (TILE_M, TILE_K))) 

158 b_descriptor_load = smt.descriptor_load(b_ptr, (0, 0)) 

159 b = smt.view(b_descriptor_load, (0, 0), (TILE_N, TILE_K), (MICRO_N, MICRO_K)) 

160 sub_num = (min(TILE_M, GEMM_M - TILE_M * pid_m) + SUB_BLK_M - 1) // SUB_BLK_M 

161 for s in smt.parallel(0, sub_num): 

162 a = smt.view( 

163 a_tile, (0, s * SUB_BLK_M), (TILE_K, SUB_BLK_M), (MICRO_K, MICRO_M) 

164 ) 

165 acc = smt.dot(b, a) 

166 acc = smt.view(acc, (0, 0), (TILE_N, SUB_BLK_M), (1, 1)) 

167 if HAS_BIAS: 

168 acc += bias_vals[:, None] 

169 acc = acc.to(output_ptr.dtype.element_ty) 

170 o_ptr = tl.make_block_ptr( 

171 base=output_ptr, 

172 shape=(OC, GEMM_M), 

173 strides=(output_stride_oc, output_stride_m), 

174 offsets=(block_n, block_m + s * SUB_BLK_M), 

175 block_shape=(TILE_N, SUB_BLK_M), 

176 order=(1, 0), 

177 ) 

178 tl.store(o_ptr, acc, boundary_check=(0, 1)) 

179 

180 

181def conv2d(input, weight, bias=None, padding=0, stride=1, dilation=1, groups=1): 

182 logger.debug("GEMS_SPACEMIT CONV2D") 

183 

184 N, C, H, W = input.shape 

185 OC, _, KH, KW = weight.shape 

186 

187 str_h, str_w = (stride, stride) if isinstance(stride, int) else stride 

188 pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding 

189 dil_h, dil_w = (dilation, dilation) if isinstance(dilation, int) else dilation 

190 

191 OH = (H + 2 * pad_h - dil_h * (KH - 1) - 1) // str_h + 1 

192 OW = (W + 2 * pad_w - dil_w * (KW - 1) - 1) // str_w + 1 

193 

194 GEMM_M = OH * OW 

195 KK = KH * KW 

196 GEMM_K = C * KK 

197 

198 im2col_buf = torch.empty( 

199 (N, GEMM_M, GEMM_K), dtype=input.dtype, device=input.device 

200 ) 

201 

202 output = torch.empty((N, OC, OH, OW), dtype=input.dtype, device=input.device) 

203 

204 input_nhwc = input.permute(0, 2, 3, 1).contiguous() 

205 weight_flat = weight.view(OC, -1).contiguous() 

206 

207 NUM_IM2COL_BLOCKS = N * OH * OW 

208 

209 TILE_M = 128 

210 TILE_N = 128 

211 TILE_K = triton.next_power_of_2(GEMM_K) 

212 BLOCK_SIZE_C = 32 

213 SUB_BLK_M = 32 

214 MICRO_M = 8 

215 MICRO_K = 8 

216 MICRO_N = 16 

217 

218 num_tiles_m = triton.cdiv(GEMM_M, TILE_M) 

219 num_tiles_n = triton.cdiv(OC, TILE_N) 

220 NUM_BMM_TILES_PER_BATCH = num_tiles_m * num_tiles_n 

221 NUM_BMM_BLOCKS = N * NUM_BMM_TILES_PER_BATCH 

222 

223 total_blocks = NUM_IM2COL_BLOCKS + NUM_BMM_BLOCKS 

224 grid = (total_blocks,) 

225 

226 if bias is not None: 

227 bias_ptr = bias.contiguous() 

228 else: 

229 bias_ptr = torch.empty(0, device=input.device, dtype=input.dtype) 

230 

231 output_3d = output.view(N, OC, GEMM_M) 

232 

233 with torch_device_fn.device(input.device): 

234 fused_im2col_bmm_kernel[grid]( 

235 input_nhwc, 

236 weight_flat, 

237 bias_ptr, 

238 output_3d, 

239 im2col_buf, 

240 N, 

241 C, 

242 H, 

243 W, 

244 KH, 

245 KW, 

246 OC, 

247 str_h, 

248 str_w, 

249 pad_h, 

250 pad_w, 

251 dil_h, 

252 dil_w, 

253 OH, 

254 OW, 

255 GEMM_M, 

256 GEMM_K, 

257 KK, 

258 input_nhwc.stride(0), 

259 input_nhwc.stride(1), 

260 input_nhwc.stride(2), 

261 input_nhwc.stride(3), 

262 im2col_buf.stride(0), 

263 im2col_buf.stride(1), 

264 im2col_buf.stride(2), 

265 weight_flat.stride(0), 

266 weight_flat.stride(1), 

267 output_3d.stride(0), 

268 output_3d.stride(1), 

269 output_3d.stride(2), 

270 NUM_IM2COL_BLOCKS=NUM_IM2COL_BLOCKS, 

271 NUM_BMM_TILES_PER_BATCH=NUM_BMM_TILES_PER_BATCH, 

272 NUM_TILES_N=num_tiles_n, 

273 BLOCK_SIZE_C=BLOCK_SIZE_C, 

274 TILE_M=TILE_M, 

275 TILE_N=TILE_N, 

276 TILE_K=TILE_K, 

277 HAS_BIAS=(bias is not None), 

278 SUB_BLK_M=SUB_BLK_M, 

279 MICRO_M=MICRO_M, 

280 MICRO_K=MICRO_K, 

281 MICRO_N=MICRO_N, 

282 ) 

283 

284 return output