torch.cuda.comm.scatter#
- torch.cuda.comm.scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None)[source]#
Scatters tensor across multiple GPUs.
- Parameters
tensor (Tensor) – tensor to scatter. Can be on CPU or GPU.
devices (Iterable[torch.device, str or int], optional) – an iterable of GPU devices, among which to scatter.
chunk_sizes (Iterable[int], optional) – sizes of chunks to be placed on each device. It should match
devicesin length and sums totensor.size(dim). If not specified,tensorwill be divided into equal chunks.dim (int, optional) – A dimension along which to chunk
tensor. Default:0.streams (Iterable[torch.cuda.Stream], optional) – an iterable of Streams, among which to execute the scatter. If not specified, the default stream will be utilized.
out (Sequence[Tensor], optional, keyword-only) – the GPU tensors to store output results. Sizes of these tensors must match that of
tensor, except fordim, where the total size must sum totensor.size(dim).
Note
Exactly one of
devicesandoutmust be specified. Whenoutis specified,chunk_sizesmust not be specified and will be inferred from sizes ofout.- Returns
- If
devicesis specified, a tuple containing chunks of
tensor, placed ondevices.
- If
- If
outis specified, a tuple containing
outtensors, each containing a chunk oftensor.
- If