Distributed communication package - torch.distributed¶
torch.distributed provides an MPI-like interface for exchanging tensor data across multi-machine networks. It supports a few different backends and initialization methods.
Currently torch.distributed supports three backends, each with different capabilities. The table below shows which functions are available for use with CPU / CUDA tensors. MPI supports cuda only if the implementation used to build PyTorch supports it.
Backend | tcp |
gloo |
mpi |
|||
---|---|---|---|---|---|---|
Device | CPU | GPU | CPU | GPU | CPU | GPU |
send | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
recv | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
broadcast | ✓ | ✘ | ✓ | ✓ | ✓ | ? |
all_reduce | ✓ | ✘ | ✓ | ✓ | ✓ | ? |
reduce | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
all_gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
gather | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
scatter | ✓ | ✘ | ✘ | ✘ | ✓ | ? |
barrier | ✓ | ✘ | ✓ | ✓ | ✓ | ? |
Basics¶
The torch.distributed package provides PyTorch support and communication primitives
for multiprocess parallelism across several computation nodes running on one or more
machines. The class torch.nn.parallel.DistributedDataParallel()
builds on this
functionality to provide synchronous distributed training as a wrapper around any
PyTorch model. This differs from the kinds of parallelism provided by
Multiprocessing package - torch.multiprocessing and torch.nn.DataParallel()
in that it supports
multiple network-connected machines and in that the user must explicitly launch a separate
copy of the main training script for each process.
In the single-machine synchronous case, torch.distributed or the
torch.nn.parallel.DistributedDataParallel()
wrapper may still have advantages over other
approaches to data-parallelism, including torch.nn.DataParallel()
:
- Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes.
- Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.
Initialization¶
The package needs to be initialized using the torch.distributed.init_process_group()
function before calling any other methods. This blocks until all processes have
joined.
-
torch.distributed.
init_process_group
(backend, init_method='env://', **kwargs)[source]¶ Initializes the distributed package.
Parameters: - backend (str) – Name of the backend to use. Depending on build-time configuration
valid values include:
tcp
,mpi
andgloo
. - init_method (str, optional) – URL specifying how to initialize the package.
- world_size (int, optional) – Number of processes participating in the job.
- rank (int, optional) – Rank of the current process.
- group_name (str, optional) – Group name. See description of init methods.
To enable
backend == mpi
, PyTorch needs to built from source on a system that supports MPI.- backend (str) – Name of the backend to use. Depending on build-time configuration
valid values include:
-
torch.distributed.
get_rank
()[source]¶ Returns the rank of current process.
Rank is a unique identifier assigned to each process within a distributed group. They are always consecutive integers ranging from 0 to
world_size
.
-
torch.distributed.
get_world_size
()[source]¶ Returns the number of processes in the distributed group.
Currently three initialization methods are supported:
TCP initialization¶
There are two ways to initialize using TCP, both requiring a network address
reachable from all processes and a desired world_size
. The first way
requires specifying an address that belongs to the rank 0 process. This first way of
initialization requires that all processes have manually specified ranks.
Alternatively, the address has to be a valid IP multicast address, in which case
ranks can be assigned automatically. Multicast initialization also supports
a group_name
argument, which allows you to use the same address for multiple
jobs, as long as they use different group names.
import torch.distributed as dist
# Use address of one of the machines
dist.init_process_group(init_method='tcp://10.1.1.20:23456', rank=args.rank, world_size=4)
# or a multicast address - rank will be assigned automatically if unspecified
dist.init_process_group(init_method='tcp://[ff15:1e18:5d4c:4cf0:d02d:b659:53ba:b0a7]:23456',
world_size=4)
Environment variable initialization¶
This method will read the configuration from environment variables, allowing one to fully customize how the information is obtained. The variables to be set are:
MASTER_PORT
- required; has to be a free port on machine with rank 0MASTER_ADDR
- required (except for rank 0); address of rank 0 nodeWORLD_SIZE
- required; can be set either here, or in a call to init functionRANK
- required; can be set either here, or in a call to init function
The machine with rank 0 will be used to set up all connections.
This is the default method, meaning that init_method
does not have to be specified (or
can be env://
).
Groups¶
By default collectives operate on the default group (also called the world) and
require all processes to enter the distributed function call. However, some workloads can benefit
from more fine-grained communication. This is where distributed groups come
into play. new_group()
function can be
used to create new groups, with arbitrary subsets of all processes. It returns
an opaque group handle that can be given as a group
argument to all collectives
(collectives are distributed functions to exchange information in certain well-known programming patterns).
-
torch.distributed.
new_group
(ranks=None)[source]¶ Creates a new distributed group.
This function requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Additionally, groups should be created in the same order in all processes.
Parameters: ranks (list[int]) – List of ranks of group members. Returns: A handle of distributed group that can be given to collective calls.
Point-to-point communication¶
-
torch.distributed.
recv
(tensor, src=None)[source]¶ Receives a tensor synchronously.
Parameters: Returns: Sender rank.
isend()
and irecv()
return distributed request objects when used. In general, the type of this object is unspecified
as they should never be created manually, but they are guaranteed to support two methods:
is_completed()
- returns True if the operation has finishedwait()
- will block the process until the operation is finished.is_completed()
is guaranteed to return True once it returns.
When using the MPI backend, isend()
and irecv()
support non-overtaking, which has some guarantees on supporting message order. For more detail, see
http://mpi-forum.org/docs/mpi-2.2/mpi22-report/node54.htm#Node54
Collective functions¶
-
torch.distributed.
broadcast
(tensor, src, group=<object object>)[source]¶ Broadcasts the tensor to the whole group.
tensor
must have the same number of elements in all processes participating in the collective.Parameters:
-
torch.distributed.
all_reduce
(tensor, op=<object object>, group=<object object>)[source]¶ Reduces the tensor data across all machines in such a way that all get the final result.
After the call
tensor
is going to be bitwise identical in all processes.Parameters: - tensor (Tensor) – Input and output of the collective. The function operates in-place.
- op (optional) – One of the values from
torch.distributed.reduce_op
enum. Specifies an operation used for element-wise reductions. - group (optional) – Group of the collective.
-
torch.distributed.
reduce
(tensor, dst, op=<object object>, group=<object object>)[source]¶ Reduces the tensor data across all machines.
Only the process with rank
dst
is going to receive the final result.Parameters: - tensor (Tensor) – Input and output of the collective. The function operates in-place.
- op (optional) – One of the values from
torch.distributed.reduce_op
enum. Specifies an operation used for element-wise reductions. - group (optional) – Group of the collective.
-
torch.distributed.
all_gather
(tensor_list, tensor, group=<object object>)[source]¶ Gathers tensors from the whole group in a list.
Parameters:
-
torch.distributed.
gather
(tensor, **kwargs)[source]¶ Gathers a list of tensors in a single process.
Parameters: - tensor (Tensor) – Input tensor.
- dst (int) – Destination rank. Required in all processes except the one that is receiveing the data.
- gather_list (list[Tensor]) – List of appropriately-sized tensors to use for received data. Required only in the receiving process.
- group (optional) – Group of the collective.