Shortcuts

Source code for torchrl.weight_update._rpc

from __future__ import annotations

import time
import weakref
from typing import Any

from torchrl._utils import logger as torchrl_logger

from torchrl.weight_update.utils import _resolve_model
from torchrl.weight_update.weight_sync_schemes import (
    TransportBackend,
    WeightStrategy,
    WeightSyncScheme,
)


[docs]class RPCWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed.rpc. This scheme uses RPC calls to synchronize weights across distributed workers. Each remote collector gets its own transport, following the same pattern as multiprocess collectors. """ def _init_on_sender_impl( self, *, model_id: str, context: Any = None, num_workers: int, ) -> None: # Store model_id and context on scheme self.model_id = model_id if context is not None: self.context = context else: raise RuntimeError(f"Expected a context for {type(self).__name__}.") collector_infos = getattr(self.context, "collector_infos", None) collector_rrefs = getattr(self.context, "collector_rrefs", None) collector_class = getattr(self.context, "collector_class", None) if ( collector_infos is None or collector_rrefs is None or collector_class is None ): raise RuntimeError( "RPCWeightSyncScheme requires a context with the following attributes: " "(context.collector_infos, context.collector_rrefs, context.collector_class)" ) # Create transports for each remote collector # worker_rank is i+1 because rank 0 is the main/trainer process for i in range(num_workers): worker_rank = i + 1 transport = self.create_transport( collector_info=collector_infos[i], collector_rref=collector_rrefs[i], collector_class=collector_class, worker_rank=worker_rank, ) self._register_worker_sender(worker_idx=i, transport=transport) def _init_on_receiver_impl( self, *, model_id: str, context: Any = None, worker_idx: int | None = None ) -> None: """Initialize scheme on the worker (receiver) side. Expected kwargs (as provided by collectors): - model_id: str # e.g. "policy" - context: Any # collector / inner collector - worker_idx: int | None # worker index (optional) """ if context is None: raise ValueError( "RPCWeightSyncScheme.init_on_receiver requires a 'context' " "providing access to the model to be synchronized." ) # Store model_id and context on scheme self.model_id = model_id self.worker_idx = worker_idx self.context = context # Access weights to set up missing elements self.weights # noqa self._receiver_transport = RPCTransport(worker_rank=worker_idx) @property def model(self) -> Any | None: if self._model_ref is not None: return self._model_ref() if self._model_id is not None: model = _resolve_model(self.context, self._model_id) if model is None: if self._model_id == "policy": torchrl_logger.debug("Creating policy from factory.") model = self.context.policy_factory[0]() self.context.policy = model else: raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" ) self._model_ref = weakref.ref(model) return model @model.setter def model(self, value: Any): if value is None: return self._model_ref = weakref.ref(value)
[docs] def create_transport( self, *, collector_info=None, collector_rref=None, collector_class=None, worker_rank=None, **kwargs, ) -> TransportBackend: """Create RPC-based transport for a specific remote collector. Args: collector_info: RPC worker info for the remote collector. collector_rref: RPC remote reference to the collector. collector_class: Class of the remote collector. worker_rank: The torch.distributed rank of the remote worker. **kwargs: Additional transport configuration. Returns: RPCTransport configured for this specific remote collector. """ return RPCTransport( collector_info=collector_info, collector_rref=collector_rref, collector_class=collector_class, worker_rank=worker_rank, )
[docs]class RPCTransport: """RPC transport for communicating with a single RPC remote collector. This transport handles weight updates for ONE specific remote collector via torch.distributed primitives (send/recv) with RPC used for signaling. Multiple transports are created for multiple collectors, following the same pattern as the DistributedDataCollector. """ def __init__( self, collector_info=None, collector_rref=None, collector_class=None, worker_rank=None, ): self._collector_info = collector_info self._collector_rref = collector_rref self._collector_class = collector_class self._worker_rank = worker_rank # The torch.distributed rank of this worker self._pending_future = None self._pending_send = None
[docs] def send_weights(self, weights: Any) -> None: """Send weights to the remote collector using torch.distributed. Uses torch.distributed.send() for the actual weight transfer and RPC for signaling the remote collector to receive. Order is critical to avoid deadlock: 1. Signal receiver via RPC to start recv() (non-blocking) 2. Send weights via torch.distributed (blocking until recv completes) """ if self._collector_info is None or self._collector_rref is None: return if self._worker_rank is None: raise RuntimeError("worker_rank must be set for RPC transport") # Step 1: Signal the remote collector via RPC to start receiving (async) # Use rref.rpc_async() to properly call the instance method on the remote object future = self._collector_rref.rpc_async()._receive_weights_scheme() # Step 2: Send weights via torch.distributed (blocks until receiver calls recv()) weights.send(self._worker_rank) # Step 3: Wait for RPC to complete (receiver has applied weights) future.wait()
[docs] def send_weights_async(self, weights: Any) -> None: """Send weights to remote collector asynchronously. Uses torch.distributed.isend() for the actual weight transfer and RPC for signaling. Use wait_ack() to wait for completion. Order is critical to avoid deadlock: 1. Signal receiver via RPC to start recv() (non-blocking) 2. Send weights via torch.distributed.isend() (non-blocking) 3. wait_ack() waits for both to complete """ if self._collector_info is None or self._collector_rref is None: return if self._worker_rank is None: raise RuntimeError("worker_rank must be set for RPC transport") # Step 1: Signal the remote collector via RPC to start receiving (async) # Use rref.rpc_async() to properly call the instance method on the remote object self._pending_future = ( self._collector_rref.rpc_async()._receive_weights_scheme() ) # Step 2: Send weights asynchronously via torch.distributed # Store the Work handle for wait_ack() weights.isend(self._worker_rank)
[docs] def wait_ack(self) -> None: """Wait for both the RPC call and the distributed send to complete.""" # Wait for the RPC call to complete if hasattr(self, "_pending_future") and self._pending_future is not None: self._pending_future.wait() del self._pending_future
[docs] def receive_weights( self, timeout: float | None = None, *, weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, ) -> Any | None: """Receive weights from sender using torch.distributed. Args: timeout: Maximum time to wait for weights (seconds). If None, blocks until weights are received. weights: Pre-allocated weight buffer to receive into. model: The model to apply weights to. strategy: Strategy for applying weights to the model. Returns: The received weights, or None if timeout expires. """ if weights is None: return None if timeout is None: # Blocking receive weights.recv(0) else: # Non-blocking receive with timeout support futures = weights.irecv(src=0, return_premature=True) if futures: start_time = time.monotonic() while True: # Check if all futures are complete all_complete = all(f.is_completed() for f in futures) if all_complete: break # Check timeout elapsed = time.monotonic() - start_time if elapsed >= timeout: # Timeout expired before receiving all weights return None # Small sleep to avoid busy-waiting time.sleep(0.001) # Apply the received weights to the model if model is not None and strategy is not None: strategy.apply_weights(model, weights) return weights
[docs] def setup_connection_and_weights_on_sender(self) -> None: """No-op for RPCTransport - weights are sent via send_weights()."""
[docs] def setup_connection_and_weights_on_receiver( self, *, worker_idx: int, weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, ) -> Any: """No-op for RPCTransport - weights are received via receive().""" return None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources