torch.split¶
- torch.split(tensor, split_size_or_sections, dim=0)[source]¶
- Splits the tensor into chunks. Each chunk is a view of the original tensor. - If - split_size_or_sectionsis an integer type, then- tensorwill be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension- dimis not divisible by- split_size.- If - split_size_or_sectionsis a list, then- tensorwill be split into- len(split_size_or_sections)chunks with sizes in- dimaccording to- split_size_or_sections.- Parameters:
- Return type:
 - Example: - >>> a = torch.arange(10).reshape(5,2) >>> a tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> torch.split(a, 2) (tensor([[0, 1], [2, 3]]), tensor([[4, 5], [6, 7]]), tensor([[8, 9]])) >>> torch.split(a, [1,4]) (tensor([[0, 1]]), tensor([[2, 3], [4, 5], [6, 7], [8, 9]]))