• Docs >
  • Distributed RPC Framework
Shortcuts

Distributed RPC Framework

The distributed RPC framework provides mechanisms for multi-machine model training through a set of primitives to allow for remote communication, and a higher-level API to automatically differentiate models split across several machines.

Warning

The RPC API is experimental and subject to change.

RPC and RRef Framework

Before using RPC and distributed autograd primitives, initialization must take place. To initialize the RPC framework we need to use init_rpc() which would initialize the RPC framework, RRef framework and distributed autograd. By default, this will also initialize the ProcessGroup (init_process_group()) backend for RPC communication. The ProcessGroup backend internally uses gloo for communication.

torch.distributed.rpc.init_rpc(name, backend=BackendType.PROCESS_GROUP, rank=-1, world_size=None, rpc_backend_options=None)[source]

Initializes RPC primitives such as the local RPC agent and distributed autograd.

Initializes the local RPC agent which immediately makes the current process ready to send and receive RPCs. This method also properly initializes a default process group backend that uses gloo for collective communication.

Parameters
  • backend (Enum) – type of RPC backend implementation. Currently, process group backend is the only available backend implementation. (default: RpcBackend.PROCESS_GROUP).

  • name (str) – a globally unique name of this node. (e.g., Trainer3, ParameterServer2, Master, Worker1) Name can only contain number, alphabet, underscore, and/or dash, and must be shorter than 128 characters.

  • rank (python:int) – a globally unique id/rank of this node.

  • world_size (python:int) – The number of workers in the group.

  • rpc_backend_options (RpcBackendOptions) – The options passed to RpcAgent consturctor.

RRef

An RRef (Remote REFerence) is a reference to a value of some type T (e.g. Tensor) on a remote worker. This handle keeps the referenced remote value alive on the owner, but there is no implication that the value will be transferred to the local worker in the future. RRefs can be used in multi-machine training by holding references to nn.Modules that exist on other workers, and calling the appropriate functions to retrieve or modify their parameters during training. See Remote Reference Protocol for more details.

class torch.distributed.rpc.RRef

A class encapsulating a reference to a value of some type on a remote worker. This handle will keep the referenced remote value alive on the worker.

is_owner(self: torch.distributed.rpc.RRef) → bool

Returns whether or not the current node is the owner of this RRef.

local_value(self: torch.distributed.rpc.RRef) → object

If the current node is the owner, returns a reference to the local value. Otherwise, throws an exception.

owner(self: torch.distributed.rpc.RRef) → torch.distributed.rpc.WorkerInfo

Returns worker information of the node that owns this RRef.

to_here(self: torch.distributed.rpc.RRef) → object

Blocking call that copies the value of the RRef from the owner to the local node and returns it. If the current node is the owner, returns a reference to the local value.

RPC and RRef primitives

This library provides primitives allowing users to create and modify references (RRefs) to remote data as well as remotely execute functions.

torch.distributed.rpc.rpc_sync(to, func, args=None, kwargs=None)[source]

Make a blocking RPC call to run function func on worker to. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe.

Parameters
  • to (str or WorkerInfo) – id or name of the destination worker.

  • func (callable) – any callable function. builtin functions (like torch.add()) can be sent over RPC more efficiently.

  • args (tuple) – the argument tuple for the func invocation.

  • kwargs (dict) – is a dictionary of keyword arguments for the func invocation.

Returns

Returns the result of running func on args and kwargs.

Example:

On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.shutdown()

On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
torch.distributed.rpc.rpc_async(to, func, args=None, kwargs=None)[source]

Make a non-blocking RPC call to run function func on worker to. RPC messages are sent and received in parallel to execution of Python code. This method is thread-safe. This method will immediately return a torch.distributed.FutureMessage that can be awaited on.

Parameters
  • to (str or WorkerInfo) – id or name of the destination worker.

  • func (callable) – any callable function. builtin functions (like torch.add()) can be sent over RPC more efficiently.

  • args (tuple) – the argument tuple for the func invocation.

  • kwargs (dict) – is a dictionary of keyword arguments for the func invocation.

Returns

Returns a torch.distributed.FutureMessage object that can be waited on. When completed, the return value of func on args and kwargs can be retrieved from the FutureMessage object.

Example:

On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.shutdown()

On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
torch.distributed.rpc.remote(to, func, args=None, kwargs=None)[source]

Make a remote call to run func on worker to and return an RRef to the result value immediately. Worker to will be the owner of the returned RRef, and the worker calling remote is a user. The owner manages the global reference count of its RRef, and the owner RRef is only destructed when globally there are no living references to it.

Parameters
  • to (str or WorkerInfo) – id or name of the destination worker.

  • func (callable) – builtin functions (like torch.add()).

  • args (tuple) – the argument tuple for the func invocation.

  • kwargs (dict) – is a dictionary of keyword arguments for the func invocation.

Returns

A user RRef instance to the result value. Use the blocking API torch.distributed.rpc.RRef.to_here() to retrieve the result value locally.

Example:

On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
>>> rpc.shutdown()

On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
torch.distributed.rpc.get_worker_info(worker_name=None)[source]

Get WorkerInfo of a given worker name. Use this WorkerInfo to avoid passing an expensive string on every invocation.

Parameters

worker_name (str) – the string name of a worker. If None, return the the id of the current worker. (default None)

Returns

WorkerInfo instance for the given worker_name or WorkerInfo of the current worker if worker_name is None.

torch.distributed.rpc.shutdown(graceful=True)[source]

Perform a shutdown of the RPC agent, and then destroy the RPC agent. This stops the local agent from accepting outstanding requests, and shuts down the RPC framework by terminating all RPC threads. If graceful=True, then this will block until all local and remote RPC processes reach this method and wait for all outstanding work to complete. Otherwise, if graceful=False, then this is a local shutdown, and it does not wait for other RPC processes to reach this method.

Parameters

graceful (bool) – Whether to do a graceful shutdown or not. If True, this will block until all local and remote RPC processes have reached this method and wait for all outstanding work to complete.

Example:

On worker 0:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> # do some work
>>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
>>> # ready to shutdown
>>> rpc.shutdown()

On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> # wait for worker 0 to finish work, and then shutdown.
>>> rpc.shutdown()

Distributed Autograd Framework

This module provides an RPC-based distributed autograd framework that can be used for applications such as model parallel training. In short, applications may send and receive gradient recording tensors over RPC. In the forward pass, we record when gradient recording tensors are sent over RPC and during the backward pass we use this information to perform a distributed backward pass using RPC. For more details see Distributed Autograd Design.

class torch.distributed.autograd.context[source]

Context object to wrap forward and backward passes when using distributed autograd. The context_id generated in the with statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with this context_id, which is required to correctly execute a distributed autograd pass.

Example:

>> import torch.distributed.autograd as dist_autograd
>> with dist_autograd.context() as context_id:
>>   t1 = torch.rand((3, 3), requires_grad=True)
>>   t2 = torch.rand((3, 3), requires_grad=True)
>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
>>   dist_autograd.backward([loss])
torch.distributed.autograd.backward(roots: List[Tensor]) → None

Kicks off the distributed backward pass using the provided roots. This currently implements the FAST mode algorithm which assumes all RPC messages sent in the same distributed autograd context across workers would be part of the autograd graph during the backward pass.

We use the provided roots to discover the autograd graph and compute appropriate dependencies. This method blocks until the entire autograd computation is done.

We accumulate the gradients in the appropriate torch.distributed.autograd.context on each of the nodes. The autograd context used is the current autograd context of this node when torch.distributed.autograd.backward() is called. If there is no valid autograd context, we throw an error. You can retrieve the accumulated gradients using the get_gradients() API.

Parameters

roots (list) – Tensors which represent the roots of the autograd computation. All the tensors should be scalars.

Example:

>> import torch.distributed.autograd as dist_autograd
>> with dist_autograd.context() as context_id:
>>      pred = model.forward()
>>      loss = loss_func(pred, loss)
>>      dist_autograd.backward(loss)
torch.distributed.autograd.get_gradients(context_id: int) → Dict[Tensor, Tensor]

Retrieves a map from Tensor to the appropriate gradient for that Tensor accumulated in the provided context_id as part of the distributed autograd backward pass.

Parameters

context_id (python:int) – The autograd context id for which we should retrieve the gradients.

Returns

A map where the key is the Tensor and the value is the associated gradient for that Tensor.

Example:

>> import torch.distributed.autograd as dist_autograd
>> with dist_autograd.context() as context_id:
>>      t1 = torch.rand((3, 3), requires_grad=True)
>>      t2 = torch.rand((3, 3), requires_grad=True)
>>      loss = t1 + t2
>>      dist_autograd.backward([loss.sum()])
>>      grads = dist_autograd.get_gradients(context_id)
>>      print (grads[t1])
>>      print (grads[t2])

Distributed Optimizer

torch.distributed.optim exposes DistributedOptimizer, which takes a list of remote parameters (RRef) and runs the optimizer locally on the workers where the parameters live. The distributed optimizer can use any of the local optimizer Algorithms to apply the gradients on each worker.

class torch.distributed.optim.DistributedOptimizer(optimizer_class, params_rref, *args, **kwargs)[source]

DistributedOptimizer takes remote references to parameters scattered across workers and applies the given optimizer locally for each parameter.

This class uses get_gradients() in order to retrieve the gradients for specific parameters.

Concurrent calls to step(), either from the same or different clients, will be serialized on each worker – as each worker’s optimizer can only work on one set of gradients at a time. However, there is no guarantee that the full forward-backward-optimizer sequence will execute for one client at a time. This means that the gradients being applied may not correspond to the latest forward pass executed on a given worker. Also, there is no guaranteed ordering across workers.

Parameters
  • optimizer_class (optim.Optimizer) – the class of optimizer to instantiate on each worker.

  • params_rref (list[RRef]) – list of RRefs to local or remote parameters to optimize.

  • args – arguments to pass to the optimizer constructor on each worker.

  • kwargs – arguments to pass to the optimizer constructor on each worker.

Example:

>> import torch.distributed.autograd as dist_autograd
>> import torch.distributed.rpc as rpc
>> from torch import optim
>> from torch.distributed.optim import DistributedOptimizer
>>
>> with dist_autograd.context() as context_id:
>>   # Forward pass.
>>   rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
>>   rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
>>   loss = rref1.to_here() + rref2.to_here()
>>
>>   # Backward pass.
>>   dist_autograd.backward([loss.sum()])
>>
>>   # Optimizer.
>>   dist_optim = DistributedOptimizer(
>>      optim.SGD,
>>      [rref1, rref2],
>>      lr=0.05,
>>   )
>>   dist_optim.step()
step()[source]

Performs a single optimization step.

This will call torch.optim.Optimizer.step() on each worker containing parameters to be optimized, and will block until all workers return. The current distributed autograd context will be used globally.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources