torch.split#
- torch.split(tensor, split_size_or_sections, dim=0)[source]#
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If
split_size_or_sectionsis an integer type, thentensorwill be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimensiondimis not divisible bysplit_size.If
split_size_or_sectionsis a list, thentensorwill be split intolen(split_size_or_sections)chunks with sizes indimaccording tosplit_size_or_sections.- Parameters
- Return type
tuple[torch.Tensor, …]
Example:
>>> a = torch.arange(10).reshape(5, 2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) >>> torch.split(a, [1, 4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))