Shortcuts

Flatten

class torch.nn.Flatten(start_dim=1, end_dim=-1)[source]

Flattens a contiguous range of dims into a tensor. For use with Sequential.

Shape:
  • Input: (N,dims)(N, *dims)

  • Output: (N,dims)(N, \prod *dims) (for the default case).

Parameters
  • start_dim – first dim to flatten (default = 1).

  • end_dim – last dim to flatten (default = -1).

Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources