Coverage for src/flag_gems/ops/col2im.py: 52%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from typing import List 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@triton.autotune( 

15 configs=[ 

16 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), 

17 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), 

18 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), 

19 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), 

20 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

21 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), 

22 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), 

23 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), 

24 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), 

25 ], 

26 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], 

27) 

28@triton.jit 

29def col2im_kernel( 

30 input_ptr, 

31 output_ptr, 

32 # Input tensor info 

33 in_stride_n, 

34 in_stride_ck, 

35 in_stride_l, 

36 # Output tensor info 

37 out_stride_n, 

38 out_stride_c, 

39 out_stride_h, 

40 out_stride_w, 

41 # Shapes 

42 batch_size, 

43 channels, 

44 out_h, 

45 out_w, 

46 L_h, 

47 L_w, 

48 # Kernel parameters 

49 kernel_h: tl.constexpr, 

50 kernel_w: tl.constexpr, 

51 stride_h: tl.constexpr, 

52 stride_w: tl.constexpr, 

53 padding_h: tl.constexpr, 

54 padding_w: tl.constexpr, 

55 dilation_h: tl.constexpr, 

56 dilation_w: tl.constexpr, 

57 # Tiling 

58 BLOCK_H: tl.constexpr, 

59 BLOCK_W: tl.constexpr, 

60): 

61 # Each program handles one (batch, channel) slice and a block of output positions 

62 pid_nc = tl.program_id(0) 

63 pid_hw = tl.program_id(1) 

64 

65 num_w_blocks = tl.cdiv(out_w, BLOCK_W) 

66 h_block_idx = pid_hw // num_w_blocks 

67 w_block_idx = pid_hw % num_w_blocks 

68 

69 n_idx = pid_nc // channels 

70 c_idx = pid_nc % channels 

71 

72 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) 

73 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) 

74 

75 # Accumulator for output values 

76 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) 

77 

78 # Base pointer to input for this batch 

79 input_base_ptr = input_ptr + n_idx * in_stride_n 

80 

81 # Iterate over kernel positions 

82 for kh in tl.static_range(0, kernel_h): 

83 for kw in tl.static_range(0, kernel_w): 

84 # Compute the numerators for l_h and l_w 

85 # l_h * stride_h = h + padding_h - kh * dilation_h 

86 # l_w * stride_w = w + padding_w - kw * dilation_w 

87 h_num = h_out_offsets[:, None] + padding_h - kh * dilation_h 

88 w_num = w_out_offsets[None, :] + padding_w - kw * dilation_w 

89 

90 # Check divisibility by stride 

91 h_valid = (h_num % stride_h) == 0 

92 w_valid = (w_num % stride_w) == 0 

93 

94 # Compute l_h and l_w 

95 l_h = h_num // stride_h 

96 l_w = w_num // stride_w 

97 

98 # Check bounds for l_h and l_w 

99 l_h_valid = (l_h >= 0) & (l_h < L_h) 

100 l_w_valid = (l_w >= 0) & (l_w < L_w) 

101 

102 # Combined mask 

103 valid_mask = h_valid & w_valid & l_h_valid & l_w_valid 

104 

105 # Compute input index 

106 # c_k = c * kernel_h * kernel_w + kh * kernel_w + kw 

107 c_k = c_idx * kernel_h * kernel_w + kh * kernel_w + kw 

108 # l = l_h * L_w + l_w 

109 l_idx = l_h * L_w + l_w 

110 

111 # Compute input offset 

112 input_offset = c_k * in_stride_ck + l_idx * in_stride_l 

113 

114 # Load input value (use 0 for invalid positions) 

115 input_val = tl.load( 

116 input_base_ptr + input_offset, mask=valid_mask, other=0.0 

117 ) 

118 

119 # Accumulate 

120 sum_acc += input_val 

121 

122 # Store output 

123 out_base_ptr = output_ptr + n_idx * out_stride_n + c_idx * out_stride_c 

124 out_offset = ( 

125 h_out_offsets[:, None] * out_stride_h + w_out_offsets[None, :] * out_stride_w 

126 ) 

127 

128 out_mask = (h_out_offsets[:, None] < out_h) & (w_out_offsets[None, :] < out_w) 

129 tl.store( 

130 out_base_ptr + out_offset, 

131 sum_acc.to(output_ptr.type.element_ty), 

132 mask=out_mask, 

133 ) 

134 

135 

136def _parse_col2im_params(output_size, kernel_size, dilation, padding, stride): 

137 """Parse and validate col2im parameters.""" 

138 

139 def _to_pair(val, name): 

140 if isinstance(val, int): 

141 return val, val 

142 if isinstance(val, (list, tuple)) and len(val) == 2: 

143 return tuple(val) 

144 raise ValueError(f"Invalid {name}: {val}") 

145 

146 out_h, out_w = _to_pair(output_size, "output_size") 

147 kernel_h, kernel_w = _to_pair(kernel_size, "kernel_size") 

148 dilation_h, dilation_w = _to_pair(dilation, "dilation") 

149 padding_h, padding_w = _to_pair(padding, "padding") 

150 stride_h, stride_w = _to_pair(stride, "stride") 

151 

152 if stride_h <= 0 or stride_w <= 0: 

153 raise ValueError(f"stride must be positive, got ({stride_h}, {stride_w})") 

154 if padding_h < 0 or padding_w < 0: 

155 raise ValueError( 

156 f"padding must be non-negative, got ({padding_h}, {padding_w})" 

157 ) 

158 if dilation_h <= 0 or dilation_w <= 0: 

159 raise ValueError(f"dilation must be positive, got ({dilation_h}, {dilation_w})") 

160 

161 return ( 

162 out_h, 

163 out_w, 

164 kernel_h, 

165 kernel_w, 

166 dilation_h, 

167 dilation_w, 

168 padding_h, 

169 padding_w, 

170 stride_h, 

171 stride_w, 

172 ) 

173 

174 

175def col2im( 

176 input: torch.Tensor, 

177 output_size: List[int], 

178 kernel_size: List[int], 

179 dilation: List[int], 

180 padding: List[int], 

181 stride: List[int], 

182) -> torch.Tensor: 

183 """ 

184 Combines an array of sliding local blocks into a large containing tensor. 

185 

186 This is the reverse operation of im2col (unfold). 

187 

188 Args: 

189 input: Input tensor of shape (N, C * kernel_h * kernel_w, L) 

190 where L is the number of sliding blocks. 

191 output_size: Shape of the output spatial dimensions (height, width). 

192 kernel_size: Size of the sliding blocks (height, width). 

193 dilation: Dilation of the sliding blocks (height, width). 

194 padding: Padding added to both sides of the input (height, width). 

195 stride: Stride of the sliding blocks (height, width). 

196 

197 Returns: 

198 Output tensor of shape (N, C, output_h, output_w). 

199 """ 

200 logger.debug("GEMS COL2IM") 

201 

202 # Parse parameters 

203 ( 

204 out_h, 

205 out_w, 

206 kernel_h, 

207 kernel_w, 

208 dilation_h, 

209 dilation_w, 

210 padding_h, 

211 padding_w, 

212 stride_h, 

213 stride_w, 

214 ) = _parse_col2im_params(output_size, kernel_size, dilation, padding, stride) 

215 

216 # Input shape validation 

217 if input.dim() != 3: 

218 raise ValueError(f"Expected 3D input, got {input.dim()}D") 

219 

220 batch_size, ck, L = input.shape 

221 

222 # Calculate expected L_h and L_w 

223 # L_h = (out_h + 2*padding_h - dilation_h*(kernel_h-1) - 1) / stride_h + 1 

224 L_h = (out_h + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1 

225 L_w = (out_w + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1 

226 expected_L = L_h * L_w 

227 

228 if L != expected_L: 

229 raise ValueError( 

230 f"Input size mismatch: expected L={expected_L} (L_h={L_h}, L_w={L_w}), got L={L}" 

231 ) 

232 

233 # Calculate channels 

234 kernel_size_total = kernel_h * kernel_w 

235 if ck % kernel_size_total != 0: 

236 raise ValueError( 

237 f"Input dimension 1 ({ck}) must be divisible by kernel_size ({kernel_size_total})" 

238 ) 

239 channels = ck // kernel_size_total 

240 

241 # Make input contiguous 

242 input = input.contiguous() 

243 

244 # Allocate output 

245 output = torch.empty( 

246 (batch_size, channels, out_h, out_w), 

247 device=input.device, 

248 dtype=input.dtype, 

249 ) 

250 

251 if output.numel() == 0: 

252 return output 

253 

254 # Launch kernel 

255 grid = lambda meta: ( 

256 batch_size * channels, 

257 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), 

258 ) 

259 

260 col2im_kernel[grid]( 

261 input, 

262 output, 

263 # Input strides 

264 input.stride(0), 

265 input.stride(1), 

266 input.stride(2), 

267 # Output strides 

268 output.stride(0), 

269 output.stride(1), 

270 output.stride(2), 

271 output.stride(3), 

272 # Shapes 

273 batch_size, 

274 channels, 

275 out_h, 

276 out_w, 

277 L_h, 

278 L_w, 

279 # Kernel parameters 

280 kernel_h, 

281 kernel_w, 

282 stride_h, 

283 stride_w, 

284 padding_h, 

285 padding_w, 

286 dilation_h, 

287 dilation_w, 

288 ) 

289 

290 return output