Coverage for src/flag_gems/ops/col2im.py: 52%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2from typing import List
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13@libentry()
14@triton.autotune(
15 configs=[
16 triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4),
17 triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4),
18 triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4),
19 triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8),
20 triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2),
21 triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2),
22 triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2),
23 triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8),
24 triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8),
25 ],
26 key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"],
27)
28@triton.jit
29def col2im_kernel(
30 input_ptr,
31 output_ptr,
32 # Input tensor info
33 in_stride_n,
34 in_stride_ck,
35 in_stride_l,
36 # Output tensor info
37 out_stride_n,
38 out_stride_c,
39 out_stride_h,
40 out_stride_w,
41 # Shapes
42 batch_size,
43 channels,
44 out_h,
45 out_w,
46 L_h,
47 L_w,
48 # Kernel parameters
49 kernel_h: tl.constexpr,
50 kernel_w: tl.constexpr,
51 stride_h: tl.constexpr,
52 stride_w: tl.constexpr,
53 padding_h: tl.constexpr,
54 padding_w: tl.constexpr,
55 dilation_h: tl.constexpr,
56 dilation_w: tl.constexpr,
57 # Tiling
58 BLOCK_H: tl.constexpr,
59 BLOCK_W: tl.constexpr,
60):
61 # Each program handles one (batch, channel) slice and a block of output positions
62 pid_nc = tl.program_id(0)
63 pid_hw = tl.program_id(1)
65 num_w_blocks = tl.cdiv(out_w, BLOCK_W)
66 h_block_idx = pid_hw // num_w_blocks
67 w_block_idx = pid_hw % num_w_blocks
69 n_idx = pid_nc // channels
70 c_idx = pid_nc % channels
72 h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H)
73 w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W)
75 # Accumulator for output values
76 sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32)
78 # Base pointer to input for this batch
79 input_base_ptr = input_ptr + n_idx * in_stride_n
81 # Iterate over kernel positions
82 for kh in tl.static_range(0, kernel_h):
83 for kw in tl.static_range(0, kernel_w):
84 # Compute the numerators for l_h and l_w
85 # l_h * stride_h = h + padding_h - kh * dilation_h
86 # l_w * stride_w = w + padding_w - kw * dilation_w
87 h_num = h_out_offsets[:, None] + padding_h - kh * dilation_h
88 w_num = w_out_offsets[None, :] + padding_w - kw * dilation_w
90 # Check divisibility by stride
91 h_valid = (h_num % stride_h) == 0
92 w_valid = (w_num % stride_w) == 0
94 # Compute l_h and l_w
95 l_h = h_num // stride_h
96 l_w = w_num // stride_w
98 # Check bounds for l_h and l_w
99 l_h_valid = (l_h >= 0) & (l_h < L_h)
100 l_w_valid = (l_w >= 0) & (l_w < L_w)
102 # Combined mask
103 valid_mask = h_valid & w_valid & l_h_valid & l_w_valid
105 # Compute input index
106 # c_k = c * kernel_h * kernel_w + kh * kernel_w + kw
107 c_k = c_idx * kernel_h * kernel_w + kh * kernel_w + kw
108 # l = l_h * L_w + l_w
109 l_idx = l_h * L_w + l_w
111 # Compute input offset
112 input_offset = c_k * in_stride_ck + l_idx * in_stride_l
114 # Load input value (use 0 for invalid positions)
115 input_val = tl.load(
116 input_base_ptr + input_offset, mask=valid_mask, other=0.0
117 )
119 # Accumulate
120 sum_acc += input_val
122 # Store output
123 out_base_ptr = output_ptr + n_idx * out_stride_n + c_idx * out_stride_c
124 out_offset = (
125 h_out_offsets[:, None] * out_stride_h + w_out_offsets[None, :] * out_stride_w
126 )
128 out_mask = (h_out_offsets[:, None] < out_h) & (w_out_offsets[None, :] < out_w)
129 tl.store(
130 out_base_ptr + out_offset,
131 sum_acc.to(output_ptr.type.element_ty),
132 mask=out_mask,
133 )
136def _parse_col2im_params(output_size, kernel_size, dilation, padding, stride):
137 """Parse and validate col2im parameters."""
139 def _to_pair(val, name):
140 if isinstance(val, int):
141 return val, val
142 if isinstance(val, (list, tuple)) and len(val) == 2:
143 return tuple(val)
144 raise ValueError(f"Invalid {name}: {val}")
146 out_h, out_w = _to_pair(output_size, "output_size")
147 kernel_h, kernel_w = _to_pair(kernel_size, "kernel_size")
148 dilation_h, dilation_w = _to_pair(dilation, "dilation")
149 padding_h, padding_w = _to_pair(padding, "padding")
150 stride_h, stride_w = _to_pair(stride, "stride")
152 if stride_h <= 0 or stride_w <= 0:
153 raise ValueError(f"stride must be positive, got ({stride_h}, {stride_w})")
154 if padding_h < 0 or padding_w < 0:
155 raise ValueError(
156 f"padding must be non-negative, got ({padding_h}, {padding_w})"
157 )
158 if dilation_h <= 0 or dilation_w <= 0:
159 raise ValueError(f"dilation must be positive, got ({dilation_h}, {dilation_w})")
161 return (
162 out_h,
163 out_w,
164 kernel_h,
165 kernel_w,
166 dilation_h,
167 dilation_w,
168 padding_h,
169 padding_w,
170 stride_h,
171 stride_w,
172 )
175def col2im(
176 input: torch.Tensor,
177 output_size: List[int],
178 kernel_size: List[int],
179 dilation: List[int],
180 padding: List[int],
181 stride: List[int],
182) -> torch.Tensor:
183 """
184 Combines an array of sliding local blocks into a large containing tensor.
186 This is the reverse operation of im2col (unfold).
188 Args:
189 input: Input tensor of shape (N, C * kernel_h * kernel_w, L)
190 where L is the number of sliding blocks.
191 output_size: Shape of the output spatial dimensions (height, width).
192 kernel_size: Size of the sliding blocks (height, width).
193 dilation: Dilation of the sliding blocks (height, width).
194 padding: Padding added to both sides of the input (height, width).
195 stride: Stride of the sliding blocks (height, width).
197 Returns:
198 Output tensor of shape (N, C, output_h, output_w).
199 """
200 logger.debug("GEMS COL2IM")
202 # Parse parameters
203 (
204 out_h,
205 out_w,
206 kernel_h,
207 kernel_w,
208 dilation_h,
209 dilation_w,
210 padding_h,
211 padding_w,
212 stride_h,
213 stride_w,
214 ) = _parse_col2im_params(output_size, kernel_size, dilation, padding, stride)
216 # Input shape validation
217 if input.dim() != 3:
218 raise ValueError(f"Expected 3D input, got {input.dim()}D")
220 batch_size, ck, L = input.shape
222 # Calculate expected L_h and L_w
223 # L_h = (out_h + 2*padding_h - dilation_h*(kernel_h-1) - 1) / stride_h + 1
224 L_h = (out_h + 2 * padding_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
225 L_w = (out_w + 2 * padding_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
226 expected_L = L_h * L_w
228 if L != expected_L:
229 raise ValueError(
230 f"Input size mismatch: expected L={expected_L} (L_h={L_h}, L_w={L_w}), got L={L}"
231 )
233 # Calculate channels
234 kernel_size_total = kernel_h * kernel_w
235 if ck % kernel_size_total != 0:
236 raise ValueError(
237 f"Input dimension 1 ({ck}) must be divisible by kernel_size ({kernel_size_total})"
238 )
239 channels = ck // kernel_size_total
241 # Make input contiguous
242 input = input.contiguous()
244 # Allocate output
245 output = torch.empty(
246 (batch_size, channels, out_h, out_w),
247 device=input.device,
248 dtype=input.dtype,
249 )
251 if output.numel() == 0:
252 return output
254 # Launch kernel
255 grid = lambda meta: (
256 batch_size * channels,
257 triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]),
258 )
260 col2im_kernel[grid](
261 input,
262 output,
263 # Input strides
264 input.stride(0),
265 input.stride(1),
266 input.stride(2),
267 # Output strides
268 output.stride(0),
269 output.stride(1),
270 output.stride(2),
271 output.stride(3),
272 # Shapes
273 batch_size,
274 channels,
275 out_h,
276 out_w,
277 L_h,
278 L_w,
279 # Kernel parameters
280 kernel_h,
281 kernel_w,
282 stride_h,
283 stride_w,
284 padding_h,
285 padding_w,
286 dilation_h,
287 dilation_w,
288 )
290 return output