Coverage for src/flag_gems/ops/conv_transpose1d.py: 54%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def conv_transpose1d_output_size( 

14 in_size: int, 

15 kernel_size: int, 

16 stride: int, 

17 padding: int, 

18 output_padding: int, 

19 dilation: int, 

20) -> int: 

21 """ 

22 Determines the output size of a 1D transposed convolution operation. 

23 

24 Args: 

25 in_size: Input size. 

26 kernel_size: Kernel size. 

27 stride: Stride. 

28 padding: Padding. 

29 output_padding: Output padding. 

30 dilation: Dilation. 

31 

32 Returns: 

33 Output size of 1D transposed convolution. 

34 """ 

35 return ( 

36 (in_size - 1) * stride 

37 - 2 * padding 

38 + dilation * (kernel_size - 1) 

39 + output_padding 

40 + 1 

41 ) 

42 

43 

44@libentry() 

45@triton.autotune( 

46 configs=runtime.get_tuned_config("conv_transpose1d"), 

47 key=[ 

48 "batch_size", 

49 "in_channels", 

50 "input_width", 

51 "out_channels", 

52 "out_width", 

53 "kernel_width", 

54 "stride_width", 

55 "padding_width", 

56 "groups", 

57 ], 

58) 

59@triton.jit 

60def conv_transpose1d_forward_kernel( 

61 input_pointer, 

62 weight_pointer, 

63 output_pointer, 

64 bias_pointer, 

65 batch_size, 

66 input_width, 

67 out_channels, 

68 out_width, 

69 input_n_stride, 

70 input_c_stride, 

71 input_w_stride, 

72 weight_ic_stride, 

73 weight_oc_stride, 

74 weight_w_stride, 

75 output_n_stride, 

76 output_c_stride, 

77 output_w_stride, 

78 in_channels: tl.constexpr, 

79 kernel_width: tl.constexpr, 

80 stride_width: tl.constexpr, 

81 padding_width: tl.constexpr, 

82 dilation_width: tl.constexpr, 

83 groups: tl.constexpr, 

84 BLOCK_N_OW: tl.constexpr, 

85 BLOCK_IC: tl.constexpr, 

86 BLOCK_OC: tl.constexpr, 

87): 

88 """ 

89 Triton kernel for 1D transposed convolution forward pass. 

90 

91 For transposed convolution: 

92 - input has shape (N, in_channels, in_width) 

93 - weight has shape (in_channels, out_channels/groups, kernel_width) 

94 - output has shape (N, out_channels, out_width) 

95 

96 The output at position o is computed by summing contributions from all input 

97 positions i where the kernel at position k could have produced output at o: 

98 o = i * stride - padding + k * dilation 

99 => i = (o + padding - k * dilation) / stride (must be integer) 

100 """ 

101 pid_n_ow = tl.program_id(0) 

102 pid_oc = tl.program_id(1) 

103 pid_group = tl.program_id(2) 

104 

105 # Calculate batch and output width indices 

106 n_ow_offset = pid_n_ow * BLOCK_N_OW + tl.arange(0, BLOCK_N_OW) 

107 batch_idx = n_ow_offset // out_width 

108 out_w_idx = n_ow_offset % out_width 

109 

110 # Output channel offset within this group 

111 out_channels_per_group = out_channels // groups 

112 # in_channels is already in_channels_per_group (passed from wrapper) 

113 in_channels_per_group = in_channels 

114 oc_offset = pid_oc * BLOCK_OC + tl.arange(0, BLOCK_OC) 

115 

116 # Initialize accumulator 

117 accum = tl.zeros((BLOCK_N_OW, BLOCK_OC), dtype=tl.float32) 

118 

119 # Pointers setup 

120 input_base = ( 

121 input_pointer 

122 + (input_n_stride * batch_idx)[:, None] 

123 + (input_c_stride * pid_group * in_channels_per_group) 

124 ) 

125 weight_base = ( 

126 weight_pointer 

127 + (weight_ic_stride * pid_group * in_channels_per_group) 

128 + (weight_oc_stride * oc_offset)[None, :] 

129 ) 

130 

131 # Loop over input channels and kernel positions 

132 BLOCK_IC_COUNT = (in_channels_per_group + BLOCK_IC - 1) // BLOCK_IC 

133 for ic_k in range(BLOCK_IC_COUNT * kernel_width): 

134 ic_block = (ic_k // kernel_width) * BLOCK_IC 

135 k = ic_k % kernel_width 

136 

137 ic_offset = ic_block + tl.arange(0, BLOCK_IC) 

138 

139 # For transposed conv: out_w = in_w * stride - padding + k * dilation 

140 # So: in_w = (out_w + padding - k * dilation) / stride 

141 # We need in_w to be a valid integer index 

142 

143 # Calculate the input position that contributes to this output 

144 numerator = out_w_idx + padding_width - k * dilation_width 

145 

146 # Check if this is divisible by stride 

147 is_divisible = (numerator % stride_width) == 0 

148 in_w_idx = numerator // stride_width 

149 

150 # Load input values 

151 curr_input_pointer = ( 

152 input_base 

153 + (input_c_stride * ic_offset)[None, :] 

154 + (input_w_stride * in_w_idx)[:, None] 

155 ) 

156 input_mask = ( 

157 (batch_idx < batch_size)[:, None] 

158 & (ic_offset < in_channels_per_group)[None, :] 

159 & is_divisible[:, None] 

160 & (in_w_idx >= 0)[:, None] 

161 & (in_w_idx < input_width)[:, None] 

162 ) 

163 input_block = tl.load(curr_input_pointer, mask=input_mask, other=0.0) 

164 

165 # Load weight values 

166 # Weight shape: (in_channels, out_channels/groups, kernel_width) 

167 curr_weight_pointer = ( 

168 weight_base 

169 + (weight_ic_stride * ic_offset)[:, None] 

170 + (weight_w_stride * k) 

171 ) 

172 weight_mask = (ic_offset < in_channels_per_group)[:, None] & ( 

173 oc_offset < out_channels_per_group 

174 )[None, :] 

175 weight_block = tl.load(curr_weight_pointer, mask=weight_mask, other=0.0) 

176 

177 # Accumulate: input_block is [BLOCK_N_OW, BLOCK_IC], weight_block is [BLOCK_IC, BLOCK_OC] 

178 accum += tl.dot( 

179 input_block.to(tl.float32), weight_block.to(tl.float32), allow_tf32=False 

180 ) 

181 

182 # Add bias if present 

183 bias_ptr = bias_pointer + pid_group * out_channels_per_group + oc_offset 

184 bias_mask = oc_offset < out_channels_per_group 

185 bias = tl.load(bias_ptr, mask=bias_mask, other=0.0).to(tl.float32) 

186 accum += bias[None, :] 

187 

188 # Store output 

189 output_ptr = ( 

190 output_pointer 

191 + (output_n_stride * batch_idx)[:, None] 

192 + (output_c_stride * (pid_group * out_channels_per_group + oc_offset))[None, :] 

193 + (output_w_stride * out_w_idx)[:, None] 

194 ) 

195 output_mask = ( 

196 (batch_idx < batch_size)[:, None] 

197 & (oc_offset < out_channels_per_group)[None, :] 

198 & (out_w_idx < out_width)[:, None] 

199 ) 

200 tl.store(output_ptr, accum, mask=output_mask) 

201 

202 

203def conv_transpose1d( 

204 input, 

205 weight, 

206 bias=None, 

207 stride=1, 

208 padding=0, 

209 output_padding=0, 

210 groups=1, 

211 dilation=1, 

212): 

213 """ 

214 Applies a 1D transposed convolution operator over an input signal. 

215 

216 Args: 

217 input: Input tensor of shape (N, in_channels, L_in) 

218 weight: Filters of shape (in_channels, out_channels/groups, kernel_width) 

219 bias: Optional bias of shape (out_channels). Default: None 

220 stride: Stride of the convolution. Default: 1 

221 padding: Zero-padding added to both sides. Default: 0 

222 output_padding: Additional size added to output shape. Default: 0 

223 groups: Number of blocked connections. Default: 1 

224 dilation: Spacing between kernel elements. Default: 1 

225 

226 Returns: 

227 Output tensor of shape (N, out_channels, L_out) 

228 """ 

229 logger.debug("GEMS CONV_TRANSPOSE1D") 

230 

231 assert input.ndim == 3, f"Input must be 3D, received shape {input.shape}" 

232 assert weight.ndim == 3, f"Weights must be 3D, received shape {weight.shape}" 

233 assert ( 

234 bias is None or bias.ndim == 1 

235 ), f"Bias must be 1D, received shape {bias.shape}" 

236 

237 # Parse stride, padding, output_padding, dilation 

238 if isinstance(stride, (list, tuple)): 

239 stride_width = stride[0] 

240 else: 

241 stride_width = stride 

242 

243 if isinstance(padding, (list, tuple)): 

244 padding_width = padding[0] 

245 else: 

246 padding_width = padding 

247 

248 if isinstance(output_padding, (list, tuple)): 

249 output_padding_width = output_padding[0] 

250 else: 

251 output_padding_width = output_padding 

252 

253 if isinstance(dilation, (list, tuple)): 

254 dilation_width = dilation[0] 

255 else: 

256 dilation_width = dilation 

257 

258 batch_size, in_channels, input_width = input.shape 

259 in_channels_weight, out_channels_per_group, kernel_width = weight.shape 

260 

261 assert ( 

262 in_channels == in_channels_weight 

263 ), f"Input channels ({in_channels}) must match weight in_channels ({in_channels_weight})" 

264 assert ( 

265 in_channels % groups == 0 

266 ), f"in_channels ({in_channels}) must be divisible by groups ({groups})" 

267 

268 out_channels = out_channels_per_group * groups 

269 

270 assert ( 

271 bias is None or bias.shape[0] == out_channels 

272 ), f"Bias shape ({bias.shape}) doesn't match out_channels ({out_channels})" 

273 

274 # Calculate output size 

275 out_width = conv_transpose1d_output_size( 

276 input_width, 

277 kernel_width, 

278 stride_width, 

279 padding_width, 

280 output_padding_width, 

281 dilation_width, 

282 ) 

283 

284 # Allocate output 

285 output_dtype = input.dtype 

286 output = torch.empty( 

287 (batch_size, out_channels, out_width), 

288 device=input.device, 

289 dtype=output_dtype, 

290 ) 

291 

292 # Grid: (batch * out_width blocks, out_channels blocks, groups) 

293 grid = lambda META: ( 

294 triton.cdiv(batch_size * out_width, META["BLOCK_N_OW"]), 

295 triton.cdiv(out_channels_per_group, META["BLOCK_OC"]), 

296 groups, 

297 ) 

298 

299 # Create bias pointer (zeros if no bias) 

300 if bias is None: 

301 bias_pointer = torch.zeros( 

302 out_channels, device=input.device, dtype=output_dtype 

303 ) 

304 else: 

305 bias_pointer = bias 

306 

307 # Ensure contiguous tensors 

308 input_contig = input.contiguous() 

309 weight_contig = weight.contiguous() 

310 

311 in_channels_per_group = in_channels // groups 

312 

313 conv_transpose1d_forward_kernel[grid]( 

314 input_contig, 

315 weight_contig, 

316 output, 

317 bias_pointer, 

318 batch_size, 

319 input_width, 

320 out_channels, 

321 out_width, 

322 *input_contig.stride(), 

323 *weight_contig.stride(), 

324 *output.stride(), 

325 in_channels_per_group, 

326 kernel_width, 

327 stride_width, 

328 padding_width, 

329 dilation_width, 

330 groups=groups, 

331 ) 

332 

333 return output