Shortcuts

Source code for torchrl.weight_update._distributed

from __future__ import annotations

import random
import socket
import time
import weakref
from datetime import timedelta
from typing import Any

import torch
from tensordict import TensorDictBase
from torchrl._utils import logger as torchrl_logger

from torchrl.weight_update.utils import _resolve_model

from torchrl.weight_update.weight_sync_schemes import (
    TransportBackend,
    WeightStrategy,
    WeightSyncScheme,
)


[docs]class DistributedWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed. This scheme uses torch.distributed primitives (send/recv) to synchronize weights across distributed workers. Each worker gets its own transport, following the same pattern as multiprocess collectors. The scheme can create its own TCPStore for coordination if one is not provided. Use `get_store_info()` after `init_on_sender()` to get connection details for workers. Args: backend (str): The distributed backend ("gloo", "nccl", etc.) sync (bool): If True, weight updates are synchronous (blocking receive). If False, a background thread monitors the store and applies weight updates automatically. Defaults to True. timeout (float): Timeout in seconds for TCPStore operations. Defaults to 3600.0 (1 hour). """ def __init__( self, backend: str = "gloo", sync: bool = True, timeout: float = 3600.0, ): super().__init__() self.backend = backend self.sync = sync self._timeout = timeout self._store = None self._store_info = None self._num_workers = None def __getstate__(self): """Custom serialization - exclude non-picklable objects.""" state = super().__getstate__() # TCPStore cannot be pickled - remove it but keep _store_info state["_store"] = None # Thread and Event cannot be pickled state["_background_thread"] = None state["_stop_event"] = None # Transports contain references to store/groups - exclude them # The receiver will create its own transport in init_on_receiver state["_sender_transports"] = {} state["_receiver_transport"] = None return state def __setstate__(self, state): """Custom deserialization.""" super().__setstate__(state) def _init_on_sender_impl( self, *, model_id: str, context: Any = None, num_workers: int, model: Any = None, weights: Any = None, **kwargs, ) -> None: if kwargs: raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}") self.model_id = model_id self._num_workers = num_workers # Attach context so we can resolve the model and prepare # weights on demand via scheme.prepare_weights(). weights_buffer = None if context is not None: self.context = context if weights is not None: self.weights = weights weights_buffer = weights if model is not None: self.model = model else: # resolve model try: model = self.model except (AttributeError, ValueError): pass if weights_buffer is None and model is not None: weights_buffer = self._get_weights_buffer_from_model(model) # Get base tcp_port from context if available to avoid port conflicts. # The DistributedDataCollector uses tcp_port for init and tcp_port+1 for its store, # so we use tcp_port+2 for the weight sync scheme's store. base_tcp_port = ( getattr(context, "tcp_port", None) if context is not None else None ) self._store = self._make_store( is_master=True, num_workers=num_workers, base_tcp_port=base_tcp_port ) for i in range(num_workers): rank = i + 1 # Workers are 1-indexed in distributed transport = self.create_transport( store=self._store, rank=rank, weights_buffer=weights_buffer, sync=self.sync, ) self._register_worker_sender(worker_idx=i, transport=transport) def _make_store( self, is_master: bool, num_workers: int | None = None, store_info: dict | None = None, base_tcp_port: int | str | None = None, max_retries: int = 10, ) -> torch.distributed.TCPStore: """Create a TCPStore for weight synchronization. Args: is_master: If True, creates the store as master (server). If False, connects as client. num_workers: Number of workers (required for master). store_info: Dictionary with 'host' and 'port' keys (required for client). base_tcp_port: Base TCP port from the collector. If provided, the store will use base_tcp_port + 2 to avoid conflicts with the collector's stores (which use base_tcp_port and base_tcp_port + 1). max_retries: Maximum number of retry attempts for handling port conflicts. Returns: The created TCPStore. """ if is_master: # Create as master (server) if num_workers is None: raise ValueError( "num_workers is required when creating store as master" ) hostname = socket.gethostname() host = socket.gethostbyname(hostname) # Use base_tcp_port + 2 if available (to avoid conflicts with collector's # tcp_port and tcp_port + 1), otherwise find a free port dynamically. initial_port = int(base_tcp_port) + 2 if base_tcp_port is not None else None last_error = None for attempt in range(max_retries): if initial_port is None or attempt > 0: # Find a free port dynamically with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", 0)) self._store_port = s.getsockname()[1] else: self._store_port = initial_port try: store = torch.distributed.TCPStore( host_name=host, port=self._store_port, is_master=True, timeout=timedelta(seconds=self._timeout), wait_for_workers=False, # Don't block - workers may not be started yet ) self._store_info = {"host": host, "port": self._store_port} return store except (RuntimeError, OSError) as e: error_msg = str(e).lower() if ( "address already in use" in error_msg or "eaddrinuse" in error_msg ): last_error = e # Add small random delay to reduce collision probability time.sleep(random.uniform(0.01, 0.1)) continue # For other errors, re-raise immediately raise raise RuntimeError( f"DistributedWeightSyncScheme: Failed to create TCPStore after {max_retries} attempts. " f"Last error: {last_error}" ) else: # Connect as client if store_info is None: raise ValueError("store_info is required when connecting as client") store = torch.distributed.TCPStore( host_name=store_info["host"], port=store_info["port"], is_master=False, timeout=timedelta(seconds=self._timeout), ) return store def _init_on_receiver_impl( self, *, model_id: str, context: Any = None, store_info: dict | None = None, worker_idx: int | None = None, **kwargs, ) -> None: """Initialize scheme on the worker (receiver) side. Expected kwargs (as provided by collectors): - model_id: str # e.g. "policy" - context: Any # collector / inner collector - store: TCPStore | None # distributed TCP store - store_info: dict | None # {"host": ..., "port": ...} to create store - rank: int | None # worker rank (1-indexed) """ if context is None: raise ValueError( "DistributedWeightSyncScheme.init_on_receiver requires a 'context' " "providing access to the model to be synchronized." ) if worker_idx is None: raise RuntimeError("rank was not provided.") if kwargs: raise RuntimeError(f"Unexpected kwargs: {kwargs.keys()}") # Store model_id and context on scheme self.model_id = model_id self.context = context # Get or create store # Priority: provided store > provided store_info > self._store_info (from serialization) # Connect to master's TCPStore as client info = self._store_info if info is None: raise RuntimeError( "TCPStore info not available. init_on_sender() must be called first on the sender side, before passing the scheme to the receiver." ) self._store = self._make_store(is_master=False, store_info=info) if (model := getattr(self, "model", None)) is not None: self.model = model weights_buffer = self._get_weights_buffer_from_model(model) else: raise RuntimeError("Couldn't find weights") self._receiver_transport = self.create_transport( store=self._store, rank=worker_idx, weights_buffer=weights_buffer, sync=self.sync, ) # Store worker_idx for synchronize_weights self._worker_idx = worker_idx # Note: Background thread for async mode is started in connect() after init_process_group def _wait_for_instruction(self, timeout: float | None = None) -> str | None: """Block until an instruction arrives via TCPStore. Args: timeout: Maximum time to wait for instruction (seconds). None means block indefinitely. Returns: The instruction string (e.g., "receive", "stop"), or None if stop event is set or timeout expires. """ key = f"NODE_{self._worker_idx}_in" start_time = time.monotonic() while True: if self._stop_event is not None and self._stop_event.is_set(): return None try: instruction = self._store.get(key) self._store.delete_key(key) # Decode bytes to string return ( instruction.decode() if isinstance(instruction, bytes) else instruction ) except RuntimeError: # Key doesn't exist yet, continue polling pass # Check timeout if timeout is not None: elapsed = time.monotonic() - start_time if elapsed >= timeout: return None time.sleep(0.01) def _send_instruction( self, instruction: str = "receive", worker_ids: int | list[int] | None = None, ) -> None: """Send instruction to receiver(s) via TCPStore. Args: instruction: The instruction to send (default: "receive"). worker_ids: Which workers to send to (None = all workers). """ if self._store is None: raise RuntimeError( "Store not initialized. init_on_sender() must be called first." ) if worker_ids is None: target_workers = list(range(self._num_workers)) if self._num_workers else [] elif isinstance(worker_ids, int): target_workers = [worker_ids] else: target_workers = list(worker_ids) # Map instruction to TCPStore format store_instruction = ( b"update_weights" if instruction == "receive" else instruction.encode() ) for worker_idx in target_workers: rank = worker_idx + 1 # Workers are 1-indexed in distributed self._store.set(f"NODE_{rank}_in", store_instruction) def _send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender via TCPStore. Args: message: The acknowledgment message (default: "updated"). """ if self._store is None or self._worker_idx is None: return self._store.set(f"NODE_{self._worker_idx}_out", message.encode()) def _wait_for_ack( self, worker_ids: int | list[int] | None = None, timeout: float | None = None, ) -> None: """Wait for acknowledgment from receiver(s) via TCPStore. Args: worker_ids: Which workers to wait for (None = all workers). timeout: Maximum time to wait (seconds). None means block indefinitely. """ if self._store is None: return if worker_ids is None: target_workers = list(range(self._num_workers)) if self._num_workers else [] elif isinstance(worker_ids, int): target_workers = [worker_ids] else: target_workers = list(worker_ids) for worker_idx in target_workers: rank = worker_idx + 1 key = f"NODE_{rank}_out" try: status = self._store.get(key) if status != b"updated": torchrl_logger.warning( f"Unexpected ack from worker {worker_idx}: {status}" ) self._store.delete_key(key) except Exception as e: torchrl_logger.warning( f"Timeout waiting for ack from worker {worker_idx}: {e}" ) def _background_receive_loop(self): """Background thread loop that waits for instructions and receives weights. This loop: 1. Waits for an instruction via TCPStore 2. Receives weights via torch.distributed 3. Sends an acknowledgment back 4. Repeats until stop event is set """ while not self._stop_event.is_set(): try: instruction = self._wait_for_instruction() if instruction is None: continue if instruction in ("receive", "update_weights"): # Receive weights via torch.distributed weights = self._receiver_transport.receive_weights( model=self.model, strategy=self._strategy, ) if weights is not None: # Cascade weight update to sub-collectors if context supports it model_id = self._model_id or "policy" if self.context is not None and hasattr( self.context, "update_policy_weights_" ): self.context.update_policy_weights_( model_id=model_id, policy_or_weights=weights ) # Send acknowledgment self._send_ack("updated") elif instruction == "stop": break else: torchrl_logger.warning( f"DistributedWeightSyncScheme: Unknown instruction: {instruction}" ) except Exception as e: if not self._stop_event.is_set(): torchrl_logger.warning( f"DistributedWeightSyncScheme: Background receiver error: {e}" ) def _setup_connection_and_weights_on_sender_impl( self, *, worker_idx: int | None = None, weights: Any | None = None ) -> None: """Send initial weights to all workers during connect(). If the sender has a stateful model (weights available), send them to all workers so they start with the correct weights. Note: This uses direct torch.distributed send/recv without TCPStore signaling to avoid interfering with the main collection loop. """ # Initialize torch.distributed process group if not already done # This is a collective operation - all workers must call it if not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend=self.backend, rank=0, # Sender is always rank 0 world_size=self._num_workers + 1, timeout=timedelta(seconds=self._timeout), ) # Check if we have weights to send if weights is None and getattr(self, "model", None) is None: self._store.set("STATELESS_MODEL", b"1") return self._store.set("STATELESS_MODEL", b"0") # Prepare weights from model weights = self._get_weights_buffer_from_model(self.model) if weights is None or weights.is_empty(): return # Send to all workers using direct torch.distributed (no TCPStore signaling) for i, transport in enumerate(self._iterate_transports()): if worker_idx is not None and i != worker_idx: continue transport.send_initial_weights(weights) def _setup_connection_and_weights_on_receiver_impl( self, *, worker_idx: int | None = None ) -> None: """Receive initial weights from sender during connect(). The receiver always has a model that needs weights, so we block waiting for the initial weights from the sender. """ # Use stored worker_idx if not provided if worker_idx is None: worker_idx = self._worker_idx # Initialize torch.distributed process group if not already done # This is a collective operation - sender and all workers must call it if not torch.distributed.is_initialized(): torch.distributed.init_process_group( backend=self.backend, rank=worker_idx, world_size=self._num_workers + 1, timeout=timedelta(seconds=self._timeout), ) if self._receiver_transport is None: torchrl_logger.warning( "DistributedWeightSyncScheme: No receiver transport, skipping initial weight sync" ) return stateless_model = self.receiver_transport._store.get("STATELESS_MODEL") if stateless_model not in (b"0", b"1"): raise RuntimeError(f"Invalid STATELESS_MODEL value: {stateless_model}") if stateless_model != b"1": # Receive initial weights (blocking, no TCPStore coordination) weights = self._receiver_transport.receive_initial_weights() if weights is not None and self.model is not None: self._strategy.apply_weights(self.model, weights, inplace=False) # Start background receiver thread AFTER initial weight sync is complete # This prevents the background thread from consuming the initial sync messages if self._background_thread is None: self._start_background_receiver()
[docs] def shutdown(self) -> None: """Stop background receiver thread and clean up.""" # Check if already shutdown if getattr(self, "_is_shutdown", False): return self._is_shutdown = True # Let base class handle background thread cleanup super().shutdown()
@property def model(self) -> Any | None: """Get the model associated with this scheme. Returns: The model if set, None otherwise. """ if self._model_ref is not None: return self._model_ref() if self._model_id is not None: model = _resolve_model(self.context, self._model_id) if model is None: if self._model_id == "policy": torchrl_logger.debug("Creating policy from factory.") model = self.context.policy_factory[0]() self.context.policy = model else: raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" ) self._model_ref = weakref.ref(model) return model @model.setter def model(self, value: Any): """Set the model for this scheme. Args: value: The model to set. If None, the setter is a no-op. """ if value is None: return self._model_ref = weakref.ref(value)
[docs] def create_transport(self, **kwargs) -> TransportBackend: """Create distributed transport for a specific worker.""" return DistributedTransport(**kwargs)
[docs]class DistributedTransport: """torch.distributed transport for communicating with a single distributed worker. This transport handles weight updates for ONE specific distributed worker via torch.distributed send/recv. Multiple transports are created for multiple workers, following the same pattern as multiprocess collectors. """ def __init__( self, *, weights_buffer: TensorDictBase, store: torch.distributed.Store = None, rank: int | None = None, sync: bool = True, ): """Initialize the DistributedTransport. Args: weights_buffer (TensorDictBase): a tensor buffer of weights. store (torch.distributed.Store): A (TCP)Store for communication. rank (int): Worker rank (1-indexed). sync (bool): Whether to use synchronous weight updates. """ self._store = store self._rank = rank self._sync = sync self._weights_buffer = weights_buffer
[docs] def send_weights(self, weights: Any) -> None: """Send weights to the distributed worker.""" if self._store is None or self._rank is None: return # Instruct worker to expect weight update self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed if self._sync: weights.send(self._rank) else: weights.isend(self._rank) # Wait for acknowledgment status = self._store.get(f"NODE_{self._rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'updated' but got status {status}.") self._store.delete_key(f"NODE_{self._rank}_out")
[docs] def send_weights_async(self, weights: Any) -> None: """Send weights to distributed worker without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. """ if self._store is None or self._rank is None: return # Instruct worker to expect weight update self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed if self._sync: weights.send(self._rank) else: weights.isend(self._rank)
[docs] def wait_ack(self) -> None: """Wait for acknowledgment from distributed worker.""" if self._store is None or self._rank is None: return status = self._store.get(f"NODE_{self._rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'updated' but got status {status}.") self._store.delete_key(f"NODE_{self._rank}_out")
[docs] def receive_weights( self, timeout: float | None = None, *, weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, ) -> Any | None: r"""Receive weights via torch.distributed and apply them to the model. The surrounding collector loop is responsible for checking the TCPStore for the \"update_weights\" instruction. When this method is called we assume that a weight update has been requested and the sender has already performed the corresponding ``send()``. Args: timeout: Maximum time to wait for weights (seconds). If None, blocks until weights are received. weights: Pre-allocated weight buffer to receive into. model: The model to apply weights to. strategy: Strategy for applying weights to the model. Returns: The received weights, or None if timeout expires. """ if self._store is None or self._rank is None: return None # Use provided weights buffer or fallback to stored one weights_buffer = weights if weights is not None else self._weights_buffer # Receive weights via torch.distributed into the buffer if self._sync or timeout is None: # Blocking receive - no timeout support if self._sync: weights_buffer.recv(src=0) else: weights_buffer.irecv(src=0) else: # Non-blocking receive with timeout support futures = weights_buffer.irecv(src=0, return_premature=True) if futures: start_time = time.monotonic() while True: # Check if all futures are complete all_complete = all(f.is_completed() for f in futures) if all_complete: break # Check timeout elapsed = time.monotonic() - start_time if elapsed >= timeout: # Timeout expired before receiving all weights return None # Small sleep to avoid busy-waiting time.sleep(0.001) # Apply weights if model and strategy provided if model is not None and strategy is not None: strategy.apply_weights(model, weights_buffer) return weights_buffer
[docs] def send_initial_weights(self, weights: Any) -> None: """Send initial weights during connect() without TCPStore signaling. This is used for the initial weight sync during connect() to avoid interfering with the main collection loop's TCPStore-based coordination. """ if self._rank is None: return # Note: No TCPStore signaling for initial sync - just direct send/recv if self._sync: weights.send(self._rank) else: weights.isend(self._rank)
[docs] def receive_initial_weights(self) -> Any: """Receive initial weights during connect() without TCPStore signaling. This is used for the initial weight sync during connect() to avoid interfering with the main collection loop's TCPStore-based coordination. Returns: The received weights TensorDict. """ if self._sync: self._weights_buffer.recv(src=0) else: self._weights_buffer.irecv(src=0) return self._weights_buffer
[docs] def setup_connection_and_weights_on_sender(self) -> None: """No-op for DistributedTransport - handled by scheme."""
[docs] def setup_connection_and_weights_on_receiver( self, *, worker_idx: int, weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, ) -> Any: """No-op for DistributedTransport - handled by scheme.""" return None

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources