ConvTranspose2d¶
- class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)[source][source]¶
- Applies a 2D transposed convolution operator over an input image composed of several input planes. - This module can be seen as the gradient of Conv2d with respect to its input. It is also known as a fractionally-strided convolution or a deconvolution (although it is not an actual deconvolution operation as it does not compute a true inverse of convolution). For more information, see the visualizations here and the Deconvolutional Networks paper. - This module supports TensorFloat32. - On certain ROCm devices, when using float16 inputs this module will use different precision for backward. - stridecontrols the stride for the cross-correlation.
- paddingcontrols the amount of implicit zero padding on both sides for- dilation * (kernel_size - 1) - paddingnumber of points. See note below for details.
- output_paddingcontrols the additional size added to one side of the output shape. See note below for details.
- dilationcontrols the spacing between the kernel points; also known as the à trous algorithm. It is harder to describe, but the link here has a nice visualization of what- dilationdoes.
- groupscontrols the connections between inputs and outputs.- in_channelsand- out_channelsmust both be divisible by- groups. For example,- At groups=1, all inputs are convolved to all outputs. 
- At groups=2, the operation becomes equivalent to having two conv layers side by side, each seeing half the input channels and producing half the output channels, and both subsequently concatenated. 
- At groups= - in_channels, each input channel is convolved with its own set of filters (of size ).
 
 - The parameters - kernel_size,- stride,- padding,- output_paddingcan either be:- a single - int– in which case the same value is used for the height and width dimensions
- a - tupleof two ints – in which case, the first int is used for the height dimension, and the second int for the width dimension
 - Note - The - paddingargument effectively adds- dilation * (kernel_size - 1) - paddingamount of zero padding to both sizes of the input. This is set so that when a- Conv2dand a- ConvTranspose2dare initialized with same parameters, they are inverses of each other in regard to the input and output shapes. However, when- stride > 1,- Conv2dmaps multiple input shapes to the same output shape.- output_paddingis provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that- output_paddingis only used to find output shape, but does not actually add zero-padding to output.- Note - In some circumstances when given tensors on a CUDA device and using CuDNN, this operator may select a nondeterministic algorithm to increase performance. If this is undesirable, you can try to make the operation deterministic (potentially at a performance cost) by setting - torch.backends.cudnn.deterministic = True. See Reproducibility for more information.- Parameters
- in_channels (int) – Number of channels in the input image 
- out_channels (int) – Number of channels produced by the convolution 
- stride (int or tuple, optional) – Stride of the convolution. Default: 1 
- padding (int or tuple, optional) – - dilation * (kernel_size - 1) - paddingzero-padding will be added to both sides of each dimension in the input. Default: 0
- output_padding (int or tuple, optional) – Additional size added to one side of each dimension in the output shape. Default: 0 
- groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1 
- bias (bool, optional) – If - True, adds a learnable bias to the output. Default:- True
- dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1 
 
 - Shape:
- Input: or 
- Output: or , where 
 
 - Variables
 - Examples: - >>> # With square kernels and equal stride >>> m = nn.ConvTranspose2d(16, 33, 3, stride=2) >>> # non-square kernels and unequal stride and with padding >>> m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)) >>> input = torch.randn(20, 16, 50, 100) >>> output = m(input) >>> # exact output size can be also specified as an argument >>> input = torch.randn(1, 16, 12, 12) >>> downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1) >>> upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1) >>> h = downsample(input) >>> h.size() torch.Size([1, 16, 6, 6]) >>> output = upsample(h, output_size=input.size()) >>> output.size() torch.Size([1, 16, 12, 12])