Coverage for src/flag_gems/ops/conv_transpose2d.py: 45%
750 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1"""Triton implementation of ``torch.nn.functional.conv_transpose2d``.
3The implementation uses semantic, parameter-regime dispatch only: a direct
4tiled path for common dense group=1 cases, a pointwise 1x1 path, a scatter path
5for no-overlap sparse-output cases, and a full residue path for the supported
6PyTorch API surface. There are no shape-specific dispatch constants.
7"""
9import logging
11import torch
12import triton
13import triton.language as tl
15from flag_gems.utils import libentry
17logger = logging.getLogger(__name__)
19_TRITON_DIRECT_LOWP_DTYPES = (torch.float16, torch.bfloat16)
21_GENERAL_TRITON_DTYPES = (torch.float32, torch.float16, torch.bfloat16)
23_DIRECT_TILED_FAMILY_MAX_CHANNELS = 256
24_DIRECT_TILED_FAMILY_MAX_KERNEL = 5
25_DIRECT_TILED_FAMILY_MAX_STRIDE = 4
26_DIRECT_TILED_OUTPUT_PADDING_MIN_INPUT_ELEMENTS = 1024
27_DIRECT_TILED_DEFAULT_SCHEDULE = (64, 32, 32, 4)
28_DIRECT_STRIDE2_PAD1_3X3_MAX_CHANNELS = 256
31def _pair(value):
32 if isinstance(value, (list, tuple)):
33 if len(value) != 2:
34 raise RuntimeError("expected a single int or a pair of ints")
35 return int(value[0]), int(value[1])
36 return int(value), int(value)
39def _direct_tiled_family_params(
40 input,
41 weight,
42 bias,
43 stride_h,
44 stride_w,
45 padding_h,
46 padding_w,
47 output_padding_h,
48 output_padding_w,
49 groups,
50 dilation_h,
51 dilation_w,
52):
53 if bias is not None or groups != 1:
54 return None
55 if (dilation_h, dilation_w) != (1, 1):
56 return None
57 if input.dtype not in _GENERAL_TRITON_DTYPES or weight.dtype != input.dtype:
58 return None
59 if input.device.type != "cuda" or weight.device != input.device:
60 return None
61 if input.dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
62 return None
63 if input.dim() != 4 or weight.dim() != 4:
64 return None
65 if not input.is_contiguous() or not weight.is_contiguous():
66 return None
67 if stride_h != stride_w or padding_h != padding_w:
68 return None
69 if output_padding_h != output_padding_w:
70 return None
71 if stride_h <= 0 or stride_h > _DIRECT_TILED_FAMILY_MAX_STRIDE:
72 return None
73 if padding_h < 0 or output_padding_h < 0:
74 return None
76 batch, input_channels, input_height, input_width = input.shape
77 weight_input_channels, output_channels, weight_height, weight_width = weight.shape
78 if batch <= 0 or input_height <= 0 or input_width <= 0:
79 return None
80 if input_channels != weight_input_channels:
81 return None
82 if input_channels < 16 or output_channels < 16:
83 return None
84 if (
85 input_channels > _DIRECT_TILED_FAMILY_MAX_CHANNELS
86 or output_channels > _DIRECT_TILED_FAMILY_MAX_CHANNELS
87 ):
88 return None
89 if (
90 weight_height <= 0
91 or weight_width <= 0
92 or weight_height > _DIRECT_TILED_FAMILY_MAX_KERNEL
93 or weight_width > _DIRECT_TILED_FAMILY_MAX_KERNEL
94 ):
95 return None
96 output_height = (
97 (input_height - 1) * stride_h - 2 * padding_h + weight_height + output_padding_h
98 )
99 output_width = (
100 (input_width - 1) * stride_w - 2 * padding_w + weight_width + output_padding_w
101 )
102 if output_height <= 0 or output_width <= 0:
103 return None
104 return (
105 batch,
106 input_channels,
107 input_height,
108 input_width,
109 output_channels,
110 weight_height,
111 weight_width,
112 stride_h,
113 padding_h,
114 )
117def _can_use_direct_tiled_family(
118 input,
119 direct_tiled_family_params,
120 output_padding_h,
121):
122 if direct_tiled_family_params is None:
123 return False
124 (
125 batch,
126 input_channels,
127 input_height,
128 input_width,
129 output_channels,
130 weight_height,
131 weight_width,
132 stride_h,
133 _padding_h,
134 ) = direct_tiled_family_params
136 if output_padding_h == 0 and stride_h <= 2:
137 return True
138 input_elements = batch * input_height * input_width
139 if (
140 input.dtype in _GENERAL_TRITON_DTYPES
141 and stride_h == 2
142 and output_padding_h == 1
143 and weight_height == 3
144 and weight_width == 3
145 and input_channels >= 64
146 and output_channels <= 64
147 and input_elements >= _DIRECT_TILED_OUTPUT_PADDING_MIN_INPUT_ELEMENTS
148 ):
149 return True
150 if stride_h >= 3 and output_padding_h == 0:
151 if weight_height >= 5 or weight_width >= 5:
152 return True
153 if input.dtype in _TRITON_DIRECT_LOWP_DTYPES:
154 return True
155 return False
158def _unsupported_conv_transpose2d(
159 input,
160 weight,
161 bias,
162 stride_h,
163 stride_w,
164 padding_h,
165 padding_w,
166 output_padding_h,
167 output_padding_w,
168 groups,
169 dilation_h,
170 dilation_w,
171):
172 bias_dtype = None if bias is None else bias.dtype
173 raise NotImplementedError(
174 "flag_gems.conv_transpose2d supports 3D or 4D CUDA input tensors "
175 "and 4D CUDA weight tensors with float32, float16, or bfloat16 dtype; got "
176 f"input_shape={tuple(input.shape)}, weight_shape={tuple(weight.shape)}, "
177 f"input_dtype={input.dtype}, weight_dtype={weight.dtype}, bias_dtype={bias_dtype}, "
178 f"input_device={input.device}, weight_device={weight.device}, "
179 f"stride=({stride_h}, {stride_w}), padding=({padding_h}, {padding_w}), "
180 f"output_padding=({output_padding_h}, {output_padding_w}), groups={groups}, "
181 f"dilation=({dilation_h}, {dilation_w})"
182 )
185def _validate_conv_transpose2d_args(
186 input,
187 weight,
188 bias,
189 stride_h,
190 stride_w,
191 padding_h,
192 padding_w,
193 output_padding_h,
194 output_padding_w,
195 groups,
196 dilation_h,
197 dilation_w,
198):
199 if input.device.type != "cuda" or weight.device != input.device:
200 return False
201 if input.dim() != 4 or weight.dim() != 4:
202 return False
203 if not input.is_contiguous() or not weight.is_contiguous():
204 return False
205 if input.dtype not in _GENERAL_TRITON_DTYPES or weight.dtype != input.dtype:
206 return False
207 if input.dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
208 return False
209 if bias is not None:
210 if bias.device != input.device or bias.dtype != input.dtype:
211 return False
212 if bias.dim() != 1 or not bias.is_contiguous():
213 return False
214 if groups <= 0:
215 raise RuntimeError("groups must be a positive integer")
216 if stride_h <= 0 or stride_w <= 0:
217 raise RuntimeError("non-positive stride is not supported")
218 if dilation_h <= 0 or dilation_w <= 0:
219 raise RuntimeError("dilation should be greater than zero")
220 if padding_h < 0 or padding_w < 0:
221 raise RuntimeError("negative padding is not supported")
222 if output_padding_h < 0 or output_padding_w < 0:
223 raise RuntimeError("negative output_padding is not supported")
224 if output_padding_h >= stride_h and output_padding_h >= dilation_h:
225 raise RuntimeError(
226 "output padding must be smaller than either stride or dilation"
227 )
228 if output_padding_w >= stride_w and output_padding_w >= dilation_w:
229 raise RuntimeError(
230 "output padding must be smaller than either stride or dilation"
231 )
233 input_channels = input.shape[1]
234 weight_input_channels = weight.shape[0]
235 output_channels_per_group = weight.shape[1]
236 weight_height = weight.shape[2]
237 weight_width = weight.shape[3]
238 if (
239 input_channels <= 0
240 or output_channels_per_group <= 0
241 or weight_height <= 0
242 or weight_width <= 0
243 ):
244 raise RuntimeError(
245 "non-empty input channels and weight dimensions are required"
246 )
247 if input_channels != weight_input_channels:
248 raise RuntimeError(
249 "expected input channel dimension to match weight input channels"
250 )
251 if input_channels % groups != 0:
252 raise RuntimeError("input channels must be divisible by groups")
253 output_channels = output_channels_per_group * groups
254 if bias is not None and bias.numel() != output_channels:
255 raise RuntimeError("expected bias to have one element per output channel")
257 input_height = input.shape[2]
258 input_width = input.shape[3]
259 output_height = (
260 (input_height - 1) * stride_h
261 - 2 * padding_h
262 + dilation_h * (weight_height - 1)
263 + output_padding_h
264 + 1
265 )
266 output_width = (
267 (input_width - 1) * stride_w
268 - 2 * padding_w
269 + dilation_w * (weight_width - 1)
270 + output_padding_w
271 + 1
272 )
273 if output_height <= 0 or output_width <= 0:
274 raise RuntimeError("calculated output size is too small")
275 return True
278def _can_use_scatter_no_overlap(
279 input,
280 weight,
281 stride_h,
282 stride_w,
283 dilation_h,
284 dilation_w,
285 groups,
286):
287 batch, input_channels, input_height, input_width = input.shape
288 _, output_channels_per_group, weight_height, weight_width = weight.shape
289 if batch <= 0 or input_height <= 0 or input_width <= 0:
290 return False
291 effective_kernel_h = (weight_height - 1) * dilation_h + 1
292 effective_kernel_w = (weight_width - 1) * dilation_w + 1
293 if stride_h < effective_kernel_h or stride_w < effective_kernel_w:
294 return False
296 input_channels_per_group = input_channels // groups
297 if input_channels_per_group > 128 or output_channels_per_group > 128:
298 return False
299 return weight_height * weight_width <= 25
302def _can_use_stride2_pad1_3x3_direct(
303 input,
304 weight,
305 bias,
306 stride_h,
307 stride_w,
308 padding_h,
309 padding_w,
310 output_padding_h,
311 output_padding_w,
312 groups,
313 dilation_h,
314 dilation_w,
315):
316 if bias is not None or groups != 1:
317 return False
318 if (dilation_h, dilation_w) != (1, 1):
319 return False
320 if (output_padding_h, output_padding_w) != (0, 0):
321 return False
322 if (stride_h, stride_w) != (2, 2) or (padding_h, padding_w) != (1, 1):
323 return False
324 if input.dim() != 4 or weight.dim() != 4:
325 return False
326 if input.device.type != "cuda" or weight.device != input.device:
327 return False
328 if input.dtype not in _GENERAL_TRITON_DTYPES or weight.dtype != input.dtype:
329 return False
330 if input.dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported():
331 return False
332 if not input.is_contiguous() or not weight.is_contiguous():
333 return False
335 batch, input_channels, input_height, input_width = input.shape
336 weight_input_channels, output_channels, weight_height, weight_width = weight.shape
337 if batch <= 0 or input_height <= 0 or input_width <= 0:
338 return False
339 if input_channels != weight_input_channels:
340 return False
341 if (weight_height, weight_width) != (3, 3):
342 return False
343 if input_channels < 16 or output_channels < 16:
344 return False
345 if (
346 input_channels > _DIRECT_STRIDE2_PAD1_3X3_MAX_CHANNELS
347 or output_channels > _DIRECT_STRIDE2_PAD1_3X3_MAX_CHANNELS
348 ):
349 return False
350 if input.dtype is torch.float32:
351 return True
352 if (
353 input.dtype is torch.float16
354 and input_channels <= 32
355 and output_channels >= 64
356 and input_height <= 16
357 ):
358 return True
359 return (
360 input.dtype is torch.bfloat16
361 and input_channels >= 64
362 and output_channels <= 32
363 and input_height <= 16
364 )
367@libentry()
368@triton.jit
369def _conv_transpose2d_direct_kernel(
370 input_pointer,
371 weight_pointer,
372 output_pointer,
373 batch_size: tl.constexpr,
374 input_height: tl.constexpr,
375 input_width: tl.constexpr,
376 output_channels: tl.constexpr,
377 output_height: tl.constexpr,
378 output_width: tl.constexpr,
379 input_n_stride: tl.constexpr,
380 input_c_stride: tl.constexpr,
381 input_height_stride: tl.constexpr,
382 input_width_stride: tl.constexpr,
383 weight_ci_stride: tl.constexpr,
384 weight_co_stride: tl.constexpr,
385 weight_height_stride: tl.constexpr,
386 weight_width_stride: tl.constexpr,
387 output_n_stride: tl.constexpr,
388 output_c_stride: tl.constexpr,
389 output_height_stride: tl.constexpr,
390 output_width_stride: tl.constexpr,
391 input_channels: tl.constexpr,
392 weight_height: tl.constexpr,
393 weight_width: tl.constexpr,
394 stride_height: tl.constexpr,
395 stride_width: tl.constexpr,
396 padding_height: tl.constexpr,
397 padding_width: tl.constexpr,
398 BLOCK_NHW: tl.constexpr,
399 BLOCK_CI: tl.constexpr,
400 BLOCK_CO: tl.constexpr,
401):
402 pid_nhw = tl.program_id(0)
403 pid_co = tl.program_id(1)
404 pid_subgrid = tl.program_id(2)
406 output_residue_h = pid_subgrid // stride_width
407 output_residue_w = pid_subgrid % stride_width
408 compact_height: tl.constexpr = (output_height + stride_height - 1) // stride_height
409 compact_width: tl.constexpr = (output_width + stride_width - 1) // stride_width
411 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW)
412 compact_plane: tl.constexpr = compact_height * compact_width
413 compact_nh = compact_offsets // compact_width
414 compact_h = compact_nh % compact_height
415 compact_w = compact_offsets % compact_width
416 n = compact_offsets // compact_plane
417 oh = compact_h * stride_height + output_residue_h
418 ow = compact_w * stride_width + output_residue_w
419 co_offsets = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
421 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32)
422 ci_blocks: tl.constexpr = tl.cdiv(input_channels, BLOCK_CI)
423 height_residue = (output_residue_h + padding_height) % stride_height
424 width_residue = (output_residue_w + padding_width) % stride_width
425 for kh in range(weight_height):
426 if kh % stride_height == height_residue:
427 ih_unstrided = oh + padding_height - kh
428 ih = ih_unstrided // stride_height
429 valid_h = (ih_unstrided >= 0) & (ih < input_height)
430 for kw in range(weight_width):
431 if kw % stride_width == width_residue:
432 iw_unstrided = ow + padding_width - kw
433 iw = iw_unstrided // stride_width
434 valid_hw = (
435 (n < batch_size)
436 & valid_h
437 & (iw_unstrided >= 0)
438 & (iw < input_width)
439 & (oh < output_height)
440 & (ow < output_width)
441 )
442 for ci_base in range(ci_blocks):
443 ci_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI)
444 input_offsets = (
445 n[:, None] * input_n_stride
446 + ci_offsets[None, :] * input_c_stride
447 + ih[:, None] * input_height_stride
448 + iw[:, None] * input_width_stride
449 )
450 weight_offsets = (
451 ci_offsets[:, None] * weight_ci_stride
452 + co_offsets[None, :] * weight_co_stride
453 + kh * weight_height_stride
454 + kw * weight_width_stride
455 )
456 input_mask = valid_hw[:, None] & (
457 ci_offsets[None, :] < input_channels
458 )
459 weight_mask = (ci_offsets[:, None] < input_channels) & (
460 co_offsets[None, :] < output_channels
461 )
462 input_block = tl.load(
463 input_pointer + input_offsets, mask=input_mask, other=0.0
464 )
465 weight_block = tl.load(
466 weight_pointer + weight_offsets, mask=weight_mask, other=0.0
467 )
468 accum += tl.dot(
469 input_block,
470 weight_block,
471 input_precision="tf32x3",
472 )
474 output_offsets = (
475 n[:, None] * output_n_stride
476 + co_offsets[None, :] * output_c_stride
477 + oh[:, None] * output_height_stride
478 + ow[:, None] * output_width_stride
479 )
480 output_mask = (
481 (n[:, None] < batch_size)
482 & (oh[:, None] < output_height)
483 & (ow[:, None] < output_width)
484 & (co_offsets[None, :] < output_channels)
485 )
486 tl.store(output_pointer + output_offsets, accum, mask=output_mask)
489@libentry()
490@triton.jit
491def _conv_transpose2d_stride2_pad1_3x3_kernel(
492 input_pointer,
493 weight_pointer,
494 output_pointer,
495 batch_size: tl.constexpr,
496 input_height: tl.constexpr,
497 input_width: tl.constexpr,
498 output_channels: tl.constexpr,
499 output_height: tl.constexpr,
500 output_width: tl.constexpr,
501 compact_height: tl.constexpr,
502 compact_width: tl.constexpr,
503 input_n_stride: tl.constexpr,
504 input_c_stride: tl.constexpr,
505 input_height_stride: tl.constexpr,
506 input_width_stride: tl.constexpr,
507 weight_ci_stride: tl.constexpr,
508 weight_co_stride: tl.constexpr,
509 weight_height_stride: tl.constexpr,
510 weight_width_stride: tl.constexpr,
511 output_n_stride: tl.constexpr,
512 output_c_stride: tl.constexpr,
513 output_height_stride: tl.constexpr,
514 output_width_stride: tl.constexpr,
515 input_channels: tl.constexpr,
516 BLOCK_NHW: tl.constexpr,
517 BLOCK_CI: tl.constexpr,
518 BLOCK_CO: tl.constexpr,
519):
520 pid_raw = tl.program_id(0)
521 phase = pid_raw % 4
522 pid_nhw = pid_raw // 4
523 pid_co = tl.program_id(1)
525 residue_h = phase // 2
526 residue_w = phase % 2
527 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW)
528 compact_plane: tl.constexpr = compact_height * compact_width
529 compact_nh = compact_offsets // compact_width
530 compact_h = compact_nh % compact_height
531 compact_w = compact_offsets % compact_width
532 n = compact_offsets // compact_plane
533 oh = compact_h * 2 + residue_h
534 ow = compact_w * 2 + residue_w
535 co_offsets = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
537 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32)
538 ci_blocks: tl.constexpr = tl.cdiv(input_channels, BLOCK_CI)
539 height_residue = (residue_h + 1) % 2
540 width_residue = (residue_w + 1) % 2
541 for kh_slot in range(2):
542 kh = height_residue + kh_slot * 2
543 valid_kh = kh < 3
544 ih_unstrided = oh + 1 - kh
545 ih = ih_unstrided // 2
546 valid_h = valid_kh & (ih_unstrided >= 0) & (ih < input_height)
547 for kw_slot in range(2):
548 kw = width_residue + kw_slot * 2
549 valid_kw = kw < 3
550 iw_unstrided = ow + 1 - kw
551 iw = iw_unstrided // 2
552 valid_hw = (
553 (n < batch_size)
554 & valid_h
555 & valid_kw
556 & (iw_unstrided >= 0)
557 & (iw < input_width)
558 & (oh < output_height)
559 & (ow < output_width)
560 )
561 for ci_base in range(ci_blocks):
562 ci_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI)
563 input_offsets = (
564 n[:, None] * input_n_stride
565 + ci_offsets[None, :] * input_c_stride
566 + ih[:, None] * input_height_stride
567 + iw[:, None] * input_width_stride
568 )
569 weight_offsets = (
570 ci_offsets[:, None] * weight_ci_stride
571 + co_offsets[None, :] * weight_co_stride
572 + kh * weight_height_stride
573 + kw * weight_width_stride
574 )
575 input_mask = valid_hw[:, None] & (ci_offsets[None, :] < input_channels)
576 weight_mask = (
577 (ci_offsets[:, None] < input_channels)
578 & (co_offsets[None, :] < output_channels)
579 & valid_kh
580 & valid_kw
581 )
582 input_block = tl.load(
583 input_pointer + input_offsets, mask=input_mask, other=0.0
584 )
585 weight_block = tl.load(
586 weight_pointer + weight_offsets, mask=weight_mask, other=0.0
587 )
588 accum += tl.dot(
589 input_block,
590 weight_block,
591 input_precision="tf32x3",
592 )
594 output_offsets = (
595 n[:, None] * output_n_stride
596 + co_offsets[None, :] * output_c_stride
597 + oh[:, None] * output_height_stride
598 + ow[:, None] * output_width_stride
599 )
600 output_mask = (
601 (n[:, None] < batch_size)
602 & (oh[:, None] < output_height)
603 & (ow[:, None] < output_width)
604 & (co_offsets[None, :] < output_channels)
605 )
606 tl.store(output_pointer + output_offsets, accum, mask=output_mask)
609@libentry()
610@triton.jit
611def _conv_transpose2d_residue_kernel(
612 input_pointer,
613 weight_pointer,
614 bias_pointer,
615 output_pointer,
616 batch_size: tl.constexpr,
617 input_channels: tl.constexpr,
618 input_height: tl.constexpr,
619 input_width: tl.constexpr,
620 output_channels: tl.constexpr,
621 output_height: tl.constexpr,
622 output_width: tl.constexpr,
623 weight_height: tl.constexpr,
624 weight_width: tl.constexpr,
625 output_channels_per_group: tl.constexpr,
626 input_channels_per_group: tl.constexpr,
627 stride_height: tl.constexpr,
628 stride_width: tl.constexpr,
629 padding_height: tl.constexpr,
630 padding_width: tl.constexpr,
631 dilation_height: tl.constexpr,
632 dilation_width: tl.constexpr,
633 has_bias: tl.constexpr,
634 n_subgrids: tl.constexpr,
635 BLOCK_NHW: tl.constexpr,
636 BLOCK_CI: tl.constexpr,
637 BLOCK_CO: tl.constexpr,
638):
639 pid_nhw = tl.program_id(0)
640 pid_co_in_group = tl.program_id(1)
641 pid_phase_group = tl.program_id(2)
643 pid_subgrid = pid_phase_group % n_subgrids
644 group = pid_phase_group // n_subgrids
645 output_residue_h = pid_subgrid // stride_width
646 output_residue_w = pid_subgrid % stride_width
647 compact_height: tl.constexpr = (output_height + stride_height - 1) // stride_height
648 compact_width: tl.constexpr = (output_width + stride_width - 1) // stride_width
649 compact_plane: tl.constexpr = compact_height * compact_width
651 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW)
652 compact_nh = compact_offsets // compact_width
653 compact_h = compact_nh % compact_height
654 compact_w = compact_offsets % compact_width
655 n = compact_offsets // compact_plane
656 oh = compact_h * stride_height + output_residue_h
657 ow = compact_w * stride_width + output_residue_w
659 co_in_offsets = pid_co_in_group * BLOCK_CO + tl.arange(0, BLOCK_CO)
660 co_offsets = group * output_channels_per_group + co_in_offsets
662 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32)
663 if has_bias:
664 bias_values = tl.load(
665 bias_pointer + co_offsets,
666 mask=co_in_offsets < output_channels_per_group,
667 other=0.0,
668 ).to(tl.float32)
669 accum += bias_values[None, :]
671 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI)
672 height_residue = (output_residue_h + padding_height) % stride_height
673 width_residue = (output_residue_w + padding_width) % stride_width
674 for kh in range(weight_height):
675 kh_residue: tl.constexpr = (kh * dilation_height) % stride_height
676 if kh_residue == height_residue:
677 ih_unstrided = oh + padding_height - kh * dilation_height
678 ih = ih_unstrided // stride_height
679 valid_h = (n < batch_size) & (ih_unstrided >= 0) & (ih < input_height)
680 for kw in range(weight_width):
681 kw_residue: tl.constexpr = (kw * dilation_width) % stride_width
682 if kw_residue == width_residue:
683 iw_unstrided = ow + padding_width - kw * dilation_width
684 iw = iw_unstrided // stride_width
685 valid_hw = (
686 valid_h
687 & (iw_unstrided >= 0)
688 & (iw < input_width)
689 & (oh < output_height)
690 & (ow < output_width)
691 )
692 for ci_base in range(ci_blocks):
693 ci_in_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI)
694 ci_offsets = group * input_channels_per_group + ci_in_offsets
695 input_offsets = (
696 n[:, None] * input_channels + ci_offsets[None, :]
697 ) * input_height
698 input_offsets = (
699 input_offsets + ih[:, None]
700 ) * input_width + iw[:, None]
701 weight_offsets = (
702 ci_offsets[:, None] * output_channels_per_group
703 + co_in_offsets[None, :]
704 ) * weight_height
705 weight_offsets = (weight_offsets + kh) * weight_width + kw
706 input_mask = valid_hw[:, None] & (
707 ci_in_offsets[None, :] < input_channels_per_group
708 )
709 weight_mask = (
710 ci_in_offsets[:, None] < input_channels_per_group
711 ) & (co_in_offsets[None, :] < output_channels_per_group)
712 input_block = tl.load(
713 input_pointer + input_offsets, mask=input_mask, other=0.0
714 )
715 weight_block = tl.load(
716 weight_pointer + weight_offsets, mask=weight_mask, other=0.0
717 )
718 accum += tl.dot(
719 input_block,
720 weight_block,
721 input_precision="tf32x3",
722 )
724 output_offsets = n[:, None] * output_channels + co_offsets[None, :]
725 output_offsets = (output_offsets * output_height + oh[:, None]) * output_width
726 output_offsets = output_offsets + ow[:, None]
727 output_mask = (
728 (n[:, None] < batch_size)
729 & (oh[:, None] < output_height)
730 & (ow[:, None] < output_width)
731 & (co_in_offsets[None, :] < output_channels_per_group)
732 & (co_offsets[None, :] < output_channels)
733 )
734 tl.store(output_pointer + output_offsets, accum, mask=output_mask)
737@libentry()
738@triton.jit
739def _conv_transpose2d_general_kernel(
740 input_pointer,
741 weight_pointer,
742 bias_pointer,
743 output_pointer,
744 total_elements: tl.constexpr,
745 batch_size: tl.constexpr,
746 input_channels: tl.constexpr,
747 input_height: tl.constexpr,
748 input_width: tl.constexpr,
749 output_channels: tl.constexpr,
750 output_height: tl.constexpr,
751 output_width: tl.constexpr,
752 weight_height: tl.constexpr,
753 weight_width: tl.constexpr,
754 output_channels_per_group: tl.constexpr,
755 input_channels_per_group: tl.constexpr,
756 stride_height: tl.constexpr,
757 stride_width: tl.constexpr,
758 padding_height: tl.constexpr,
759 padding_width: tl.constexpr,
760 dilation_height: tl.constexpr,
761 dilation_width: tl.constexpr,
762 has_bias: tl.constexpr,
763 BLOCK_SIZE: tl.constexpr,
764):
765 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
766 mask = offsets < total_elements
768 tmp = offsets // output_width
769 ow = offsets - tmp * output_width
770 tmp2 = tmp // output_height
771 oh = tmp - tmp2 * output_height
772 n = tmp2 // output_channels
773 co = tmp2 - n * output_channels
775 group = co // output_channels_per_group
776 co_in_group = co - group * output_channels_per_group
777 accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
779 if has_bias:
780 bias = tl.load(bias_pointer + co, mask=mask, other=0.0).to(tl.float32)
781 accum += bias
783 for ci_in_group in tl.range(0, input_channels_per_group):
784 ci = group * input_channels_per_group + ci_in_group
785 for kh in tl.static_range(0, weight_height):
786 ih_unstrided = oh + padding_height - kh * dilation_height
787 ih = ih_unstrided // stride_height
788 valid_h = (ih_unstrided % stride_height == 0) & (ih >= 0)
789 valid_h = valid_h & (ih < input_height)
790 for kw in tl.static_range(0, weight_width):
791 iw_unstrided = ow + padding_width - kw * dilation_width
792 iw = iw_unstrided // stride_width
793 valid = mask & valid_h
794 valid = valid & (iw_unstrided % stride_width == 0)
795 valid = valid & (iw >= 0) & (iw < input_width)
797 input_offsets = (n * input_channels + ci) * input_height + ih
798 input_offsets = input_offsets * input_width + iw
799 weight_offsets = (
800 ci * output_channels_per_group + co_in_group
801 ) * weight_height
802 weight_offsets = (weight_offsets + kh) * weight_width + kw
803 input_values = tl.load(
804 input_pointer + input_offsets, mask=valid, other=0.0
805 ).to(tl.float32)
806 weight_values = tl.load(
807 weight_pointer + weight_offsets, mask=valid, other=0.0
808 ).to(tl.float32)
809 accum += input_values * weight_values
811 tl.store(output_pointer + offsets, accum, mask=mask)
814@libentry()
815@triton.jit
816def _conv_transpose2d_residue_static_kernel(
817 input_pointer,
818 weight_pointer,
819 bias_pointer,
820 output_pointer,
821 batch_size: tl.constexpr,
822 input_channels: tl.constexpr,
823 input_height: tl.constexpr,
824 input_width: tl.constexpr,
825 output_channels: tl.constexpr,
826 output_height: tl.constexpr,
827 output_width: tl.constexpr,
828 compact_height: tl.constexpr,
829 compact_width: tl.constexpr,
830 weight_height: tl.constexpr,
831 weight_width: tl.constexpr,
832 output_channels_per_group: tl.constexpr,
833 input_channels_per_group: tl.constexpr,
834 stride_height: tl.constexpr,
835 stride_width: tl.constexpr,
836 padding_height: tl.constexpr,
837 padding_width: tl.constexpr,
838 dilation_height: tl.constexpr,
839 dilation_width: tl.constexpr,
840 has_bias: tl.constexpr,
841 output_residue_h: tl.constexpr,
842 output_residue_w: tl.constexpr,
843 co_blocks_per_group: tl.constexpr,
844 BLOCK_NHW: tl.constexpr,
845 BLOCK_CI: tl.constexpr,
846 BLOCK_CO: tl.constexpr,
847):
848 pid_nhw = tl.program_id(0)
849 pid_gco = tl.program_id(1)
851 compact_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW)
852 compact_plane: tl.constexpr = compact_height * compact_width
853 compact_nh = compact_offsets // compact_width
854 compact_h = compact_nh % compact_height
855 compact_w = compact_offsets % compact_width
856 n = compact_offsets // compact_plane
857 oh = compact_h * stride_height + output_residue_h
858 ow = compact_w * stride_width + output_residue_w
860 group = pid_gco // co_blocks_per_group
861 pid_co_in_group = pid_gco - group * co_blocks_per_group
862 co_in_offsets = pid_co_in_group * BLOCK_CO + tl.arange(0, BLOCK_CO)
863 co_offsets = group * output_channels_per_group + co_in_offsets
865 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32)
866 if has_bias:
867 bias_values = tl.load(
868 bias_pointer + co_offsets,
869 mask=co_in_offsets < output_channels_per_group,
870 other=0.0,
871 ).to(tl.float32)
872 accum += bias_values[None, :]
874 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI)
875 height_residue: tl.constexpr = (output_residue_h + padding_height) % stride_height
876 width_residue: tl.constexpr = (output_residue_w + padding_width) % stride_width
877 for kh in tl.static_range(0, weight_height):
878 if (kh * dilation_height) % stride_height == height_residue:
879 ih_unstrided = oh + padding_height - kh * dilation_height
880 ih = ih_unstrided // stride_height
881 valid_h = (n < batch_size) & (ih_unstrided >= 0) & (ih < input_height)
882 for kw in tl.static_range(0, weight_width):
883 if (kw * dilation_width) % stride_width == width_residue:
884 iw_unstrided = ow + padding_width - kw * dilation_width
885 iw = iw_unstrided // stride_width
886 valid_hw = (
887 valid_h
888 & (iw_unstrided >= 0)
889 & (iw < input_width)
890 & (oh < output_height)
891 & (ow < output_width)
892 )
893 for ci_base in range(ci_blocks):
894 ci_in_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI)
895 ci_offsets = group * input_channels_per_group + ci_in_offsets
896 input_offsets = (
897 n[:, None] * input_channels + ci_offsets[None, :]
898 ) * input_height
899 input_offsets = (
900 input_offsets + ih[:, None]
901 ) * input_width + iw[:, None]
902 weight_offsets = (
903 ci_offsets[:, None] * output_channels_per_group
904 + co_in_offsets[None, :]
905 ) * weight_height
906 weight_offsets = (weight_offsets + kh) * weight_width + kw
907 input_mask = valid_hw[:, None] & (
908 ci_in_offsets[None, :] < input_channels_per_group
909 )
910 weight_mask = (
911 ci_in_offsets[:, None] < input_channels_per_group
912 ) & (co_in_offsets[None, :] < output_channels_per_group)
913 input_block = tl.load(
914 input_pointer + input_offsets, mask=input_mask, other=0.0
915 )
916 weight_block = tl.load(
917 weight_pointer + weight_offsets, mask=weight_mask, other=0.0
918 )
919 accum += tl.dot(
920 input_block,
921 weight_block,
922 input_precision="tf32x3",
923 )
925 output_offsets = n[:, None] * output_channels + co_offsets[None, :]
926 output_offsets = (output_offsets * output_height + oh[:, None]) * output_width
927 output_offsets = output_offsets + ow[:, None]
928 output_mask = (
929 (n[:, None] < batch_size)
930 & (oh[:, None] < output_height)
931 & (ow[:, None] < output_width)
932 & (co_in_offsets[None, :] < output_channels_per_group)
933 & (co_offsets[None, :] < output_channels)
934 )
935 tl.store(output_pointer + output_offsets, accum, mask=output_mask)
938@libentry()
939@triton.jit
940def _conv_transpose2d_scatter_init_kernel(
941 bias_pointer,
942 output_pointer,
943 total_elements: tl.constexpr,
944 output_channels: tl.constexpr,
945 output_height: tl.constexpr,
946 output_width: tl.constexpr,
947 has_bias: tl.constexpr,
948 BLOCK_SIZE: tl.constexpr,
949):
950 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
951 mask = offsets < total_elements
952 values = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
953 if has_bias:
954 spatial_size: tl.constexpr = output_height * output_width
955 co = (offsets // spatial_size) % output_channels
956 values = tl.load(bias_pointer + co, mask=mask, other=0.0).to(tl.float32)
957 tl.store(output_pointer + offsets, values, mask=mask)
960@libentry()
961@triton.jit
962def _conv_transpose2d_scatter_no_overlap_kernel(
963 input_pointer,
964 weight_pointer,
965 bias_pointer,
966 output_pointer,
967 batch_size: tl.constexpr,
968 input_channels: tl.constexpr,
969 input_height: tl.constexpr,
970 input_width: tl.constexpr,
971 output_channels: tl.constexpr,
972 output_height: tl.constexpr,
973 output_width: tl.constexpr,
974 weight_height: tl.constexpr,
975 weight_width: tl.constexpr,
976 output_channels_per_group: tl.constexpr,
977 input_channels_per_group: tl.constexpr,
978 stride_height: tl.constexpr,
979 stride_width: tl.constexpr,
980 padding_height: tl.constexpr,
981 padding_width: tl.constexpr,
982 dilation_height: tl.constexpr,
983 dilation_width: tl.constexpr,
984 has_bias: tl.constexpr,
985 BLOCK_NHW: tl.constexpr,
986 BLOCK_CI: tl.constexpr,
987 BLOCK_CO: tl.constexpr,
988):
989 pid_nhw = tl.program_id(0)
990 pid_co = tl.program_id(1)
991 pid_gkk = tl.program_id(2)
993 kw = pid_gkk % weight_width
994 tmp = pid_gkk // weight_width
995 kh = tmp % weight_height
996 group = tmp // weight_height
998 nhw_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW)
999 iw = nhw_offsets % input_width
1000 tmp = nhw_offsets // input_width
1001 ih = tmp % input_height
1002 n = tmp // input_height
1004 oh = ih * stride_height - padding_height + kh * dilation_height
1005 ow = iw * stride_width - padding_width + kw * dilation_width
1006 valid_nhw = (nhw_offsets < batch_size * input_height * input_width) & (
1007 n < batch_size
1008 )
1009 valid_nhw = valid_nhw & (oh >= 0) & (oh < output_height)
1010 valid_nhw = valid_nhw & (ow >= 0) & (ow < output_width)
1012 co_in_group = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO)
1013 co = group * output_channels_per_group + co_in_group
1014 ci_in_group_base = tl.arange(0, BLOCK_CI)
1016 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32)
1017 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI)
1018 for ci_block in range(ci_blocks):
1019 ci_in_group = ci_block * BLOCK_CI + ci_in_group_base
1020 ci = group * input_channels_per_group + ci_in_group
1021 input_offsets = (n[:, None] * input_channels + ci[None, :]) * input_height
1022 input_offsets = (input_offsets + ih[:, None]) * input_width + iw[:, None]
1023 weight_offsets = (
1024 ci[:, None] * output_channels_per_group + co_in_group[None, :]
1025 ) * weight_height
1026 weight_offsets = (weight_offsets + kh) * weight_width + kw
1028 ci_mask = ci_in_group < input_channels_per_group
1029 co_mask = co_in_group < output_channels_per_group
1030 input_block = tl.load(
1031 input_pointer + input_offsets,
1032 mask=valid_nhw[:, None] & ci_mask[None, :],
1033 other=0.0,
1034 )
1035 weight_block = tl.load(
1036 weight_pointer + weight_offsets,
1037 mask=ci_mask[:, None] & co_mask[None, :],
1038 other=0.0,
1039 )
1040 accum += tl.dot(
1041 input_block,
1042 weight_block,
1043 input_precision="tf32x3",
1044 )
1046 if has_bias:
1047 bias = tl.load(
1048 bias_pointer + co,
1049 mask=co_in_group < output_channels_per_group,
1050 other=0.0,
1051 ).to(tl.float32)
1052 accum += bias[None, :]
1054 output_offsets = (n[:, None] * output_channels + co[None, :]) * output_height
1055 output_offsets = (output_offsets + oh[:, None]) * output_width + ow[:, None]
1056 output_mask = valid_nhw[:, None] & (
1057 co_in_group[None, :] < output_channels_per_group
1058 )
1059 tl.store(output_pointer + output_offsets, accum, mask=output_mask)
1062@libentry()
1063@triton.jit
1064def _conv_transpose2d_1x1_kernel(
1065 input_pointer,
1066 weight_pointer,
1067 bias_pointer,
1068 output_pointer,
1069 batch_size: tl.constexpr,
1070 input_channels: tl.constexpr,
1071 input_height: tl.constexpr,
1072 input_width: tl.constexpr,
1073 output_channels: tl.constexpr,
1074 output_channels_per_group: tl.constexpr,
1075 input_channels_per_group: tl.constexpr,
1076 has_bias: tl.constexpr,
1077 co_blocks_per_group: tl.constexpr,
1078 BLOCK_NHW: tl.constexpr,
1079 BLOCK_CI: tl.constexpr,
1080 BLOCK_CO: tl.constexpr,
1081):
1082 pid_nhw = tl.program_id(0)
1083 pid_gco = tl.program_id(1)
1085 group = pid_gco // co_blocks_per_group
1086 pid_co_in_group = pid_gco - group * co_blocks_per_group
1087 co_in_offsets = pid_co_in_group * BLOCK_CO + tl.arange(0, BLOCK_CO)
1088 co_offsets = group * output_channels_per_group + co_in_offsets
1090 nhw_offsets = pid_nhw * BLOCK_NHW + tl.arange(0, BLOCK_NHW)
1091 iw = nhw_offsets % input_width
1092 tmp = nhw_offsets // input_width
1093 ih = tmp % input_height
1094 n = tmp // input_height
1095 valid_nhw = (nhw_offsets < batch_size * input_height * input_width) & (
1096 n < batch_size
1097 )
1099 accum = tl.zeros((BLOCK_NHW, BLOCK_CO), dtype=tl.float32)
1100 if has_bias:
1101 bias_values = tl.load(
1102 bias_pointer + co_offsets,
1103 mask=co_in_offsets < output_channels_per_group,
1104 other=0.0,
1105 ).to(tl.float32)
1106 accum += bias_values[None, :]
1108 ci_blocks: tl.constexpr = tl.cdiv(input_channels_per_group, BLOCK_CI)
1109 for ci_base in range(ci_blocks):
1110 ci_in_offsets = ci_base * BLOCK_CI + tl.arange(0, BLOCK_CI)
1111 ci_offsets = group * input_channels_per_group + ci_in_offsets
1112 input_offsets = n[:, None] * input_channels + ci_offsets[None, :]
1113 input_offsets = (input_offsets * input_height + ih[:, None]) * input_width
1114 input_offsets = input_offsets + iw[:, None]
1115 weight_offsets = (
1116 ci_offsets[:, None] * output_channels_per_group + co_in_offsets[None, :]
1117 )
1118 ci_mask = ci_in_offsets < input_channels_per_group
1119 co_mask = co_in_offsets < output_channels_per_group
1120 input_block = tl.load(
1121 input_pointer + input_offsets,
1122 mask=valid_nhw[:, None] & ci_mask[None, :],
1123 other=0.0,
1124 )
1125 weight_block = tl.load(
1126 weight_pointer + weight_offsets,
1127 mask=ci_mask[:, None] & co_mask[None, :],
1128 other=0.0,
1129 )
1130 accum += tl.dot(input_block, weight_block, input_precision="tf32x3")
1132 output_offsets = n[:, None] * output_channels + co_offsets[None, :]
1133 output_offsets = (output_offsets * input_height + ih[:, None]) * input_width
1134 output_offsets = output_offsets + iw[:, None]
1135 output_mask = valid_nhw[:, None] & (
1136 co_in_offsets[None, :] < output_channels_per_group
1137 )
1138 tl.store(output_pointer + output_offsets, accum, mask=output_mask)
1141def _can_use_pointwise_1x1(
1142 weight,
1143 stride_h,
1144 stride_w,
1145 padding_h,
1146 padding_w,
1147 output_padding_h,
1148 output_padding_w,
1149):
1150 return (
1151 weight.shape[2] == 1
1152 and weight.shape[3] == 1
1153 and stride_h == 1
1154 and stride_w == 1
1155 and padding_h == 0
1156 and padding_w == 0
1157 and output_padding_h == 0
1158 and output_padding_w == 0
1159 )
1162def _conv_transpose2d_pointwise_1x1(input, weight, bias, groups):
1163 batch, input_channels, input_height, input_width = input.shape
1164 _, output_channels_per_group, _weight_height, _weight_width = weight.shape
1165 output_channels = output_channels_per_group * groups
1166 output = torch.empty(
1167 (batch, output_channels, input_height, input_width),
1168 device=input.device,
1169 dtype=input.dtype,
1170 )
1171 if output.numel() == 0:
1172 return output
1174 input_channels_per_group = input_channels // groups
1175 block_nhw = 128 if input.dtype is not torch.float32 else 64
1176 block_ci = 16 if input.dtype is torch.float32 else 32
1177 if input_channels_per_group <= 16:
1178 block_ci = 16
1179 block_co = 16 if output_channels_per_group <= 16 else 32
1180 co_blocks_per_group = triton.cdiv(output_channels_per_group, block_co)
1181 grid = (
1182 triton.cdiv(batch * input_height * input_width, block_nhw),
1183 groups * co_blocks_per_group,
1184 )
1185 bias_pointer = bias if bias is not None else input
1186 _conv_transpose2d_1x1_kernel[grid](
1187 input,
1188 weight,
1189 bias_pointer,
1190 output,
1191 batch,
1192 input_channels,
1193 input_height,
1194 input_width,
1195 output_channels,
1196 output_channels_per_group,
1197 input_channels_per_group,
1198 bias is not None,
1199 co_blocks_per_group,
1200 BLOCK_NHW=block_nhw,
1201 BLOCK_CI=block_ci,
1202 BLOCK_CO=block_co,
1203 num_warps=4,
1204 )
1205 return output
1208def _conv_transpose2d_scatter_no_overlap(
1209 input,
1210 weight,
1211 bias,
1212 stride_h,
1213 stride_w,
1214 padding_h,
1215 padding_w,
1216 dilation_h,
1217 dilation_w,
1218 output_padding_h,
1219 output_padding_w,
1220 groups,
1221):
1222 batch, input_channels, input_height, input_width = input.shape
1223 _, output_channels_per_group, weight_height, weight_width = weight.shape
1224 output_channels = output_channels_per_group * groups
1225 output_height = (
1226 (input_height - 1) * stride_h
1227 - 2 * padding_h
1228 + dilation_h * (weight_height - 1)
1229 + output_padding_h
1230 + 1
1231 )
1232 output_width = (
1233 (input_width - 1) * stride_w
1234 - 2 * padding_w
1235 + dilation_w * (weight_width - 1)
1236 + output_padding_w
1237 + 1
1238 )
1239 output = torch.empty(
1240 (batch, output_channels, output_height, output_width),
1241 device=input.device,
1242 dtype=input.dtype,
1243 )
1244 total_elements = output.numel()
1245 if total_elements == 0:
1246 return output
1248 init_block = 1024
1249 bias_pointer = bias if bias is not None else input
1250 _conv_transpose2d_scatter_init_kernel[(triton.cdiv(total_elements, init_block),)](
1251 bias_pointer,
1252 output,
1253 total_elements,
1254 output_channels,
1255 output_height,
1256 output_width,
1257 bias is not None,
1258 BLOCK_SIZE=init_block,
1259 num_warps=4,
1260 )
1262 input_channels_per_group = input_channels // groups
1263 if input_channels_per_group <= 16:
1264 block_ci = 16
1265 elif input_channels_per_group <= 64:
1266 block_ci = 64 if input.dtype is not torch.float32 else 32
1267 else:
1268 block_ci = 64
1269 block_co = 16 if output_channels_per_group <= 16 else 32
1270 block_nhw = 32 if input.dtype is torch.float32 else 64
1271 if output_channels_per_group >= 64:
1272 block_nhw = 32
1274 input_nhw = batch * input_height * input_width
1275 grid = (
1276 triton.cdiv(input_nhw, block_nhw),
1277 triton.cdiv(output_channels_per_group, block_co),
1278 groups * weight_height * weight_width,
1279 )
1280 _conv_transpose2d_scatter_no_overlap_kernel[grid](
1281 input,
1282 weight,
1283 bias_pointer,
1284 output,
1285 batch,
1286 input_channels,
1287 input_height,
1288 input_width,
1289 output_channels,
1290 output_height,
1291 output_width,
1292 weight_height,
1293 weight_width,
1294 output_channels_per_group,
1295 input_channels_per_group,
1296 stride_h,
1297 stride_w,
1298 padding_h,
1299 padding_w,
1300 dilation_h,
1301 dilation_w,
1302 bias is not None,
1303 BLOCK_NHW=block_nhw,
1304 BLOCK_CI=block_ci,
1305 BLOCK_CO=block_co,
1306 num_warps=4,
1307 num_stages=3,
1308 )
1309 return output
1312def conv_transpose2d(
1313 input,
1314 weight,
1315 bias=None,
1316 stride=1,
1317 padding=0,
1318 output_padding=0,
1319 groups=1,
1320 dilation=1,
1321):
1322 logger.debug("GEMS CONV_TRANSPOSE2D")
1324 stride_h, stride_w = _pair(stride)
1325 padding_h, padding_w = _pair(padding)
1326 output_padding_h, output_padding_w = _pair(output_padding)
1327 dilation_h, dilation_w = _pair(dilation)
1329 input_was_unbatched = input.dim() == 3
1330 if input_was_unbatched:
1331 input = input.unsqueeze(0)
1333 if not input.is_contiguous():
1334 input = input.contiguous()
1335 if not weight.is_contiguous():
1336 weight = weight.contiguous()
1337 if bias is not None and not bias.is_contiguous():
1338 bias = bias.contiguous()
1340 output = _conv_transpose2d_4d_dispatch(
1341 input,
1342 weight,
1343 bias,
1344 stride_h,
1345 stride_w,
1346 padding_h,
1347 padding_w,
1348 output_padding_h,
1349 output_padding_w,
1350 groups,
1351 dilation_h,
1352 dilation_w,
1353 )
1354 if input_was_unbatched:
1355 return output.squeeze(0)
1356 return output
1359def _conv_transpose2d_4d_dispatch(
1360 input,
1361 weight,
1362 bias,
1363 stride_h,
1364 stride_w,
1365 padding_h,
1366 padding_w,
1367 output_padding_h,
1368 output_padding_w,
1369 groups,
1370 dilation_h,
1371 dilation_w,
1372):
1373 if _can_use_stride2_pad1_3x3_direct(
1374 input,
1375 weight,
1376 bias,
1377 stride_h,
1378 stride_w,
1379 padding_h,
1380 padding_w,
1381 output_padding_h,
1382 output_padding_w,
1383 groups,
1384 dilation_h,
1385 dilation_w,
1386 ):
1387 return _conv_transpose2d_stride2_pad1_3x3(input, weight)
1389 direct_tiled_family_params = _direct_tiled_family_params(
1390 input,
1391 weight,
1392 bias,
1393 stride_h,
1394 stride_w,
1395 padding_h,
1396 padding_w,
1397 output_padding_h,
1398 output_padding_w,
1399 groups,
1400 dilation_h,
1401 dilation_w,
1402 )
1403 if _can_use_direct_tiled_family(
1404 input, direct_tiled_family_params, output_padding_h
1405 ):
1406 return _conv_transpose2d_direct(
1407 input,
1408 weight,
1409 stride_h,
1410 stride_w,
1411 padding_h,
1412 padding_w,
1413 dilation_h,
1414 dilation_w,
1415 output_padding_h,
1416 output_padding_w,
1417 )
1419 if _validate_conv_transpose2d_args(
1420 input,
1421 weight,
1422 bias,
1423 stride_h,
1424 stride_w,
1425 padding_h,
1426 padding_w,
1427 output_padding_h,
1428 output_padding_w,
1429 groups,
1430 dilation_h,
1431 dilation_w,
1432 ):
1433 if _can_use_pointwise_1x1(
1434 weight,
1435 stride_h,
1436 stride_w,
1437 padding_h,
1438 padding_w,
1439 output_padding_h,
1440 output_padding_w,
1441 ):
1442 return _conv_transpose2d_pointwise_1x1(input, weight, bias, groups)
1443 if _can_use_scatter_no_overlap(
1444 input,
1445 weight,
1446 stride_h,
1447 stride_w,
1448 dilation_h,
1449 dilation_w,
1450 groups,
1451 ):
1452 return _conv_transpose2d_scatter_no_overlap(
1453 input,
1454 weight,
1455 bias,
1456 stride_h,
1457 stride_w,
1458 padding_h,
1459 padding_w,
1460 dilation_h,
1461 dilation_w,
1462 output_padding_h,
1463 output_padding_w,
1464 groups,
1465 )
1466 return _conv_transpose2d_general(
1467 input,
1468 weight,
1469 bias,
1470 stride_h,
1471 stride_w,
1472 padding_h,
1473 padding_w,
1474 dilation_h,
1475 dilation_w,
1476 output_padding_h,
1477 output_padding_w,
1478 groups,
1479 )
1481 return _unsupported_conv_transpose2d(
1482 input,
1483 weight,
1484 bias,
1485 stride_h,
1486 stride_w,
1487 padding_h,
1488 padding_w,
1489 output_padding_h,
1490 output_padding_w,
1491 groups,
1492 dilation_h,
1493 dilation_w,
1494 )
1497def _select_stride2_pad1_3x3_schedule(input_dtype, input_channels, output_channels):
1498 block_nhw = 64
1499 block_ci = 32
1500 block_co = 32
1501 num_warps = 4
1503 if input_dtype is torch.float32:
1504 block_ci = 16
1505 block_co = 16
1506 elif input_channels <= 32 and output_channels >= 64:
1507 block_nhw = 128
1508 block_ci = 16
1509 block_co = 64
1510 elif input_channels >= 64 and output_channels <= 32:
1511 block_nhw = 64
1512 block_ci = 32
1513 block_co = 32
1514 num_warps = 8
1515 elif input_dtype is torch.bfloat16 and input_channels >= 128:
1516 block_nhw = 128
1517 block_ci = 16
1518 block_co = 16
1519 num_warps = 8
1521 return block_nhw, block_ci, block_co, num_warps
1524def _conv_transpose2d_stride2_pad1_3x3(input, weight):
1525 batch, input_channels, input_height, input_width = input.shape
1526 _, output_channels, _weight_height, _weight_width = weight.shape
1527 output_height = input_height * 2 - 1
1528 output_width = input_width * 2 - 1
1529 output = torch.empty(
1530 (batch, output_channels, output_height, output_width),
1531 device=input.device,
1532 dtype=input.dtype,
1533 )
1534 if output.numel() == 0:
1535 return output
1537 block_nhw, block_ci, block_co, num_warps = _select_stride2_pad1_3x3_schedule(
1538 input.dtype,
1539 input_channels,
1540 output_channels,
1541 )
1542 compact_height = (output_height + 1) // 2
1543 compact_width = (output_width + 1) // 2
1544 grid = (
1545 triton.cdiv(batch * compact_height * compact_width, block_nhw) * 4,
1546 triton.cdiv(output_channels, block_co),
1547 )
1548 _conv_transpose2d_stride2_pad1_3x3_kernel[grid](
1549 input,
1550 weight,
1551 output,
1552 batch,
1553 input_height,
1554 input_width,
1555 output_channels,
1556 output_height,
1557 output_width,
1558 compact_height,
1559 compact_width,
1560 *input.stride(),
1561 *weight.stride(),
1562 *output.stride(),
1563 input_channels,
1564 BLOCK_NHW=block_nhw,
1565 BLOCK_CI=block_ci,
1566 BLOCK_CO=block_co,
1567 num_warps=num_warps,
1568 )
1569 return output
1572def _select_conv_transpose2d_direct_schedule(
1573 input_dtype,
1574 input_channels,
1575 output_channels,
1576 weight_height,
1577 weight_width,
1578 stride_h,
1579 output_padding_h,
1580):
1581 block_nhw, block_ci, block_co, num_warps = _DIRECT_TILED_DEFAULT_SCHEDULE
1583 if input_dtype is torch.bfloat16:
1584 if stride_h >= 3:
1585 block_nhw = 128
1586 block_ci = 16
1587 block_co = 16
1588 num_warps = 8
1589 elif input_channels >= 128:
1590 block_nhw = 256
1591 block_ci = 16
1592 block_co = 16
1593 num_warps = 8
1594 elif weight_height >= 5 or weight_width >= 5:
1595 block_nhw = 128
1596 block_ci = 16
1597 elif input_channels >= 64 and output_channels <= 32:
1598 block_ci = 64
1599 if stride_h == 1:
1600 num_warps = 8
1601 elif input_dtype is torch.float16:
1602 if stride_h >= 3:
1603 block_nhw = 128
1604 block_ci = 16
1605 block_co = 16
1606 num_warps = 8
1607 elif weight_height >= 5 or weight_width >= 5:
1608 block_nhw = 128
1609 block_ci = 16
1610 elif input_channels >= 64 and output_channels <= 32:
1611 block_ci = 64
1612 if stride_h == 1:
1613 num_warps = 8
1614 elif input_dtype is torch.float32 and (weight_height >= 5 or weight_width >= 5):
1615 block_ci = 16
1616 elif input_channels >= 64 and output_channels <= 32:
1617 block_ci = 64
1618 if stride_h == 1:
1619 num_warps = 8
1620 if (
1621 stride_h == 1
1622 and weight_height <= 3
1623 and weight_width <= 3
1624 and input_channels >= 64
1625 and output_channels <= 64
1626 ):
1627 block_nhw = 256
1628 block_ci = 16
1629 block_co = 32
1630 num_warps = 8
1631 elif (
1632 stride_h == 2
1633 and weight_height <= 3
1634 and weight_width <= 3
1635 and input_channels <= 32
1636 and output_channels >= 64
1637 ):
1638 block_nhw = 128
1639 block_ci = 16
1640 block_co = 64
1641 num_warps = 4
1642 elif (
1643 stride_h == 2
1644 and weight_height <= 3
1645 and weight_width <= 3
1646 and input_channels >= 64
1647 and output_channels <= 32
1648 ):
1649 block_nhw = 32
1650 block_ci = 16
1651 block_co = 32
1652 num_warps = 8
1653 if output_padding_h:
1654 block_nhw = min(block_nhw, 128)
1655 block_ci = min(block_ci, 32)
1657 return block_nhw, block_ci, block_co, num_warps
1660def _conv_transpose2d_direct(
1661 input,
1662 weight,
1663 stride_h,
1664 stride_w,
1665 padding_h,
1666 padding_w,
1667 dilation_h,
1668 dilation_w,
1669 output_padding_h,
1670 output_padding_w,
1671):
1672 batch, input_channels, input_height, input_width = input.shape
1673 _, output_channels, weight_height, weight_width = weight.shape
1674 output_height = (
1675 (input_height - 1) * stride_h
1676 - 2 * padding_h
1677 + dilation_h * (weight_height - 1)
1678 + output_padding_h
1679 + 1
1680 )
1681 output_width = (
1682 (input_width - 1) * stride_w
1683 - 2 * padding_w
1684 + dilation_w * (weight_width - 1)
1685 + output_padding_w
1686 + 1
1687 )
1688 output = torch.empty(
1689 (batch, output_channels, output_height, output_width),
1690 device=input.device,
1691 dtype=input.dtype,
1692 )
1693 compact_height = triton.cdiv(output_height, stride_h)
1694 compact_width = triton.cdiv(output_width, stride_w)
1695 max_sub_spatial = batch * compact_height * compact_width
1696 n_subgrids = stride_h * stride_w
1698 block_nhw, block_ci, block_co, num_warps = _select_conv_transpose2d_direct_schedule(
1699 input.dtype,
1700 input_channels,
1701 output_channels,
1702 weight_height,
1703 weight_width,
1704 stride_h,
1705 output_padding_h,
1706 )
1708 grid = (
1709 triton.cdiv(max_sub_spatial, block_nhw),
1710 triton.cdiv(output_channels, block_co),
1711 n_subgrids,
1712 )
1713 _conv_transpose2d_direct_kernel[grid](
1714 input,
1715 weight,
1716 output,
1717 batch,
1718 input_height,
1719 input_width,
1720 output_channels,
1721 output_height,
1722 output_width,
1723 *input.stride(),
1724 *weight.stride(),
1725 *output.stride(),
1726 input_channels,
1727 weight_height,
1728 weight_width,
1729 stride_h,
1730 stride_w,
1731 padding_h,
1732 padding_w,
1733 BLOCK_NHW=block_nhw,
1734 BLOCK_CI=block_ci,
1735 BLOCK_CO=block_co,
1736 num_warps=num_warps,
1737 )
1738 return output
1741def _conv_transpose2d_general(
1742 input,
1743 weight,
1744 bias,
1745 stride_h,
1746 stride_w,
1747 padding_h,
1748 padding_w,
1749 dilation_h,
1750 dilation_w,
1751 output_padding_h,
1752 output_padding_w,
1753 groups,
1754):
1755 return _conv_transpose2d_residue(
1756 input,
1757 weight,
1758 bias,
1759 stride_h,
1760 stride_w,
1761 padding_h,
1762 padding_w,
1763 dilation_h,
1764 dilation_w,
1765 output_padding_h,
1766 output_padding_w,
1767 groups,
1768 )
1771def _conv_transpose2d_residue(
1772 input,
1773 weight,
1774 bias,
1775 stride_h,
1776 stride_w,
1777 padding_h,
1778 padding_w,
1779 dilation_h,
1780 dilation_w,
1781 output_padding_h,
1782 output_padding_w,
1783 groups,
1784):
1785 batch, input_channels, input_height, input_width = input.shape
1786 _, output_channels_per_group, weight_height, weight_width = weight.shape
1787 output_channels = output_channels_per_group * groups
1788 output_height = (
1789 (input_height - 1) * stride_h
1790 - 2 * padding_h
1791 + dilation_h * (weight_height - 1)
1792 + output_padding_h
1793 + 1
1794 )
1795 output_width = (
1796 (input_width - 1) * stride_w
1797 - 2 * padding_w
1798 + dilation_w * (weight_width - 1)
1799 + output_padding_w
1800 + 1
1801 )
1802 output = torch.empty(
1803 (batch, output_channels, output_height, output_width),
1804 device=input.device,
1805 dtype=input.dtype,
1806 )
1807 total_elements = output.numel()
1808 if total_elements == 0:
1809 return output
1811 input_channels_per_group = input_channels // groups
1812 if (
1813 input.dtype in _TRITON_DIRECT_LOWP_DTYPES
1814 and weight_height >= 5
1815 and weight_width >= 5
1816 and stride_h == 2
1817 and stride_w == 2
1818 and dilation_h == 1
1819 and dilation_w == 1
1820 and input_channels_per_group >= 64
1821 and output_channels_per_group <= 32
1822 ):
1823 block_nhw = 256
1824 block_ci = 16
1825 block_co = 32
1826 co_blocks_per_group = triton.cdiv(output_channels_per_group, block_co)
1827 bias_pointer = bias if bias is not None else input
1828 for residue_h in range(stride_h):
1829 compact_height = (output_height + stride_h - 1 - residue_h) // stride_h
1830 for residue_w in range(stride_w):
1831 compact_width = (output_width + stride_w - 1 - residue_w) // stride_w
1832 grid = (
1833 triton.cdiv(batch * compact_height * compact_width, block_nhw),
1834 groups * co_blocks_per_group,
1835 )
1836 _conv_transpose2d_residue_static_kernel[grid](
1837 input,
1838 weight,
1839 bias_pointer,
1840 output,
1841 batch,
1842 input_channels,
1843 input_height,
1844 input_width,
1845 output_channels,
1846 output_height,
1847 output_width,
1848 compact_height,
1849 compact_width,
1850 weight_height,
1851 weight_width,
1852 output_channels_per_group,
1853 input_channels_per_group,
1854 stride_h,
1855 stride_w,
1856 padding_h,
1857 padding_w,
1858 dilation_h,
1859 dilation_w,
1860 bias is not None,
1861 residue_h,
1862 residue_w,
1863 co_blocks_per_group,
1864 BLOCK_NHW=block_nhw,
1865 BLOCK_CI=block_ci,
1866 BLOCK_CO=block_co,
1867 num_warps=4,
1868 num_stages=2,
1869 )
1870 return output
1872 block_nhw = 64
1873 block_ci = 32
1874 block_co = 32
1875 num_warps = 4
1876 if input.dtype is torch.float32:
1877 block_ci = 16
1878 block_co = 16
1879 elif input_channels_per_group <= 16:
1880 block_ci = 16
1881 if output_channels_per_group <= 16:
1882 block_co = 16
1883 if (
1884 weight_height >= 5
1885 and weight_width >= 5
1886 and stride_h == 2
1887 and stride_w == 2
1888 and input_channels_per_group >= 64
1889 and output_channels_per_group <= 32
1890 ):
1891 block_nhw = 128
1892 block_ci = 64 if input.dtype is not torch.float32 else 32
1893 block_co = 16
1894 num_warps = 8
1895 if stride_h * stride_w >= 4 and input.dtype is not torch.float32:
1896 block_nhw = 128
1897 num_warps = 8
1899 compact_height = triton.cdiv(output_height, stride_h)
1900 compact_width = triton.cdiv(output_width, stride_w)
1901 max_sub_spatial = batch * compact_height * compact_width
1902 n_subgrids = stride_h * stride_w
1903 co_blocks_per_group = triton.cdiv(output_channels_per_group, block_co)
1904 grid = (
1905 triton.cdiv(max_sub_spatial, block_nhw),
1906 co_blocks_per_group,
1907 groups * n_subgrids,
1908 )
1909 bias_pointer = bias if bias is not None else input
1910 _conv_transpose2d_residue_kernel[grid](
1911 input,
1912 weight,
1913 bias_pointer,
1914 output,
1915 batch,
1916 input_channels,
1917 input_height,
1918 input_width,
1919 output_channels,
1920 output_height,
1921 output_width,
1922 weight_height,
1923 weight_width,
1924 output_channels_per_group,
1925 input_channels // groups,
1926 stride_h,
1927 stride_w,
1928 padding_h,
1929 padding_w,
1930 dilation_h,
1931 dilation_w,
1932 bias is not None,
1933 n_subgrids,
1934 BLOCK_NHW=block_nhw,
1935 BLOCK_CI=block_ci,
1936 BLOCK_CO=block_co,
1937 num_warps=num_warps,
1938 )
1939 return output