Unflatten¶
- class torch.nn.Unflatten(dim, unflattened_size)[source]¶
 Unflattens a tensor dim expanding it to a desired shape. For use with
Sequential.dimspecifies the dimension of the input tensor to be unflattened, and it can be either int or str when Tensor or NamedTensor is used, respectively.unflattened_sizeis the new shape of the unflattened dimension of the tensor and it can be a tuple of ints or a list of ints or torch.Size for Tensor input; a NamedShape (tuple of (name, size) tuples) for NamedTensor input.
- Shape:
 Input: , where is the size at dimension
dimand means any number of dimensions including none.Output: , where =
unflattened_sizeand .
- Parameters
 
Examples
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])