Coverage for src/flag_gems/ops/conv_transpose2d.py: 45%

750 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1"""Triton implementation of ``torch.nn.functional.conv_transpose2d``. 

2 

3The implementation uses semantic, parameter-regime dispatch only: a direct 

4tiled path for common dense group=1 cases, a pointwise 1x1 path, a scatter path 

5for no-overlap sparse-output cases, and a full residue path for the supported 

6PyTorch API surface. There are no shape-specific dispatch constants. 

7""" 

8 

9import logging 

10 

11import torch 

12import triton 

13import triton.language as tl 

14 

15from flag_gems.utils import libentry 

16 

17logger = logging.getLogger(__name__) 

18 

19_TRITON_DIRECT_LOWP_DTYPES = (torch.float16, torch.bfloat16) 

20 

21_GENERAL_TRITON_DTYPES = (torch.float32, torch.float16, torch.bfloat16) 

22 

23_DIRECT_TILED_FAMILY_MAX_CHANNELS = 256 

24_DIRECT_TILED_FAMILY_MAX_KERNEL = 5 

25_DIRECT_TILED_FAMILY_MAX_STRIDE = 4 

26_DIRECT_TILED_OUTPUT_PADDING_MIN_INPUT_ELEMENTS = 1024 

27_DIRECT_TILED_DEFAULT_SCHEDULE = (64, 32, 32, 4) 

28_DIRECT_STRIDE2_PAD1_3X3_MAX_CHANNELS = 256 

29 

30 

31def _pair(value): 

32 if isinstance(value, (list, tuple)): 

33 if len(value) != 2: 

34 raise RuntimeError("expected a single int or a pair of ints") 

35 return int(value[0]), int(value[1]) 

36 return int(value), int(value) 

37 

38 

39def _direct_tiled_family_params( 

40 input, 

41 weight, 

42 bias, 

43 stride_h, 

44 stride_w, 

45 padding_h, 

46 padding_w, 

47 output_padding_h, 

48 output_padding_w, 

49 groups, 

50 dilation_h, 

51 dilation_w, 

52): 

53 if bias is not None or groups != 1: 

54 return None 

55 if (dilation_h, dilation_w) != (1, 1): 

56 return None 

57 if input.dtype not in _GENERAL_TRITON_DTYPES or weight.dtype != input.dtype: 

58 return None 

59 if input.device.type != "cuda" or weight.device != input.device: 

60 return None 

61 if input.dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported(): 

62 return None 

63 if input.dim() != 4 or weight.dim() != 4: 

64 return None 

65 if not input.is_contiguous() or not weight.is_contiguous(): 

66 return None 

67 if stride_h != stride_w or padding_h != padding_w: 

68 return None 

69 if output_padding_h != output_padding_w: 

70 return None 

71 if stride_h <= 0 or stride_h > _DIRECT_TILED_FAMILY_MAX_STRIDE: 

72 return None 

73 if padding_h < 0 or output_padding_h < 0: 

74 return None 

75 

76 batch, input_channels, input_height, input_width = input.shape 

77 weight_input_channels, output_channels, weight_height, weight_width = weight.shape 

78 if batch <= 0 or input_height <= 0 or input_width <= 0: 

79 return None 

80 if input_channels != weight_input_channels: 

81 return None 

82 if input_channels < 16 or output_channels < 16: 

83 return None 

84 if ( 

85 input_channels > _DIRECT_TILED_FAMILY_MAX_CHANNELS 

86 or output_channels > _DIRECT_TILED_FAMILY_MAX_CHANNELS 

87 ): 

88 return None 

89 if ( 

90 weight_height <= 0 

91 or weight_width <= 0 

92 or weight_height > _DIRECT_TILED_FAMILY_MAX_KERNEL 

93 or weight_width > _DIRECT_TILED_FAMILY_MAX_KERNEL 

94 ): 

95 return None 

96 output_height = ( 

97 (input_height - 1) * stride_h - 2 * padding_h + weight_height + output_padding_h 

98 ) 

99 output_width = ( 

100 (input_width - 1) * stride_w - 2 * padding_w + weight_width + output_padding_w 

101 ) 

102 if output_height <= 0 or output_width <= 0: 

103 return None 

104 return ( 

105 batch, 

106 input_channels, 

107 input_height, 

108 input_width, 

109 output_channels, 

110 weight_height, 

111 weight_width, 

112 stride_h, 

113 padding_h, 

114 ) 

115 

116 

117def _can_use_direct_tiled_family( 

118 input, 

119 direct_tiled_family_params, 

120 output_padding_h, 

121): 

122 if direct_tiled_family_params is None: 

123 return False 

124 ( 

125 batch, 

126 input_channels, 

127 input_height, 

128 input_width, 

129 output_channels, 

130 weight_height, 

131 weight_width, 

132 stride_h, 

133 _padding_h, 

134 ) = direct_tiled_family_params 

135 

136 if output_padding_h == 0 and stride_h <= 2: 

137 return True 

138 input_elements = batch * input_height * input_width 

139 if ( 

140 input.dtype in _GENERAL_TRITON_DTYPES 

141 and stride_h == 2 

142 and output_padding_h == 1 

143 and weight_height == 3 

144 and weight_width == 3 

145 and input_channels >= 64 

146 and output_channels <= 64 

147 and input_elements >= _DIRECT_TILED_OUTPUT_PADDING_MIN_INPUT_ELEMENTS 

148 ): 

149 return True 

150 if stride_h >= 3 and output_padding_h == 0: 

151 if weight_height >= 5 or weight_width >= 5: 

152 return True 

153 if input.dtype in _TRITON_DIRECT_LOWP_DTYPES: 

154 return True 

155 return False 

156 

157 

158def _unsupported_conv_transpose2d( 

159 input, 

160 weight, 

161 bias, 

162 stride_h, 

163 stride_w, 

164 padding_h, 

165 padding_w, 

166 output_padding_h, 

167 output_padding_w, 

168 groups, 

169 dilation_h, 

170 dilation_w, 

171): 

172 bias_dtype = None if bias is None else bias.dtype 

173 raise NotImplementedError( 

174 "flag_gems.conv_transpose2d supports 3D or 4D CUDA input tensors " 

175 "and 4D CUDA weight tensors with float32, float16, or bfloat16 dtype; got " 

176 f"input_shape={tuple(input.shape)}, weight_shape={tuple(weight.shape)}, " 

177 f"input_dtype={input.dtype}, weight_dtype={weight.dtype}, bias_dtype={bias_dtype}, " 

178 f"input_device={input.device}, weight_device={weight.device}, " 

179 f"stride=({stride_h}, {stride_w}), padding=({padding_h}, {padding_w}), " 

180 f"output_padding=({output_padding_h}, {output_padding_w}), groups={groups}, " 

181 f"dilation=({dilation_h}, {dilation_w})" 

182 ) 

183 

184 

185def _validate_conv_transpose2d_args( 

186 input, 

187 weight, 

188 bias, 

189 stride_h, 

190 stride_w, 

191 padding_h, 

192 padding_w, 

193 output_padding_h, 

194 output_padding_w, 

195 groups, 

196 dilation_h, 

197 dilation_w, 

198): 

199 if input.device.type != "cuda" or weight.device != input.device: 

200 return False 

201 if input.dim() != 4 or weight.dim() != 4: 

202 return False 

203 if not input.is_contiguous() or not weight.is_contiguous(): 

204 return False 

205 if input.dtype not in _GENERAL_TRITON_DTYPES or weight.dtype != input.dtype: 

206 return False 

207 if input.dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported(): 

208 return False 

209 if bias is not None: 

210 if bias.device != input.device or bias.dtype != input.dtype: 

211 return False 

212 if bias.dim() != 1 or not bias.is_contiguous(): 

213 return False 

214 if groups <= 0: 

215 raise RuntimeError("groups must be a positive integer") 

216 if stride_h <= 0 or stride_w <= 0: 

217 raise RuntimeError("non-positive stride is not supported") 

218 if dilation_h <= 0 or dilation_w <= 0: 

219 raise RuntimeError("dilation should be greater than zero") 

220 if padding_h < 0 or padding_w < 0: 

221 raise RuntimeError("negative padding is not supported") 

222 if output_padding_h < 0 or output_padding_w < 0: 

223 raise RuntimeError("negative output_padding is not supported") 

224 if output_padding_h >= stride_h and output_padding_h >= dilation_h: 

225 raise RuntimeError( 

226 "output padding must be smaller than either stride or dilation" 

227 ) 

228 if output_padding_w >= stride_w and output_padding_w >= dilation_w: 

229 raise RuntimeError( 

230 "output padding must be smaller than either stride or dilation" 

231 ) 

232 

233 input_channels = input.shape[1] 

234 weight_input_channels = weight.shape[0] 

235 output_channels_per_group = weight.shape[1] 

236 weight_height = weight.shape[2] 

237 weight_width = weight.shape[3] 

238 if ( 

239 input_channels <= 0 

240 or output_channels_per_group <= 0 

241 or weight_height <= 0 

242 or weight_width <= 0 

243 ): 

244 raise RuntimeError( 

245 "non-empty input channels and weight dimensions are required" 

246 ) 

247 if input_channels != weight_input_channels: 

248 raise RuntimeError( 

249 "expected input channel dimension to match weight input channels" 

250 ) 

251 if input_channels % groups != 0: 

252 raise RuntimeError("input channels must be divisible by groups") 

253 output_channels = output_channels_per_group * groups 

254 if bias is not None and bias.numel() != output_channels: 

255 raise RuntimeError("expected bias to have one element per output channel") 

256 

257 input_height = input.shape[2] 

258 input_width = input.shape[3] 

259 output_height = ( 

260 (input_height - 1) * stride_h 

261 - 2 * padding_h 

262 + dilation_h * (weight_height - 1) 

263 + output_padding_h 

264 + 1 

265 ) 

266 output_width = ( 

267 (input_width - 1) * stride_w 

268 - 2 * padding_w 

269 + dilation_w * (weight_width - 1) 

270 + output_padding_w 

271 + 1 

272 ) 

273 if output_height <= 0 or output_width <= 0: 

274 raise RuntimeError("calculated output size is too small") 

275 return True 

276 

277 

278def _can_use_scatter_no_overlap( 

279 input, 

280 weight, 

281 stride_h, 

282 stride_w, 

283 dilation_h, 

284 dilation_w, 

285 groups, 

286): 

287 batch, input_channels, input_height, input_width = input.shape 

288 _, output_channels_per_group, weight_height, weight_width = weight.shape 

289 if batch <= 0 or input_height <= 0 or input_width <= 0: 

290 return False 

291 effective_kernel_h = (weight_height - 1) * dilation_h + 1 

292 effective_kernel_w = (weight_width - 1) * dilation_w + 1 

293 if stride_h < effective_kernel_h or stride_w < effective_kernel_w: 

294 return False 

295 

296 input_channels_per_group = input_channels // groups 

297 if input_channels_per_group > 128 or output_channels_per_group > 128: 

298 return False 

299 return weight_height * weight_width <= 25 

300 

301 

302def _can_use_stride2_pad1_3x3_direct( 

303 input, 

304 weight, 

305 bias, 

306 stride_h, 

307 stride_w, 

308 padding_h, 

309 padding_w, 

310 output_padding_h, 

311 output_padding_w, 

312 groups, 

313 dilation_h, 

314 dilation_w, 

315): 

316 if bias is not None or groups != 1: 

317 return False 

318 if (dilation_h, dilation_w) != (1, 1): 

319 return False 

320 if (output_padding_h, output_padding_w) != (0, 0): 

321 return False 

322 if (stride_h, stride_w) != (2, 2) or (padding_h, padding_w) != (1, 1): 

323 return False 

324 if input.dim() != 4 or weight.dim() != 4: 

325 return False 

326 if input.device.type != "cuda" or weight.device != input.device: 

327 return False 

328 if input.dtype not in _GENERAL_TRITON_DTYPES or weight.dtype != input.dtype: 

329 return False 

330 if input.dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported(): 

331 return False 

332 if not input.is_contiguous() or not weight.is_contiguous(): 

333 return False 

334 

335 batch, input_channels, input_height, input_width = input.shape 

336 weight_input_channels, output_channels, weight_height, weight_width = weight.shape 

337 if batch <= 0 or input_height <= 0 or input_width <= 0: 

338 return False 

339 if input_channels != weight_input_channels: 

340 return False 

341 if (weight_height, weight_width) != (3, 3): 

342 return False 

343 if input_channels < 16 or output_channels < 16: 

344 return False 

345 if ( 

346 input_channels > _DIRECT_STRIDE2_PAD1_3X3_MAX_CHANNELS 

347 or output_channels > _DIRECT_STRIDE2_PAD1_3X3_MAX_CHANNELS 

348 ): 

349 return False 

350 if input.dtype is torch.float32: 

351 return True 

352 if ( 

353 input.dtype is torch.float16 

354 and input_channels <= 32 

355 and output_channels >= 64 

356 and input_height <= 16 

357 ): 

358 return True 

359 return ( 

360 input.dtype is torch.bfloat16 

361 and input_channels >= 64 

362 and output_channels <= 32 

363 and input_height <= 16 

364 ) 

365 

366 

367@libentry() 

368@triton.jit 

369def _conv_transpose2d_direct_kernel( 

370 input_pointer, 

371 weight_pointer, 

372 output_pointer, 

373 batch_size: tl.constexpr, 

374 input_height: tl.constexpr, 

375 input_width: tl.constexpr, 

376 output_channels: tl.constexpr, 

377 output_height: tl.constexpr, 

378 output_width: tl.constexpr, 

379 input_n_stride: tl.constexpr, 

380 input_c_stride: tl.constexpr, 

381 input_height_stride: tl.constexpr, 

382 input_width_stride: tl.constexpr, 

383 weight_ci_stride: tl.constexpr, 

384 weight_co_stride: tl.constexpr, 

385 weight_height_stride: tl.constexpr, 

386 weight_width_stride: tl.constexpr, 

387 output_n_stride: tl.constexpr, 

388 output_c_stride: tl.constexpr, 

389 output_height_stride: tl.constexpr, 

390 output_width_stride: tl.constexpr, 

391 input_channels: tl.constexpr, 

392 weight_height: tl.constexpr, 

393 weight_width: tl.constexpr, 

394 stride_height: tl.constexpr, 

395 stride_width: tl.constexpr, 

396 padding_height: tl.constexpr, 

397 padding_width: tl.constexpr, 

398 BLOCK_NHW: tl.constexpr, 

399 BLOCK_CI: tl.constexpr, 

400 BLOCK_CO: tl.constexpr, 

401): 

402 pid_nhw = tl.program_id(0) 

403 pid_co = tl.program_id(1) 

404 pid_subgrid = tl.program_id(2) 

405 

406 output_residue_h = pid_subgrid // stride_width 

407 output_residue_w = pid_subgrid % stride_width 

408 compact_height: tl.constexpr = (output_height + stride_height - 1) // stride_height 

409 compact_width: tl.constexpr = (output_width + stride_width - 1) // stride_width 

410 

411 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW) 

412 compact_plane: tl.constexpr = compact_height * compact_width 

413 compact_nh = compact_offsets // compact_width 

414 compact_h = compact_nh % compact_height 

415 compact_w = compact_offsets % compact_width 

416 n = compact_offsets // compact_plane 

417 oh = compact_h * stride_height + output_residue_h 

418 ow = compact_w * stride_width + output_residue_w 

419 co_offsets = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

420 

421 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32) 

422 ci_blocks: tl.constexpr = tl.cdiv(input_channels, BLOCK_CI) 

423 height_residue = (output_residue_h + padding_height) % stride_height 

424 width_residue = (output_residue_w + padding_width) % stride_width 

425 for kh in range(weight_height): 

426 if kh % stride_height == height_residue: 

427 ih_unstrided = oh + padding_height - kh 

428 ih = ih_unstrided // stride_height 

429 valid_h = (ih_unstrided >= 0) & (ih < input_height) 

430 for kw in range(weight_width): 

431 if kw % stride_width == width_residue: 

432 iw_unstrided = ow + padding_width - kw 

433 iw = iw_unstrided // stride_width 

434 valid_hw = ( 

435 (n < batch_size) 

436 & valid_h 

437 & (iw_unstrided >= 0) 

438 & (iw < input_width) 

439 & (oh < output_height) 

440 & (ow < output_width) 

441 ) 

442 for ci_base in range(ci_blocks): 

443 ci_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI) 

444 input_offsets = ( 

445 n[:, None] * input_n_stride 

446 + ci_offsets[None, :] * input_c_stride 

447 + ih[:, None] * input_height_stride 

448 + iw[:, None] * input_width_stride 

449 ) 

450 weight_offsets = ( 

451 ci_offsets[:, None] * weight_ci_stride 

452 + co_offsets[None, :] * weight_co_stride 

453 + kh * weight_height_stride 

454 + kw * weight_width_stride 

455 ) 

456 input_mask = valid_hw[:, None] & ( 

457 ci_offsets[None, :] < input_channels 

458 ) 

459 weight_mask = (ci_offsets[:, None] < input_channels) & ( 

460 co_offsets[None, :] < output_channels 

461 ) 

462 input_block = tl.load( 

463 input_pointer + input_offsets, mask=input_mask, other=0.0 

464 ) 

465 weight_block = tl.load( 

466 weight_pointer + weight_offsets, mask=weight_mask, other=0.0 

467 ) 

468 accum += tl.dot( 

469 input_block, 

470 weight_block, 

471 input_precision="tf32x3", 

472 ) 

473 

474 output_offsets = ( 

475 n[:, None] * output_n_stride 

476 + co_offsets[None, :] * output_c_stride 

477 + oh[:, None] * output_height_stride 

478 + ow[:, None] * output_width_stride 

479 ) 

480 output_mask = ( 

481 (n[:, None] < batch_size) 

482 & (oh[:, None] < output_height) 

483 & (ow[:, None] < output_width) 

484 & (co_offsets[None, :] < output_channels) 

485 ) 

486 tl.store(output_pointer + output_offsets, accum, mask=output_mask) 

487 

488 

489@libentry() 

490@triton.jit 

491def _conv_transpose2d_stride2_pad1_3x3_kernel( 

492 input_pointer, 

493 weight_pointer, 

494 output_pointer, 

495 batch_size: tl.constexpr, 

496 input_height: tl.constexpr, 

497 input_width: tl.constexpr, 

498 output_channels: tl.constexpr, 

499 output_height: tl.constexpr, 

500 output_width: tl.constexpr, 

501 compact_height: tl.constexpr, 

502 compact_width: tl.constexpr, 

503 input_n_stride: tl.constexpr, 

504 input_c_stride: tl.constexpr, 

505 input_height_stride: tl.constexpr, 

506 input_width_stride: tl.constexpr, 

507 weight_ci_stride: tl.constexpr, 

508 weight_co_stride: tl.constexpr, 

509 weight_height_stride: tl.constexpr, 

510 weight_width_stride: tl.constexpr, 

511 output_n_stride: tl.constexpr, 

512 output_c_stride: tl.constexpr, 

513 output_height_stride: tl.constexpr, 

514 output_width_stride: tl.constexpr, 

515 input_channels: tl.constexpr, 

516 BLOCK_NHW: tl.constexpr, 

517 BLOCK_CI: tl.constexpr, 

518 BLOCK_CO: tl.constexpr, 

519): 

520 pid_raw = tl.program_id(0) 

521 phase = pid_raw % 4 

522 pid_nhw = pid_raw // 4 

523 pid_co = tl.program_id(1) 

524 

525 residue_h = phase // 2 

526 residue_w = phase % 2 

527 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW) 

528 compact_plane: tl.constexpr = compact_height * compact_width 

529 compact_nh = compact_offsets // compact_width 

530 compact_h = compact_nh % compact_height 

531 compact_w = compact_offsets % compact_width 

532 n = compact_offsets // compact_plane 

533 oh = compact_h * 2 + residue_h 

534 ow = compact_w * 2 + residue_w 

535 co_offsets = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

536 

537 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32) 

538 ci_blocks: tl.constexpr = tl.cdiv(input_channels, BLOCK_CI) 

539 height_residue = (residue_h + 1) % 2 

540 width_residue = (residue_w + 1) % 2 

541 for kh_slot in range(2): 

542 kh = height_residue + kh_slot * 2 

543 valid_kh = kh < 3 

544 ih_unstrided = oh + 1 - kh 

545 ih = ih_unstrided // 2 

546 valid_h = valid_kh & (ih_unstrided >= 0) & (ih < input_height) 

547 for kw_slot in range(2): 

548 kw = width_residue + kw_slot * 2 

549 valid_kw = kw < 3 

550 iw_unstrided = ow + 1 - kw 

551 iw = iw_unstrided // 2 

552 valid_hw = ( 

553 (n < batch_size) 

554 & valid_h 

555 & valid_kw 

556 & (iw_unstrided >= 0) 

557 & (iw < input_width) 

558 & (oh < output_height) 

559 & (ow < output_width) 

560 ) 

561 for ci_base in range(ci_blocks): 

562 ci_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI) 

563 input_offsets = ( 

564 n[:, None] * input_n_stride 

565 + ci_offsets[None, :] * input_c_stride 

566 + ih[:, None] * input_height_stride 

567 + iw[:, None] * input_width_stride 

568 ) 

569 weight_offsets = ( 

570 ci_offsets[:, None] * weight_ci_stride 

571 + co_offsets[None, :] * weight_co_stride 

572 + kh * weight_height_stride 

573 + kw * weight_width_stride 

574 ) 

575 input_mask = valid_hw[:, None] & (ci_offsets[None, :] < input_channels) 

576 weight_mask = ( 

577 (ci_offsets[:, None] < input_channels) 

578 & (co_offsets[None, :] < output_channels) 

579 & valid_kh 

580 & valid_kw 

581 ) 

582 input_block = tl.load( 

583 input_pointer + input_offsets, mask=input_mask, other=0.0 

584 ) 

585 weight_block = tl.load( 

586 weight_pointer + weight_offsets, mask=weight_mask, other=0.0 

587 ) 

588 accum += tl.dot( 

589 input_block, 

590 weight_block, 

591 input_precision="tf32x3", 

592 ) 

593 

594 output_offsets = ( 

595 n[:, None] * output_n_stride 

596 + co_offsets[None, :] * output_c_stride 

597 + oh[:, None] * output_height_stride 

598 + ow[:, None] * output_width_stride 

599 ) 

600 output_mask = ( 

601 (n[:, None] < batch_size) 

602 & (oh[:, None] < output_height) 

603 & (ow[:, None] < output_width) 

604 & (co_offsets[None, :] < output_channels) 

605 ) 

606 tl.store(output_pointer + output_offsets, accum, mask=output_mask) 

607 

608 

609@libentry() 

610@triton.jit 

611def _conv_transpose2d_residue_kernel( 

612 input_pointer, 

613 weight_pointer, 

614 bias_pointer, 

615 output_pointer, 

616 batch_size: tl.constexpr, 

617 input_channels: tl.constexpr, 

618 input_height: tl.constexpr, 

619 input_width: tl.constexpr, 

620 output_channels: tl.constexpr, 

621 output_height: tl.constexpr, 

622 output_width: tl.constexpr, 

623 weight_height: tl.constexpr, 

624 weight_width: tl.constexpr, 

625 output_channels_per_group: tl.constexpr, 

626 input_channels_per_group: tl.constexpr, 

627 stride_height: tl.constexpr, 

628 stride_width: tl.constexpr, 

629 padding_height: tl.constexpr, 

630 padding_width: tl.constexpr, 

631 dilation_height: tl.constexpr, 

632 dilation_width: tl.constexpr, 

633 has_bias: tl.constexpr, 

634 n_subgrids: tl.constexpr, 

635 BLOCK_NHW: tl.constexpr, 

636 BLOCK_CI: tl.constexpr, 

637 BLOCK_CO: tl.constexpr, 

638): 

639 pid_nhw = tl.program_id(0) 

640 pid_co_in_group = tl.program_id(1) 

641 pid_phase_group = tl.program_id(2) 

642 

643 pid_subgrid = pid_phase_group % n_subgrids 

644 group = pid_phase_group // n_subgrids 

645 output_residue_h = pid_subgrid // stride_width 

646 output_residue_w = pid_subgrid % stride_width 

647 compact_height: tl.constexpr = (output_height + stride_height - 1) // stride_height 

648 compact_width: tl.constexpr = (output_width + stride_width - 1) // stride_width 

649 compact_plane: tl.constexpr = compact_height * compact_width 

650 

651 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW) 

652 compact_nh = compact_offsets // compact_width 

653 compact_h = compact_nh % compact_height 

654 compact_w = compact_offsets % compact_width 

655 n = compact_offsets // compact_plane 

656 oh = compact_h * stride_height + output_residue_h 

657 ow = compact_w * stride_width + output_residue_w 

658 

659 co_in_offsets = pid_co_in_group * BLOCK_CO + tl.arange(0, BLOCK_CO) 

660 co_offsets = group * output_channels_per_group + co_in_offsets 

661 

662 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32) 

663 if has_bias: 

664 bias_values = tl.load( 

665 bias_pointer + co_offsets, 

666 mask=co_in_offsets < output_channels_per_group, 

667 other=0.0, 

668 ).to(tl.float32) 

669 accum += bias_values[None, :] 

670 

671 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI) 

672 height_residue = (output_residue_h + padding_height) % stride_height 

673 width_residue = (output_residue_w + padding_width) % stride_width 

674 for kh in range(weight_height): 

675 kh_residue: tl.constexpr = (kh * dilation_height) % stride_height 

676 if kh_residue == height_residue: 

677 ih_unstrided = oh + padding_height - kh * dilation_height 

678 ih = ih_unstrided // stride_height 

679 valid_h = (n < batch_size) & (ih_unstrided >= 0) & (ih < input_height) 

680 for kw in range(weight_width): 

681 kw_residue: tl.constexpr = (kw * dilation_width) % stride_width 

682 if kw_residue == width_residue: 

683 iw_unstrided = ow + padding_width - kw * dilation_width 

684 iw = iw_unstrided // stride_width 

685 valid_hw = ( 

686 valid_h 

687 & (iw_unstrided >= 0) 

688 & (iw < input_width) 

689 & (oh < output_height) 

690 & (ow < output_width) 

691 ) 

692 for ci_base in range(ci_blocks): 

693 ci_in_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI) 

694 ci_offsets = group * input_channels_per_group + ci_in_offsets 

695 input_offsets = ( 

696 n[:, None] * input_channels + ci_offsets[None, :] 

697 ) * input_height 

698 input_offsets = ( 

699 input_offsets + ih[:, None] 

700 ) * input_width + iw[:, None] 

701 weight_offsets = ( 

702 ci_offsets[:, None] * output_channels_per_group 

703 + co_in_offsets[None, :] 

704 ) * weight_height 

705 weight_offsets = (weight_offsets + kh) * weight_width + kw 

706 input_mask = valid_hw[:, None] & ( 

707 ci_in_offsets[None, :] < input_channels_per_group 

708 ) 

709 weight_mask = ( 

710 ci_in_offsets[:, None] < input_channels_per_group 

711 ) & (co_in_offsets[None, :] < output_channels_per_group) 

712 input_block = tl.load( 

713 input_pointer + input_offsets, mask=input_mask, other=0.0 

714 ) 

715 weight_block = tl.load( 

716 weight_pointer + weight_offsets, mask=weight_mask, other=0.0 

717 ) 

718 accum += tl.dot( 

719 input_block, 

720 weight_block, 

721 input_precision="tf32x3", 

722 ) 

723 

724 output_offsets = n[:, None] * output_channels + co_offsets[None, :] 

725 output_offsets = (output_offsets * output_height + oh[:, None]) * output_width 

726 output_offsets = output_offsets + ow[:, None] 

727 output_mask = ( 

728 (n[:, None] < batch_size) 

729 & (oh[:, None] < output_height) 

730 & (ow[:, None] < output_width) 

731 & (co_in_offsets[None, :] < output_channels_per_group) 

732 & (co_offsets[None, :] < output_channels) 

733 ) 

734 tl.store(output_pointer + output_offsets, accum, mask=output_mask) 

735 

736 

737@libentry() 

738@triton.jit 

739def _conv_transpose2d_general_kernel( 

740 input_pointer, 

741 weight_pointer, 

742 bias_pointer, 

743 output_pointer, 

744 total_elements: tl.constexpr, 

745 batch_size: tl.constexpr, 

746 input_channels: tl.constexpr, 

747 input_height: tl.constexpr, 

748 input_width: tl.constexpr, 

749 output_channels: tl.constexpr, 

750 output_height: tl.constexpr, 

751 output_width: tl.constexpr, 

752 weight_height: tl.constexpr, 

753 weight_width: tl.constexpr, 

754 output_channels_per_group: tl.constexpr, 

755 input_channels_per_group: tl.constexpr, 

756 stride_height: tl.constexpr, 

757 stride_width: tl.constexpr, 

758 padding_height: tl.constexpr, 

759 padding_width: tl.constexpr, 

760 dilation_height: tl.constexpr, 

761 dilation_width: tl.constexpr, 

762 has_bias: tl.constexpr, 

763 BLOCK_SIZE: tl.constexpr, 

764): 

765 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

766 mask = offsets < total_elements 

767 

768 tmp = offsets // output_width 

769 ow = offsets - tmp * output_width 

770 tmp2 = tmp // output_height 

771 oh = tmp - tmp2 * output_height 

772 n = tmp2 // output_channels 

773 co = tmp2 - n * output_channels 

774 

775 group = co // output_channels_per_group 

776 co_in_group = co - group * output_channels_per_group 

777 accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

778 

779 if has_bias: 

780 bias = tl.load(bias_pointer + co, mask=mask, other=0.0).to(tl.float32) 

781 accum += bias 

782 

783 for ci_in_group in tl.range(0, input_channels_per_group): 

784 ci = group * input_channels_per_group + ci_in_group 

785 for kh in tl.static_range(0, weight_height): 

786 ih_unstrided = oh + padding_height - kh * dilation_height 

787 ih = ih_unstrided // stride_height 

788 valid_h = (ih_unstrided % stride_height == 0) & (ih >= 0) 

789 valid_h = valid_h & (ih < input_height) 

790 for kw in tl.static_range(0, weight_width): 

791 iw_unstrided = ow + padding_width - kw * dilation_width 

792 iw = iw_unstrided // stride_width 

793 valid = mask & valid_h 

794 valid = valid & (iw_unstrided % stride_width == 0) 

795 valid = valid & (iw >= 0) & (iw < input_width) 

796 

797 input_offsets = (n * input_channels + ci) * input_height + ih 

798 input_offsets = input_offsets * input_width + iw 

799 weight_offsets = ( 

800 ci * output_channels_per_group + co_in_group 

801 ) * weight_height 

802 weight_offsets = (weight_offsets + kh) * weight_width + kw 

803 input_values = tl.load( 

804 input_pointer + input_offsets, mask=valid, other=0.0 

805 ).to(tl.float32) 

806 weight_values = tl.load( 

807 weight_pointer + weight_offsets, mask=valid, other=0.0 

808 ).to(tl.float32) 

809 accum += input_values * weight_values 

810 

811 tl.store(output_pointer + offsets, accum, mask=mask) 

812 

813 

814@libentry() 

815@triton.jit 

816def _conv_transpose2d_residue_static_kernel( 

817 input_pointer, 

818 weight_pointer, 

819 bias_pointer, 

820 output_pointer, 

821 batch_size: tl.constexpr, 

822 input_channels: tl.constexpr, 

823 input_height: tl.constexpr, 

824 input_width: tl.constexpr, 

825 output_channels: tl.constexpr, 

826 output_height: tl.constexpr, 

827 output_width: tl.constexpr, 

828 compact_height: tl.constexpr, 

829 compact_width: tl.constexpr, 

830 weight_height: tl.constexpr, 

831 weight_width: tl.constexpr, 

832 output_channels_per_group: tl.constexpr, 

833 input_channels_per_group: tl.constexpr, 

834 stride_height: tl.constexpr, 

835 stride_width: tl.constexpr, 

836 padding_height: tl.constexpr, 

837 padding_width: tl.constexpr, 

838 dilation_height: tl.constexpr, 

839 dilation_width: tl.constexpr, 

840 has_bias: tl.constexpr, 

841 output_residue_h: tl.constexpr, 

842 output_residue_w: tl.constexpr, 

843 co_blocks_per_group: tl.constexpr, 

844 BLOCK_NHW: tl.constexpr, 

845 BLOCK_CI: tl.constexpr, 

846 BLOCK_CO: tl.constexpr, 

847): 

848 pid_nhw = tl.program_id(0) 

849 pid_gco = tl.program_id(1) 

850 

851 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW) 

852 compact_plane: tl.constexpr = compact_height * compact_width 

853 compact_nh = compact_offsets // compact_width 

854 compact_h = compact_nh % compact_height 

855 compact_w = compact_offsets % compact_width 

856 n = compact_offsets // compact_plane 

857 oh = compact_h * stride_height + output_residue_h 

858 ow = compact_w * stride_width + output_residue_w 

859 

860 group = pid_gco // co_blocks_per_group 

861 pid_co_in_group = pid_gco - group * co_blocks_per_group 

862 co_in_offsets = pid_co_in_group * BLOCK_CO + tl.arange(0, BLOCK_CO) 

863 co_offsets = group * output_channels_per_group + co_in_offsets 

864 

865 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32) 

866 if has_bias: 

867 bias_values = tl.load( 

868 bias_pointer + co_offsets, 

869 mask=co_in_offsets < output_channels_per_group, 

870 other=0.0, 

871 ).to(tl.float32) 

872 accum += bias_values[None, :] 

873 

874 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI) 

875 height_residue: tl.constexpr = (output_residue_h + padding_height) % stride_height 

876 width_residue: tl.constexpr = (output_residue_w + padding_width) % stride_width 

877 for kh in tl.static_range(0, weight_height): 

878 if (kh * dilation_height) % stride_height == height_residue: 

879 ih_unstrided = oh + padding_height - kh * dilation_height 

880 ih = ih_unstrided // stride_height 

881 valid_h = (n < batch_size) & (ih_unstrided >= 0) & (ih < input_height) 

882 for kw in tl.static_range(0, weight_width): 

883 if (kw * dilation_width) % stride_width == width_residue: 

884 iw_unstrided = ow + padding_width - kw * dilation_width 

885 iw = iw_unstrided // stride_width 

886 valid_hw = ( 

887 valid_h 

888 & (iw_unstrided >= 0) 

889 & (iw < input_width) 

890 & (oh < output_height) 

891 & (ow < output_width) 

892 ) 

893 for ci_base in range(ci_blocks): 

894 ci_in_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI) 

895 ci_offsets = group * input_channels_per_group + ci_in_offsets 

896 input_offsets = ( 

897 n[:, None] * input_channels + ci_offsets[None, :] 

898 ) * input_height 

899 input_offsets = ( 

900 input_offsets + ih[:, None] 

901 ) * input_width + iw[:, None] 

902 weight_offsets = ( 

903 ci_offsets[:, None] * output_channels_per_group 

904 + co_in_offsets[None, :] 

905 ) * weight_height 

906 weight_offsets = (weight_offsets + kh) * weight_width + kw 

907 input_mask = valid_hw[:, None] & ( 

908 ci_in_offsets[None, :] < input_channels_per_group 

909 ) 

910 weight_mask = ( 

911 ci_in_offsets[:, None] < input_channels_per_group 

912 ) & (co_in_offsets[None, :] < output_channels_per_group) 

913 input_block = tl.load( 

914 input_pointer + input_offsets, mask=input_mask, other=0.0 

915 ) 

916 weight_block = tl.load( 

917 weight_pointer + weight_offsets, mask=weight_mask, other=0.0 

918 ) 

919 accum += tl.dot( 

920 input_block, 

921 weight_block, 

922 input_precision="tf32x3", 

923 ) 

924 

925 output_offsets = n[:, None] * output_channels + co_offsets[None, :] 

926 output_offsets = (output_offsets * output_height + oh[:, None]) * output_width 

927 output_offsets = output_offsets + ow[:, None] 

928 output_mask = ( 

929 (n[:, None] < batch_size) 

930 & (oh[:, None] < output_height) 

931 & (ow[:, None] < output_width) 

932 & (co_in_offsets[None, :] < output_channels_per_group) 

933 & (co_offsets[None, :] < output_channels) 

934 ) 

935 tl.store(output_pointer + output_offsets, accum, mask=output_mask) 

936 

937 

938@libentry() 

939@triton.jit 

940def _conv_transpose2d_scatter_init_kernel( 

941 bias_pointer, 

942 output_pointer, 

943 total_elements: tl.constexpr, 

944 output_channels: tl.constexpr, 

945 output_height: tl.constexpr, 

946 output_width: tl.constexpr, 

947 has_bias: tl.constexpr, 

948 BLOCK_SIZE: tl.constexpr, 

949): 

950 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

951 mask = offsets < total_elements 

952 values = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 

953 if has_bias: 

954 spatial_size: tl.constexpr = output_height * output_width 

955 co = (offsets // spatial_size) % output_channels 

956 values = tl.load(bias_pointer + co, mask=mask, other=0.0).to(tl.float32) 

957 tl.store(output_pointer + offsets, values, mask=mask) 

958 

959 

960@libentry() 

961@triton.jit 

962def _conv_transpose2d_scatter_no_overlap_kernel( 

963 input_pointer, 

964 weight_pointer, 

965 bias_pointer, 

966 output_pointer, 

967 batch_size: tl.constexpr, 

968 input_channels: tl.constexpr, 

969 input_height: tl.constexpr, 

970 input_width: tl.constexpr, 

971 output_channels: tl.constexpr, 

972 output_height: tl.constexpr, 

973 output_width: tl.constexpr, 

974 weight_height: tl.constexpr, 

975 weight_width: tl.constexpr, 

976 output_channels_per_group: tl.constexpr, 

977 input_channels_per_group: tl.constexpr, 

978 stride_height: tl.constexpr, 

979 stride_width: tl.constexpr, 

980 padding_height: tl.constexpr, 

981 padding_width: tl.constexpr, 

982 dilation_height: tl.constexpr, 

983 dilation_width: tl.constexpr, 

984 has_bias: tl.constexpr, 

985 BLOCK_NHW: tl.constexpr, 

986 BLOCK_CI: tl.constexpr, 

987 BLOCK_CO: tl.constexpr, 

988): 

989 pid_nhw = tl.program_id(0) 

990 pid_co = tl.program_id(1) 

991 pid_gkk = tl.program_id(2) 

992 

993 kw = pid_gkk % weight_width 

994 tmp = pid_gkk // weight_width 

995 kh = tmp % weight_height 

996 group = tmp // weight_height 

997 

998 nhw_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW) 

999 iw = nhw_offsets % input_width 

1000 tmp = nhw_offsets // input_width 

1001 ih = tmp % input_height 

1002 n = tmp // input_height 

1003 

1004 oh = ih * stride_height - padding_height + kh * dilation_height 

1005 ow = iw * stride_width - padding_width + kw * dilation_width 

1006 valid_nhw = (nhw_offsets < batch_size * input_height * input_width) & ( 

1007 n < batch_size 

1008 ) 

1009 valid_nhw = valid_nhw & (oh >= 0) & (oh < output_height) 

1010 valid_nhw = valid_nhw & (ow >= 0) & (ow < output_width) 

1011 

1012 co_in_group = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) 

1013 co = group * output_channels_per_group + co_in_group 

1014 ci_in_group_base = tl.arange(0, BLOCK_CI) 

1015 

1016 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32) 

1017 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI) 

1018 for ci_block in range(ci_blocks): 

1019 ci_in_group = ci_block * BLOCK_CI + ci_in_group_base 

1020 ci = group * input_channels_per_group + ci_in_group 

1021 input_offsets = (n[:, None] * input_channels + ci[None, :]) * input_height 

1022 input_offsets = (input_offsets + ih[:, None]) * input_width + iw[:, None] 

1023 weight_offsets = ( 

1024 ci[:, None] * output_channels_per_group + co_in_group[None, :] 

1025 ) * weight_height 

1026 weight_offsets = (weight_offsets + kh) * weight_width + kw 

1027 

1028 ci_mask = ci_in_group < input_channels_per_group 

1029 co_mask = co_in_group < output_channels_per_group 

1030 input_block = tl.load( 

1031 input_pointer + input_offsets, 

1032 mask=valid_nhw[:, None] & ci_mask[None, :], 

1033 other=0.0, 

1034 ) 

1035 weight_block = tl.load( 

1036 weight_pointer + weight_offsets, 

1037 mask=ci_mask[:, None] & co_mask[None, :], 

1038 other=0.0, 

1039 ) 

1040 accum += tl.dot( 

1041 input_block, 

1042 weight_block, 

1043 input_precision="tf32x3", 

1044 ) 

1045 

1046 if has_bias: 

1047 bias = tl.load( 

1048 bias_pointer + co, 

1049 mask=co_in_group < output_channels_per_group, 

1050 other=0.0, 

1051 ).to(tl.float32) 

1052 accum += bias[None, :] 

1053 

1054 output_offsets = (n[:, None] * output_channels + co[None, :]) * output_height 

1055 output_offsets = (output_offsets + oh[:, None]) * output_width + ow[:, None] 

1056 output_mask = valid_nhw[:, None] & ( 

1057 co_in_group[None, :] < output_channels_per_group 

1058 ) 

1059 tl.store(output_pointer + output_offsets, accum, mask=output_mask) 

1060 

1061 

1062@libentry() 

1063@triton.jit 

1064def _conv_transpose2d_1x1_kernel( 

1065 input_pointer, 

1066 weight_pointer, 

1067 bias_pointer, 

1068 output_pointer, 

1069 batch_size: tl.constexpr, 

1070 input_channels: tl.constexpr, 

1071 input_height: tl.constexpr, 

1072 input_width: tl.constexpr, 

1073 output_channels: tl.constexpr, 

1074 output_channels_per_group: tl.constexpr, 

1075 input_channels_per_group: tl.constexpr, 

1076 has_bias: tl.constexpr, 

1077 co_blocks_per_group: tl.constexpr, 

1078 BLOCK_NHW: tl.constexpr, 

1079 BLOCK_CI: tl.constexpr, 

1080 BLOCK_CO: tl.constexpr, 

1081): 

1082 pid_nhw = tl.program_id(0) 

1083 pid_gco = tl.program_id(1) 

1084 

1085 group = pid_gco // co_blocks_per_group 

1086 pid_co_in_group = pid_gco - group * co_blocks_per_group 

1087 co_in_offsets = pid_co_in_group * BLOCK_CO + tl.arange(0, BLOCK_CO) 

1088 co_offsets = group * output_channels_per_group + co_in_offsets 

1089 

1090 nhw_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW) 

1091 iw = nhw_offsets % input_width 

1092 tmp = nhw_offsets // input_width 

1093 ih = tmp % input_height 

1094 n = tmp // input_height 

1095 valid_nhw = (nhw_offsets < batch_size * input_height * input_width) & ( 

1096 n < batch_size 

1097 ) 

1098 

1099 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32) 

1100 if has_bias: 

1101 bias_values = tl.load( 

1102 bias_pointer + co_offsets, 

1103 mask=co_in_offsets < output_channels_per_group, 

1104 other=0.0, 

1105 ).to(tl.float32) 

1106 accum += bias_values[None, :] 

1107 

1108 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI) 

1109 for ci_base in range(ci_blocks): 

1110 ci_in_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI) 

1111 ci_offsets = group * input_channels_per_group + ci_in_offsets 

1112 input_offsets = n[:, None] * input_channels + ci_offsets[None, :] 

1113 input_offsets = (input_offsets * input_height + ih[:, None]) * input_width 

1114 input_offsets = input_offsets + iw[:, None] 

1115 weight_offsets = ( 

1116 ci_offsets[:, None] * output_channels_per_group + co_in_offsets[None, :] 

1117 ) 

1118 ci_mask = ci_in_offsets < input_channels_per_group 

1119 co_mask = co_in_offsets < output_channels_per_group 

1120 input_block = tl.load( 

1121 input_pointer + input_offsets, 

1122 mask=valid_nhw[:, None] & ci_mask[None, :], 

1123 other=0.0, 

1124 ) 

1125 weight_block = tl.load( 

1126 weight_pointer + weight_offsets, 

1127 mask=ci_mask[:, None] & co_mask[None, :], 

1128 other=0.0, 

1129 ) 

1130 accum += tl.dot(input_block, weight_block, input_precision="tf32x3") 

1131 

1132 output_offsets = n[:, None] * output_channels + co_offsets[None, :] 

1133 output_offsets = (output_offsets * input_height + ih[:, None]) * input_width 

1134 output_offsets = output_offsets + iw[:, None] 

1135 output_mask = valid_nhw[:, None] & ( 

1136 co_in_offsets[None, :] < output_channels_per_group 

1137 ) 

1138 tl.store(output_pointer + output_offsets, accum, mask=output_mask) 

1139 

1140 

1141def _can_use_pointwise_1x1( 

1142 weight, 

1143 stride_h, 

1144 stride_w, 

1145 padding_h, 

1146 padding_w, 

1147 output_padding_h, 

1148 output_padding_w, 

1149): 

1150 return ( 

1151 weight.shape[2] == 1 

1152 and weight.shape[3] == 1 

1153 and stride_h == 1 

1154 and stride_w == 1 

1155 and padding_h == 0 

1156 and padding_w == 0 

1157 and output_padding_h == 0 

1158 and output_padding_w == 0 

1159 ) 

1160 

1161 

1162def _conv_transpose2d_pointwise_1x1(input, weight, bias, groups): 

1163 batch, input_channels, input_height, input_width = input.shape 

1164 _, output_channels_per_group, _weight_height, _weight_width = weight.shape 

1165 output_channels = output_channels_per_group * groups 

1166 output = torch.empty( 

1167 (batch, output_channels, input_height, input_width), 

1168 device=input.device, 

1169 dtype=input.dtype, 

1170 ) 

1171 if output.numel() == 0: 

1172 return output 

1173 

1174 input_channels_per_group = input_channels // groups 

1175 block_nhw = 128 if input.dtype is not torch.float32 else 64 

1176 block_ci = 16 if input.dtype is torch.float32 else 32 

1177 if input_channels_per_group <= 16: 

1178 block_ci = 16 

1179 block_co = 16 if output_channels_per_group <= 16 else 32 

1180 co_blocks_per_group = triton.cdiv(output_channels_per_group, block_co) 

1181 grid = ( 

1182 triton.cdiv(batch * input_height * input_width, block_nhw), 

1183 groups * co_blocks_per_group, 

1184 ) 

1185 bias_pointer = bias if bias is not None else input 

1186 _conv_transpose2d_1x1_kernel[grid]( 

1187 input, 

1188 weight, 

1189 bias_pointer, 

1190 output, 

1191 batch, 

1192 input_channels, 

1193 input_height, 

1194 input_width, 

1195 output_channels, 

1196 output_channels_per_group, 

1197 input_channels_per_group, 

1198 bias is not None, 

1199 co_blocks_per_group, 

1200 BLOCK_NHW=block_nhw, 

1201 BLOCK_CI=block_ci, 

1202 BLOCK_CO=block_co, 

1203 num_warps=4, 

1204 ) 

1205 return output 

1206 

1207 

1208def _conv_transpose2d_scatter_no_overlap( 

1209 input, 

1210 weight, 

1211 bias, 

1212 stride_h, 

1213 stride_w, 

1214 padding_h, 

1215 padding_w, 

1216 dilation_h, 

1217 dilation_w, 

1218 output_padding_h, 

1219 output_padding_w, 

1220 groups, 

1221): 

1222 batch, input_channels, input_height, input_width = input.shape 

1223 _, output_channels_per_group, weight_height, weight_width = weight.shape 

1224 output_channels = output_channels_per_group * groups 

1225 output_height = ( 

1226 (input_height - 1) * stride_h 

1227 - 2 * padding_h 

1228 + dilation_h * (weight_height - 1) 

1229 + output_padding_h 

1230 + 1 

1231 ) 

1232 output_width = ( 

1233 (input_width - 1) * stride_w 

1234 - 2 * padding_w 

1235 + dilation_w * (weight_width - 1) 

1236 + output_padding_w 

1237 + 1 

1238 ) 

1239 output = torch.empty( 

1240 (batch, output_channels, output_height, output_width), 

1241 device=input.device, 

1242 dtype=input.dtype, 

1243 ) 

1244 total_elements = output.numel() 

1245 if total_elements == 0: 

1246 return output 

1247 

1248 init_block = 1024 

1249 bias_pointer = bias if bias is not None else input 

1250 _conv_transpose2d_scatter_init_kernel[(triton.cdiv(total_elements, init_block),)]( 

1251 bias_pointer, 

1252 output, 

1253 total_elements, 

1254 output_channels, 

1255 output_height, 

1256 output_width, 

1257 bias is not None, 

1258 BLOCK_SIZE=init_block, 

1259 num_warps=4, 

1260 ) 

1261 

1262 input_channels_per_group = input_channels // groups 

1263 if input_channels_per_group <= 16: 

1264 block_ci = 16 

1265 elif input_channels_per_group <= 64: 

1266 block_ci = 64 if input.dtype is not torch.float32 else 32 

1267 else: 

1268 block_ci = 64 

1269 block_co = 16 if output_channels_per_group <= 16 else 32 

1270 block_nhw = 32 if input.dtype is torch.float32 else 64 

1271 if output_channels_per_group >= 64: 

1272 block_nhw = 32 

1273 

1274 input_nhw = batch * input_height * input_width 

1275 grid = ( 

1276 triton.cdiv(input_nhw, block_nhw), 

1277 triton.cdiv(output_channels_per_group, block_co), 

1278 groups * weight_height * weight_width, 

1279 ) 

1280 _conv_transpose2d_scatter_no_overlap_kernel[grid]( 

1281 input, 

1282 weight, 

1283 bias_pointer, 

1284 output, 

1285 batch, 

1286 input_channels, 

1287 input_height, 

1288 input_width, 

1289 output_channels, 

1290 output_height, 

1291 output_width, 

1292 weight_height, 

1293 weight_width, 

1294 output_channels_per_group, 

1295 input_channels_per_group, 

1296 stride_h, 

1297 stride_w, 

1298 padding_h, 

1299 padding_w, 

1300 dilation_h, 

1301 dilation_w, 

1302 bias is not None, 

1303 BLOCK_NHW=block_nhw, 

1304 BLOCK_CI=block_ci, 

1305 BLOCK_CO=block_co, 

1306 num_warps=4, 

1307 num_stages=3, 

1308 ) 

1309 return output 

1310 

1311 

1312def conv_transpose2d( 

1313 input, 

1314 weight, 

1315 bias=None, 

1316 stride=1, 

1317 padding=0, 

1318 output_padding=0, 

1319 groups=1, 

1320 dilation=1, 

1321): 

1322 logger.debug("GEMS CONV_TRANSPOSE2D") 

1323 

1324 stride_h, stride_w = _pair(stride) 

1325 padding_h, padding_w = _pair(padding) 

1326 output_padding_h, output_padding_w = _pair(output_padding) 

1327 dilation_h, dilation_w = _pair(dilation) 

1328 

1329 input_was_unbatched = input.dim() == 3 

1330 if input_was_unbatched: 

1331 input = input.unsqueeze(0) 

1332 

1333 if not input.is_contiguous(): 

1334 input = input.contiguous() 

1335 if not weight.is_contiguous(): 

1336 weight = weight.contiguous() 

1337 if bias is not None and not bias.is_contiguous(): 

1338 bias = bias.contiguous() 

1339 

1340 output = _conv_transpose2d_4d_dispatch( 

1341 input, 

1342 weight, 

1343 bias, 

1344 stride_h, 

1345 stride_w, 

1346 padding_h, 

1347 padding_w, 

1348 output_padding_h, 

1349 output_padding_w, 

1350 groups, 

1351 dilation_h, 

1352 dilation_w, 

1353 ) 

1354 if input_was_unbatched: 

1355 return output.squeeze(0) 

1356 return output 

1357 

1358 

1359def _conv_transpose2d_4d_dispatch( 

1360 input, 

1361 weight, 

1362 bias, 

1363 stride_h, 

1364 stride_w, 

1365 padding_h, 

1366 padding_w, 

1367 output_padding_h, 

1368 output_padding_w, 

1369 groups, 

1370 dilation_h, 

1371 dilation_w, 

1372): 

1373 if _can_use_stride2_pad1_3x3_direct( 

1374 input, 

1375 weight, 

1376 bias, 

1377 stride_h, 

1378 stride_w, 

1379 padding_h, 

1380 padding_w, 

1381 output_padding_h, 

1382 output_padding_w, 

1383 groups, 

1384 dilation_h, 

1385 dilation_w, 

1386 ): 

1387 return _conv_transpose2d_stride2_pad1_3x3(input, weight) 

1388 

1389 direct_tiled_family_params = _direct_tiled_family_params( 

1390 input, 

1391 weight, 

1392 bias, 

1393 stride_h, 

1394 stride_w, 

1395 padding_h, 

1396 padding_w, 

1397 output_padding_h, 

1398 output_padding_w, 

1399 groups, 

1400 dilation_h, 

1401 dilation_w, 

1402 ) 

1403 if _can_use_direct_tiled_family( 

1404 input, direct_tiled_family_params, output_padding_h 

1405 ): 

1406 return _conv_transpose2d_direct( 

1407 input, 

1408 weight, 

1409 stride_h, 

1410 stride_w, 

1411 padding_h, 

1412 padding_w, 

1413 dilation_h, 

1414 dilation_w, 

1415 output_padding_h, 

1416 output_padding_w, 

1417 ) 

1418 

1419 if _validate_conv_transpose2d_args( 

1420 input, 

1421 weight, 

1422 bias, 

1423 stride_h, 

1424 stride_w, 

1425 padding_h, 

1426 padding_w, 

1427 output_padding_h, 

1428 output_padding_w, 

1429 groups, 

1430 dilation_h, 

1431 dilation_w, 

1432 ): 

1433 if _can_use_pointwise_1x1( 

1434 weight, 

1435 stride_h, 

1436 stride_w, 

1437 padding_h, 

1438 padding_w, 

1439 output_padding_h, 

1440 output_padding_w, 

1441 ): 

1442 return _conv_transpose2d_pointwise_1x1(input, weight, bias, groups) 

1443 if _can_use_scatter_no_overlap( 

1444 input, 

1445 weight, 

1446 stride_h, 

1447 stride_w, 

1448 dilation_h, 

1449 dilation_w, 

1450 groups, 

1451 ): 

1452 return _conv_transpose2d_scatter_no_overlap( 

1453 input, 

1454 weight, 

1455 bias, 

1456 stride_h, 

1457 stride_w, 

1458 padding_h, 

1459 padding_w, 

1460 dilation_h, 

1461 dilation_w, 

1462 output_padding_h, 

1463 output_padding_w, 

1464 groups, 

1465 ) 

1466 return _conv_transpose2d_general( 

1467 input, 

1468 weight, 

1469 bias, 

1470 stride_h, 

1471 stride_w, 

1472 padding_h, 

1473 padding_w, 

1474 dilation_h, 

1475 dilation_w, 

1476 output_padding_h, 

1477 output_padding_w, 

1478 groups, 

1479 ) 

1480 

1481 return _unsupported_conv_transpose2d( 

1482 input, 

1483 weight, 

1484 bias, 

1485 stride_h, 

1486 stride_w, 

1487 padding_h, 

1488 padding_w, 

1489 output_padding_h, 

1490 output_padding_w, 

1491 groups, 

1492 dilation_h, 

1493 dilation_w, 

1494 ) 

1495 

1496 

1497def _select_stride2_pad1_3x3_schedule(input_dtype, input_channels, output_channels): 

1498 block_nhw = 64 

1499 block_ci = 32 

1500 block_co = 32 

1501 num_warps = 4 

1502 

1503 if input_dtype is torch.float32: 

1504 block_ci = 16 

1505 block_co = 16 

1506 elif input_channels <= 32 and output_channels >= 64: 

1507 block_nhw = 128 

1508 block_ci = 16 

1509 block_co = 64 

1510 elif input_channels >= 64 and output_channels <= 32: 

1511 block_nhw = 64 

1512 block_ci = 32 

1513 block_co = 32 

1514 num_warps = 8 

1515 elif input_dtype is torch.bfloat16 and input_channels >= 128: 

1516 block_nhw = 128 

1517 block_ci = 16 

1518 block_co = 16 

1519 num_warps = 8 

1520 

1521 return block_nhw, block_ci, block_co, num_warps 

1522 

1523 

1524def _conv_transpose2d_stride2_pad1_3x3(input, weight): 

1525 batch, input_channels, input_height, input_width = input.shape 

1526 _, output_channels, _weight_height, _weight_width = weight.shape 

1527 output_height = input_height * 2 - 1 

1528 output_width = input_width * 2 - 1 

1529 output = torch.empty( 

1530 (batch, output_channels, output_height, output_width), 

1531 device=input.device, 

1532 dtype=input.dtype, 

1533 ) 

1534 if output.numel() == 0: 

1535 return output 

1536 

1537 block_nhw, block_ci, block_co, num_warps = _select_stride2_pad1_3x3_schedule( 

1538 input.dtype, 

1539 input_channels, 

1540 output_channels, 

1541 ) 

1542 compact_height = (output_height + 1) // 2 

1543 compact_width = (output_width + 1) // 2 

1544 grid = ( 

1545 triton.cdiv(batch * compact_height * compact_width, block_nhw) * 4, 

1546 triton.cdiv(output_channels, block_co), 

1547 ) 

1548 _conv_transpose2d_stride2_pad1_3x3_kernel[grid]( 

1549 input, 

1550 weight, 

1551 output, 

1552 batch, 

1553 input_height, 

1554 input_width, 

1555 output_channels, 

1556 output_height, 

1557 output_width, 

1558 compact_height, 

1559 compact_width, 

1560 *input.stride(), 

1561 *weight.stride(), 

1562 *output.stride(), 

1563 input_channels, 

1564 BLOCK_NHW=block_nhw, 

1565 BLOCK_CI=block_ci, 

1566 BLOCK_CO=block_co, 

1567 num_warps=num_warps, 

1568 ) 

1569 return output 

1570 

1571 

1572def _select_conv_transpose2d_direct_schedule( 

1573 input_dtype, 

1574 input_channels, 

1575 output_channels, 

1576 weight_height, 

1577 weight_width, 

1578 stride_h, 

1579 output_padding_h, 

1580): 

1581 block_nhw, block_ci, block_co, num_warps = _DIRECT_TILED_DEFAULT_SCHEDULE 

1582 

1583 if input_dtype is torch.bfloat16: 

1584 if stride_h >= 3: 

1585 block_nhw = 128 

1586 block_ci = 16 

1587 block_co = 16 

1588 num_warps = 8 

1589 elif input_channels >= 128: 

1590 block_nhw = 256 

1591 block_ci = 16 

1592 block_co = 16 

1593 num_warps = 8 

1594 elif weight_height >= 5 or weight_width >= 5: 

1595 block_nhw = 128 

1596 block_ci = 16 

1597 elif input_channels >= 64 and output_channels <= 32: 

1598 block_ci = 64 

1599 if stride_h == 1: 

1600 num_warps = 8 

1601 elif input_dtype is torch.float16: 

1602 if stride_h >= 3: 

1603 block_nhw = 128 

1604 block_ci = 16 

1605 block_co = 16 

1606 num_warps = 8 

1607 elif weight_height >= 5 or weight_width >= 5: 

1608 block_nhw = 128 

1609 block_ci = 16 

1610 elif input_channels >= 64 and output_channels <= 32: 

1611 block_ci = 64 

1612 if stride_h == 1: 

1613 num_warps = 8 

1614 elif input_dtype is torch.float32 and (weight_height >= 5 or weight_width >= 5): 

1615 block_ci = 16 

1616 elif input_channels >= 64 and output_channels <= 32: 

1617 block_ci = 64 

1618 if stride_h == 1: 

1619 num_warps = 8 

1620 if ( 

1621 stride_h == 1 

1622 and weight_height <= 3 

1623 and weight_width <= 3 

1624 and input_channels >= 64 

1625 and output_channels <= 64 

1626 ): 

1627 block_nhw = 256 

1628 block_ci = 16 

1629 block_co = 32 

1630 num_warps = 8 

1631 elif ( 

1632 stride_h == 2 

1633 and weight_height <= 3 

1634 and weight_width <= 3 

1635 and input_channels <= 32 

1636 and output_channels >= 64 

1637 ): 

1638 block_nhw = 128 

1639 block_ci = 16 

1640 block_co = 64 

1641 num_warps = 4 

1642 elif ( 

1643 stride_h == 2 

1644 and weight_height <= 3 

1645 and weight_width <= 3 

1646 and input_channels >= 64 

1647 and output_channels <= 32 

1648 ): 

1649 block_nhw = 32 

1650 block_ci = 16 

1651 block_co = 32 

1652 num_warps = 8 

1653 if output_padding_h: 

1654 block_nhw = min(block_nhw, 128) 

1655 block_ci = min(block_ci, 32) 

1656 

1657 return block_nhw, block_ci, block_co, num_warps 

1658 

1659 

1660def _conv_transpose2d_direct( 

1661 input, 

1662 weight, 

1663 stride_h, 

1664 stride_w, 

1665 padding_h, 

1666 padding_w, 

1667 dilation_h, 

1668 dilation_w, 

1669 output_padding_h, 

1670 output_padding_w, 

1671): 

1672 batch, input_channels, input_height, input_width = input.shape 

1673 _, output_channels, weight_height, weight_width = weight.shape 

1674 output_height = ( 

1675 (input_height - 1) * stride_h 

1676 - 2 * padding_h 

1677 + dilation_h * (weight_height - 1) 

1678 + output_padding_h 

1679 + 1 

1680 ) 

1681 output_width = ( 

1682 (input_width - 1) * stride_w 

1683 - 2 * padding_w 

1684 + dilation_w * (weight_width - 1) 

1685 + output_padding_w 

1686 + 1 

1687 ) 

1688 output = torch.empty( 

1689 (batch, output_channels, output_height, output_width), 

1690 device=input.device, 

1691 dtype=input.dtype, 

1692 ) 

1693 compact_height = triton.cdiv(output_height, stride_h) 

1694 compact_width = triton.cdiv(output_width, stride_w) 

1695 max_sub_spatial = batch * compact_height * compact_width 

1696 n_subgrids = stride_h * stride_w 

1697 

1698 block_nhw, block_ci, block_co, num_warps = _select_conv_transpose2d_direct_schedule( 

1699 input.dtype, 

1700 input_channels, 

1701 output_channels, 

1702 weight_height, 

1703 weight_width, 

1704 stride_h, 

1705 output_padding_h, 

1706 ) 

1707 

1708 grid = ( 

1709 triton.cdiv(max_sub_spatial, block_nhw), 

1710 triton.cdiv(output_channels, block_co), 

1711 n_subgrids, 

1712 ) 

1713 _conv_transpose2d_direct_kernel[grid]( 

1714 input, 

1715 weight, 

1716 output, 

1717 batch, 

1718 input_height, 

1719 input_width, 

1720 output_channels, 

1721 output_height, 

1722 output_width, 

1723 *input.stride(), 

1724 *weight.stride(), 

1725 *output.stride(), 

1726 input_channels, 

1727 weight_height, 

1728 weight_width, 

1729 stride_h, 

1730 stride_w, 

1731 padding_h, 

1732 padding_w, 

1733 BLOCK_NHW=block_nhw, 

1734 BLOCK_CI=block_ci, 

1735 BLOCK_CO=block_co, 

1736 num_warps=num_warps, 

1737 ) 

1738 return output 

1739 

1740 

1741def _conv_transpose2d_general( 

1742 input, 

1743 weight, 

1744 bias, 

1745 stride_h, 

1746 stride_w, 

1747 padding_h, 

1748 padding_w, 

1749 dilation_h, 

1750 dilation_w, 

1751 output_padding_h, 

1752 output_padding_w, 

1753 groups, 

1754): 

1755 return _conv_transpose2d_residue( 

1756 input, 

1757 weight, 

1758 bias, 

1759 stride_h, 

1760 stride_w, 

1761 padding_h, 

1762 padding_w, 

1763 dilation_h, 

1764 dilation_w, 

1765 output_padding_h, 

1766 output_padding_w, 

1767 groups, 

1768 ) 

1769 

1770 

1771def _conv_transpose2d_residue( 

1772 input, 

1773 weight, 

1774 bias, 

1775 stride_h, 

1776 stride_w, 

1777 padding_h, 

1778 padding_w, 

1779 dilation_h, 

1780 dilation_w, 

1781 output_padding_h, 

1782 output_padding_w, 

1783 groups, 

1784): 

1785 batch, input_channels, input_height, input_width = input.shape 

1786 _, output_channels_per_group, weight_height, weight_width = weight.shape 

1787 output_channels = output_channels_per_group * groups 

1788 output_height = ( 

1789 (input_height - 1) * stride_h 

1790 - 2 * padding_h 

1791 + dilation_h * (weight_height - 1) 

1792 + output_padding_h 

1793 + 1 

1794 ) 

1795 output_width = ( 

1796 (input_width - 1) * stride_w 

1797 - 2 * padding_w 

1798 + dilation_w * (weight_width - 1) 

1799 + output_padding_w 

1800 + 1 

1801 ) 

1802 output = torch.empty( 

1803 (batch, output_channels, output_height, output_width), 

1804 device=input.device, 

1805 dtype=input.dtype, 

1806 ) 

1807 total_elements = output.numel() 

1808 if total_elements == 0: 

1809 return output 

1810 

1811 input_channels_per_group = input_channels // groups 

1812 if ( 

1813 input.dtype in _TRITON_DIRECT_LOWP_DTYPES 

1814 and weight_height >= 5 

1815 and weight_width >= 5 

1816 and stride_h == 2 

1817 and stride_w == 2 

1818 and dilation_h == 1 

1819 and dilation_w == 1 

1820 and input_channels_per_group >= 64 

1821 and output_channels_per_group <= 32 

1822 ): 

1823 block_nhw = 256 

1824 block_ci = 16 

1825 block_co = 32 

1826 co_blocks_per_group = triton.cdiv(output_channels_per_group, block_co) 

1827 bias_pointer = bias if bias is not None else input 

1828 for residue_h in range(stride_h): 

1829 compact_height = (output_height + stride_h - 1 - residue_h) // stride_h 

1830 for residue_w in range(stride_w): 

1831 compact_width = (output_width + stride_w - 1 - residue_w) // stride_w 

1832 grid = ( 

1833 triton.cdiv(batch * compact_height * compact_width, block_nhw), 

1834 groups * co_blocks_per_group, 

1835 ) 

1836 _conv_transpose2d_residue_static_kernel[grid]( 

1837 input, 

1838 weight, 

1839 bias_pointer, 

1840 output, 

1841 batch, 

1842 input_channels, 

1843 input_height, 

1844 input_width, 

1845 output_channels, 

1846 output_height, 

1847 output_width, 

1848 compact_height, 

1849 compact_width, 

1850 weight_height, 

1851 weight_width, 

1852 output_channels_per_group, 

1853 input_channels_per_group, 

1854 stride_h, 

1855 stride_w, 

1856 padding_h, 

1857 padding_w, 

1858 dilation_h, 

1859 dilation_w, 

1860 bias is not None, 

1861 residue_h, 

1862 residue_w, 

1863 co_blocks_per_group, 

1864 BLOCK_NHW=block_nhw, 

1865 BLOCK_CI=block_ci, 

1866 BLOCK_CO=block_co, 

1867 num_warps=4, 

1868 num_stages=2, 

1869 ) 

1870 return output 

1871 

1872 block_nhw = 64 

1873 block_ci = 32 

1874 block_co = 32 

1875 num_warps = 4 

1876 if input.dtype is torch.float32: 

1877 block_ci = 16 

1878 block_co = 16 

1879 elif input_channels_per_group <= 16: 

1880 block_ci = 16 

1881 if output_channels_per_group <= 16: 

1882 block_co = 16 

1883 if ( 

1884 weight_height >= 5 

1885 and weight_width >= 5 

1886 and stride_h == 2 

1887 and stride_w == 2 

1888 and input_channels_per_group >= 64 

1889 and output_channels_per_group <= 32 

1890 ): 

1891 block_nhw = 128 

1892 block_ci = 64 if input.dtype is not torch.float32 else 32 

1893 block_co = 16 

1894 num_warps = 8 

1895 if stride_h * stride_w >= 4 and input.dtype is not torch.float32: 

1896 block_nhw = 128 

1897 num_warps = 8 

1898 

1899 compact_height = triton.cdiv(output_height, stride_h) 

1900 compact_width = triton.cdiv(output_width, stride_w) 

1901 max_sub_spatial = batch * compact_height * compact_width 

1902 n_subgrids = stride_h * stride_w 

1903 co_blocks_per_group = triton.cdiv(output_channels_per_group, block_co) 

1904 grid = ( 

1905 triton.cdiv(max_sub_spatial, block_nhw), 

1906 co_blocks_per_group, 

1907 groups * n_subgrids, 

1908 ) 

1909 bias_pointer = bias if bias is not None else input 

1910 _conv_transpose2d_residue_kernel[grid]( 

1911 input, 

1912 weight, 

1913 bias_pointer, 

1914 output, 

1915 batch, 

1916 input_channels, 

1917 input_height, 

1918 input_width, 

1919 output_channels, 

1920 output_height, 

1921 output_width, 

1922 weight_height, 

1923 weight_width, 

1924 output_channels_per_group, 

1925 input_channels // groups, 

1926 stride_h, 

1927 stride_w, 

1928 padding_h, 

1929 padding_w, 

1930 dilation_h, 

1931 dilation_w, 

1932 bias is not None, 

1933 n_subgrids, 

1934 BLOCK_NHW=block_nhw, 

1935 BLOCK_CI=block_ci, 

1936 BLOCK_CO=block_co, 

1937 num_warps=num_warps, 

1938 ) 

1939 return output