Coverage for src/flag_gems/ops/conv_transpose1d.py: 54%
82 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13def conv_transpose1d_output_size(
14 in_size: int,
15 kernel_size: int,
16 stride: int,
17 padding: int,
18 output_padding: int,
19 dilation: int,
20) -> int:
21 """
22 Determines the output size of a 1D transposed convolution operation.
24 Args:
25 in_size: Input size.
26 kernel_size: Kernel size.
27 stride: Stride.
28 padding: Padding.
29 output_padding: Output padding.
30 dilation: Dilation.
32 Returns:
33 Output size of 1D transposed convolution.
34 """
35 return (
36 (in_size - 1) * stride
37 - 2 * padding
38 + dilation * (kernel_size - 1)
39 + output_padding
40 + 1
41 )
44@libentry()
45@triton.autotune(
46 configs=runtime.get_tuned_config("conv_transpose1d"),
47 key=[
48 "batch_size",
49 "in_channels",
50 "input_width",
51 "out_channels",
52 "out_width",
53 "kernel_width",
54 "stride_width",
55 "padding_width",
56 "groups",
57 ],
58)
59@triton.jit
60def conv_transpose1d_forward_kernel(
61 input_pointer,
62 weight_pointer,
63 output_pointer,
64 bias_pointer,
65 batch_size,
66 input_width,
67 out_channels,
68 out_width,
69 input_n_stride,
70 input_c_stride,
71 input_w_stride,
72 weight_ic_stride,
73 weight_oc_stride,
74 weight_w_stride,
75 output_n_stride,
76 output_c_stride,
77 output_w_stride,
78 in_channels: tl.constexpr,
79 kernel_width: tl.constexpr,
80 stride_width: tl.constexpr,
81 padding_width: tl.constexpr,
82 dilation_width: tl.constexpr,
83 groups: tl.constexpr,
84 BLOCK_N_OW: tl.constexpr,
85 BLOCK_IC: tl.constexpr,
86 BLOCK_OC: tl.constexpr,
87):
88 """
89 Triton kernel for 1D transposed convolution forward pass.
91 For transposed convolution:
92 - input has shape (N, in_channels, in_width)
93 - weight has shape (in_channels, out_channels/groups, kernel_width)
94 - output has shape (N, out_channels, out_width)
96 The output at position o is computed by summing contributions from all input
97 positions i where the kernel at position k could have produced output at o:
98 o = i * stride - padding + k * dilation
99 => i = (o + padding - k * dilation) / stride (must be integer)
100 """
101 pid_n_ow = tl.program_id(0)
102 pid_oc = tl.program_id(1)
103 pid_group = tl.program_id(2)
105 # Calculate batch and output width indices
106 n_ow_offset = pid_n_ow * BLOCK_N_OW + tl.arange(0, BLOCK_N_OW)
107 batch_idx = n_ow_offset // out_width
108 out_w_idx = n_ow_offset % out_width
110 # Output channel offset within this group
111 out_channels_per_group = out_channels // groups
112 # in_channels is already in_channels_per_group (passed from wrapper)
113 in_channels_per_group = in_channels
114 oc_offset = pid_oc * BLOCK_OC + tl.arange(0, BLOCK_OC)
116 # Initialize accumulator
117 accum = tl.zeros((BLOCK_N_OW, BLOCK_OC), dtype=tl.float32)
119 # Pointers setup
120 input_base = (
121 input_pointer
122 + (input_n_stride * batch_idx)[:, None]
123 + (input_c_stride * pid_group * in_channels_per_group)
124 )
125 weight_base = (
126 weight_pointer
127 + (weight_ic_stride * pid_group * in_channels_per_group)
128 + (weight_oc_stride * oc_offset)[None, :]
129 )
131 # Loop over input channels and kernel positions
132 BLOCK_IC_COUNT = (in_channels_per_group + BLOCK_IC - 1) // BLOCK_IC
133 for ic_k in range(BLOCK_IC_COUNT * kernel_width):
134 ic_block = (ic_k // kernel_width) * BLOCK_IC
135 k = ic_k % kernel_width
137 ic_offset = ic_block + tl.arange(0, BLOCK_IC)
139 # For transposed conv: out_w = in_w * stride - padding + k * dilation
140 # So: in_w = (out_w + padding - k * dilation) / stride
141 # We need in_w to be a valid integer index
143 # Calculate the input position that contributes to this output
144 numerator = out_w_idx + padding_width - k * dilation_width
146 # Check if this is divisible by stride
147 is_divisible = (numerator % stride_width) == 0
148 in_w_idx = numerator // stride_width
150 # Load input values
151 curr_input_pointer = (
152 input_base
153 + (input_c_stride * ic_offset)[None, :]
154 + (input_w_stride * in_w_idx)[:, None]
155 )
156 input_mask = (
157 (batch_idx < batch_size)[:, None]
158 & (ic_offset < in_channels_per_group)[None, :]
159 & is_divisible[:, None]
160 & (in_w_idx >= 0)[:, None]
161 & (in_w_idx < input_width)[:, None]
162 )
163 input_block = tl.load(curr_input_pointer, mask=input_mask, other=0.0)
165 # Load weight values
166 # Weight shape: (in_channels, out_channels/groups, kernel_width)
167 curr_weight_pointer = (
168 weight_base
169 + (weight_ic_stride * ic_offset)[:, None]
170 + (weight_w_stride * k)
171 )
172 weight_mask = (ic_offset < in_channels_per_group)[:, None] & (
173 oc_offset < out_channels_per_group
174 )[None, :]
175 weight_block = tl.load(curr_weight_pointer, mask=weight_mask, other=0.0)
177 # Accumulate: input_block is [BLOCK_N_OW, BLOCK_IC], weight_block is [BLOCK_IC, BLOCK_OC]
178 accum += tl.dot(
179 input_block.to(tl.float32), weight_block.to(tl.float32), allow_tf32=False
180 )
182 # Add bias if present
183 bias_ptr = bias_pointer + pid_group * out_channels_per_group + oc_offset
184 bias_mask = oc_offset < out_channels_per_group
185 bias = tl.load(bias_ptr, mask=bias_mask, other=0.0).to(tl.float32)
186 accum += bias[None, :]
188 # Store output
189 output_ptr = (
190 output_pointer
191 + (output_n_stride * batch_idx)[:, None]
192 + (output_c_stride * (pid_group * out_channels_per_group + oc_offset))[None, :]
193 + (output_w_stride * out_w_idx)[:, None]
194 )
195 output_mask = (
196 (batch_idx < batch_size)[:, None]
197 & (oc_offset < out_channels_per_group)[None, :]
198 & (out_w_idx < out_width)[:, None]
199 )
200 tl.store(output_ptr, accum, mask=output_mask)
203def conv_transpose1d(
204 input,
205 weight,
206 bias=None,
207 stride=1,
208 padding=0,
209 output_padding=0,
210 groups=1,
211 dilation=1,
212):
213 """
214 Applies a 1D transposed convolution operator over an input signal.
216 Args:
217 input: Input tensor of shape (N, in_channels, L_in)
218 weight: Filters of shape (in_channels, out_channels/groups, kernel_width)
219 bias: Optional bias of shape (out_channels). Default: None
220 stride: Stride of the convolution. Default: 1
221 padding: Zero-padding added to both sides. Default: 0
222 output_padding: Additional size added to output shape. Default: 0
223 groups: Number of blocked connections. Default: 1
224 dilation: Spacing between kernel elements. Default: 1
226 Returns:
227 Output tensor of shape (N, out_channels, L_out)
228 """
229 logger.debug("GEMS CONV_TRANSPOSE1D")
231 assert input.ndim == 3, f"Input must be 3D, received shape {input.shape}"
232 assert weight.ndim == 3, f"Weights must be 3D, received shape {weight.shape}"
233 assert (
234 bias is None or bias.ndim == 1
235 ), f"Bias must be 1D, received shape {bias.shape}"
237 # Parse stride, padding, output_padding, dilation
238 if isinstance(stride, (list, tuple)):
239 stride_width = stride[0]
240 else:
241 stride_width = stride
243 if isinstance(padding, (list, tuple)):
244 padding_width = padding[0]
245 else:
246 padding_width = padding
248 if isinstance(output_padding, (list, tuple)):
249 output_padding_width = output_padding[0]
250 else:
251 output_padding_width = output_padding
253 if isinstance(dilation, (list, tuple)):
254 dilation_width = dilation[0]
255 else:
256 dilation_width = dilation
258 batch_size, in_channels, input_width = input.shape
259 in_channels_weight, out_channels_per_group, kernel_width = weight.shape
261 assert (
262 in_channels == in_channels_weight
263 ), f"Input channels ({in_channels}) must match weight in_channels ({in_channels_weight})"
264 assert (
265 in_channels % groups == 0
266 ), f"in_channels ({in_channels}) must be divisible by groups ({groups})"
268 out_channels = out_channels_per_group * groups
270 assert (
271 bias is None or bias.shape[0] == out_channels
272 ), f"Bias shape ({bias.shape}) doesn't match out_channels ({out_channels})"
274 # Calculate output size
275 out_width = conv_transpose1d_output_size(
276 input_width,
277 kernel_width,
278 stride_width,
279 padding_width,
280 output_padding_width,
281 dilation_width,
282 )
284 # Allocate output
285 output_dtype = input.dtype
286 output = torch.empty(
287 (batch_size, out_channels, out_width),
288 device=input.device,
289 dtype=output_dtype,
290 )
292 # Grid: (batch * out_width blocks, out_channels blocks, groups)
293 grid = lambda META: (
294 triton.cdiv(batch_size * out_width, META["BLOCK_N_OW"]),
295 triton.cdiv(out_channels_per_group, META["BLOCK_OC"]),
296 groups,
297 )
299 # Create bias pointer (zeros if no bias)
300 if bias is None:
301 bias_pointer = torch.zeros(
302 out_channels, device=input.device, dtype=output_dtype
303 )
304 else:
305 bias_pointer = bias
307 # Ensure contiguous tensors
308 input_contig = input.contiguous()
309 weight_contig = weight.contiguous()
311 in_channels_per_group = in_channels // groups
313 conv_transpose1d_forward_kernel[grid](
314 input_contig,
315 weight_contig,
316 output,
317 bias_pointer,
318 batch_size,
319 input_width,
320 out_channels,
321 out_width,
322 *input_contig.stride(),
323 *weight_contig.stride(),
324 *output.stride(),
325 in_channels_per_group,
326 kernel_width,
327 stride_width,
328 padding_width,
329 dilation_width,
330 groups=groups,
331 )
333 return output