Shortcuts

Source code for torchrl.weight_update._ray

from __future__ import annotations

import os
import socket

import time
import weakref
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Literal

import torch
from tensordict import TensorDict
from tensordict.base import TensorDictBase

from torchrl._utils import logger as torchrl_logger
from torchrl.weight_update.utils import _resolve_model
from torchrl.weight_update.weight_sync_schemes import (
    TransportBackend,
    WeightStrategy,
    WeightSyncScheme,
)

# Default timeout for torch.distributed operations
_DIST_TIMEOUT = timedelta(seconds=60)


@dataclass
class ConnectionInfo:
    """Connection info for Ray distributed computing.

    Uses dataclass instead of UserDict to avoid Ray signature introspection
    issues with UserDict's __class_getitem__ in Python 3.11+
    (ValueError: no signature found for builtin type GenericAlias).
    """

    master_addr: str
    master_port: int
    world_size: int
    stateful_model: bool

    def get(self, key: str, default: Any = None) -> Any:
        """Get a connection info value by key name.

        Args:
            key (str): The attribute name to retrieve.
            default: The default value if the attribute does not exist.
                Defaults to None.

        Returns:
            The value of the attribute, or the default if not found.
        """
        return getattr(self, key, default)


[docs]class RayTransport: """Ray transport for communicating with a single Ray actor. This transport handles weight updates for ONE specific remote actor using torch.distributed for efficient weight transfer. Ray is used for signaling/coordination, while the actual weight data is transferred via torch.distributed send/recv operations. Multiple transports are created for multiple actors, following the same pattern as multiprocess collectors. Args: remote_actor: The Ray actor handle for the remote collector/transform. worker_idx (int, optional): The worker index for this remote actor. Defaults to 0. backend (str): The torch.distributed backend to use ("gloo" or "nccl"). Defaults to "gloo". connection_info_name (str): Name of the Ray actor storing connection info. Defaults to "connection_info". model_id (str, optional): The model identifier for weight synchronization. """ def __init__( self, *, remote_actor=None, worker_idx: int | None = None, backend: str = "gloo", connection_info_name: str = "connection_info", model_id: str | None = None, ): """Initialize the RayTransport. Args: remote_actor: The Ray actor handle for the remote collector/transform. worker_idx (int, optional): The worker index for this remote actor. Defaults to 0. backend (str): The torch.distributed backend to use ("gloo" or "nccl"). Defaults to "gloo". connection_info_name (str): Name of the Ray actor storing connection info. Defaults to "connection_info". model_id (str, optional): The model identifier for weight synchronization. """ try: import ray self.ray = ray except ImportError: raise ImportError("Ray is required for RayTransport") self._remote_actor = remote_actor self._worker_idx = worker_idx if worker_idx is not None else 0 self._backend = backend self._connection_info_name = connection_info_name self._model_id = model_id # Distributed state self._dist_initialized = False self._weights_buffer: TensorDictBase | None = None self._stateful_model: bool = True # Async operation state self._pending_future = None self._pending_isend = None # Model reference (set by scheme on receiver side) self._model = None @property def _rank(self) -> int: """Get the torch.distributed rank for this worker. Returns: int: The rank (worker_idx + 1, since sender is rank 0). """ return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed
[docs] def set_model(self, model: Any) -> None: """Set the model for receiving weights. Args: model: The model to receive weights into. """ self._model = model
# ======================================================================== # Sending Weights (Sender Side) # ========================================================================
[docs] def send_weights(self, weights: Any) -> None: """Send weights to the remote actor via torch.distributed. This method: 1. Signals the remote actor to start receiving via Ray remote call 2. Sends weights via torch.distributed.isend 3. Waits for both to complete Args: weights: The weights to send (typically a TensorDict). """ if self._remote_actor is None: return # Step 1: Signal the remote actor via Ray to start receiving (async) future = self._remote_actor._receive_weights_scheme.remote() # Step 2: Send weights via torch.distributed (async) weights.isend(dst=self._rank) # Step 3: Wait for the Ray call to complete (receiver has applied weights) self.ray.get(future)
[docs] def send_weights_async(self, weights: Any) -> None: """Send weights to Ray actor without waiting for completion. Use :meth:`wait_ack` to wait for completion after sending to all actors. Args: weights: The weights to send (typically a TensorDict). """ if self._remote_actor is None: return # Step 1: Signal the actor via Ray to start receiving (async) self._pending_future = self._remote_actor._receive_weights_scheme.remote() # Step 2: Send weights via torch.distributed (async) self._pending_isend = weights.isend(dst=self._rank, return_early=True)
[docs] def wait_ack(self) -> None: """Wait for Ray actor to finish applying weights. Raises: RuntimeError: If no pending future exists (i.e., :meth:`send_weights_async` was not called before this method). """ if self._pending_future is not None: self.ray.get(self._pending_future) if self._pending_isend is not None: for fut in self._pending_isend: fut.wait() self._pending_future = None self._pending_isend = None else: raise RuntimeError("No pending future. Did you call send_weights_async?")
# ======================================================================== # Receiving Weights (Receiver Side) # ========================================================================
[docs] def receive_weights( self, timeout: float | None = None, *, weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, ) -> Any | None: """Receive weights from sender via torch.distributed. Args: timeout: Maximum time to wait for weights (seconds). If None, blocks until weights are received. weights: Pre-allocated weight buffer to receive into. model: The model to apply weights to. strategy: Strategy for applying weights to the model. Returns: The received weights, or None if timeout expires. """ from torchrl.collectors.utils import _cast # Use provided weights buffer or fallback to stored one weights_buffer = weights if weights is not None else self._weights_buffer if weights_buffer is None: if model is None: raise RuntimeError("No model available to receive weights") if isinstance(model, torch.nn.Module): weights_buffer = TensorDict.from_module(model) weights_buffer = weights_buffer.data.apply(_cast, weights_buffer) else: weights_buffer = TensorDict(lock=True) # Cache the weights buffer for future use if self._weights_buffer is None: self._weights_buffer = weights_buffer # Receive weights from rank 0 if timeout is None: # Blocking receive weights_buffer.irecv(src=0) else: # Non-blocking receive with timeout support futures = weights_buffer.irecv(src=0, return_premature=True) if futures: start_time = time.monotonic() while True: # Check if all futures are complete all_complete = all(f.is_completed() for f in futures) if all_complete: break # Check timeout elapsed = time.monotonic() - start_time if elapsed >= timeout: # Timeout expired before receiving all weights return None # Small sleep to avoid busy-waiting time.sleep(0.001) # Apply weights to model if not isinstance(model, torch.nn.Module): if not weights_buffer.is_empty(): raise RuntimeError( f"Cannot cast weights to model type: {type(model)} with weights: {weights_buffer}." ) return None if strategy is not None: strategy.apply_weights(model, weights_buffer) else: weights_buffer.to_module(model) return weights_buffer
# ======================================================================== # Connection Setup # ========================================================================
[docs] def setup_connection_and_weights_on_sender(self) -> None: """Initialize torch.distributed on sender side for this worker's rank. This is called by the scheme after it has created the connection info Ray actor. The actual ``init_process_group`` happens in the scheme since it's a collective operation that needs to happen for rank 0. Note: This method exists for interface compatibility but the real work happens in the scheme's :meth:`_setup_distributed_connection_sender`. """
# The scheme handles the collective init_process_group for rank 0. # This method exists for interface compatibility but the real work # happens in the scheme's _setup_distributed_connection_sender.
[docs] def setup_connection_and_weights_on_receiver( self, *, worker_idx: int, strategy: WeightStrategy | None = None, model: Any | None = None, weights: Any | None = None, ) -> Any: """Join torch.distributed process group and receive initial weights. This method: 1. Retrieves connection info from the shared Ray actor 2. Initializes torch.distributed process group with rank=worker_idx+1 3. Receives weights if model is stateful Args: worker_idx (int): The worker index for this transport. strategy (WeightStrategy, optional): The weight transmission strategy. model (nn.Module or compatible, optional): The model to receive weights for. weights (TensorDict, optional): Pre-allocated buffer for receiving weights. Returns: The received weights (TensorDict) if model is stateful, None otherwise. """ if self._dist_initialized: # Already initialized, just receive weights if stateful if self._stateful_model: result = self.receive_weights( weights=weights, model=model, strategy=strategy ) return result[1] if result else None return None self._worker_idx = worker_idx rank = self._rank # Wait for connection info actor to be available i = 0 while True: try: remote_connection_info = self.ray.get_actor(self._connection_info_name) except ValueError: i += 1 time.sleep(0.1) continue break master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) master_port = self.ray.get(remote_connection_info.get.remote("master_port")) world_size = self.ray.get(remote_connection_info.get.remote("world_size")) stateful_model = self.ray.get( remote_connection_info.get.remote("stateful_model") ) self._stateful_model = stateful_model # Set environment variables for torch.distributed os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(master_port) # Initialize process group on receiver torch.distributed.init_process_group( backend=self._backend, rank=rank, world_size=world_size, ) self._dist_initialized = True # Receive initial weights if model is stateful if self._stateful_model: return self.receive_weights(model=model, weights=weights, strategy=strategy) return None
[docs]class RayWeightSyncScheme(WeightSyncScheme): """Weight synchronization for Ray distributed computing. This scheme uses torch.distributed to synchronize weights across distributed workers (Ray actors). The process group is initialized during the first ``synchronize_weights()`` call, with the sender as rank 0 and workers as rank ``worker_idx + 1``. Each remote collector gets its own transport, following the same pattern as multiprocess collectors. Args: strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). Defaults to "tensordict". backend (str): The torch.distributed backend to use ("gloo" or "nccl"). Defaults to "gloo". """ @property def connection_info_name(self) -> str: """Get the name of the Ray actor storing connection info. Returns a unique name based on model_id to avoid collisions when multiple schemes are used with different models. Returns: The connection info actor name. """ if self._model_id is not None: return f"connection_info_{self._model_id}" return "connection_info" def __init__( self, strategy: Literal["tensordict", "state_dict"] = "tensordict", backend: str = "gloo", ): """Initialize the RayWeightSyncScheme. Args: strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). Defaults to "tensordict". backend (str): The torch.distributed backend to use ("gloo" or "nccl"). Defaults to "gloo". """ super().__init__(strategy) self._backend = backend self._dist_initialized = False self._remote_collectors: list | None = None self._num_workers: int = 0 @property def model(self) -> Any | None: """Get the model associated with this scheme. Returns: The model if set, None otherwise. """ if self._model_ref is not None: return self._model_ref() if self._model_id is not None: model = _resolve_model(self.context, self._model_id) if model is None: if self._model_id == "policy": torchrl_logger.debug("Creating policy from factory.") model = self.context.policy_factory[0]() self.context.policy = model else: raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" ) self._model_ref = weakref.ref(model) return model @model.setter def model(self, value: Any): """Set the model for this scheme. Args: value: The model to set. If None, the setter is a no-op. """ if value is None: return self._model_ref = weakref.ref(value)
[docs] def create_transport( self, *, remote_actor=None, worker_idx: int | None = None, # Legacy parameter name for backwards compatibility remote_collector=None, **kwargs, ) -> TransportBackend: """Create Ray-based transport for a specific remote actor. Args: remote_actor: The Ray actor handle for the remote collector/transform. worker_idx: The worker index for this remote actor. remote_collector: Legacy alias for remote_actor. **kwargs: Additional transport configuration. Returns: RayTransport configured for this specific remote actor. """ # Support legacy parameter name if remote_actor is None: remote_actor = remote_collector return RayTransport( remote_actor=remote_actor, worker_idx=worker_idx, backend=self._backend, connection_info_name=self.connection_info_name, model_id=self._model_id, )
def _init_on_sender_impl( self, model_id: str, context: Any = None, **kwargs, ) -> None: """Initialize on the main process (sender side). This method sets up the torch.distributed connection info and shares it with all remote collectors so they can join the process group. Args: model_id: Identifier for the model being synchronized context: Optional context object providing remote_collectors **kwargs: Alternative to context (remote_collectors, source_model, etc.) """ try: import ray self.ray = ray except ImportError: raise ImportError("Ray is required for RayWeightSyncScheme") # Extract parameters from context or kwargs if context is not None: remote_collectors = getattr(context, "remote_collectors", None) num_workers = getattr(context, "num_workers", None) or getattr( context, "num_collectors", None ) else: remote_collectors = kwargs.get("remote_collectors") num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") if remote_collectors is None: raise ValueError("remote_collectors must be provided via context or kwargs") if num_workers is None: num_workers = len(remote_collectors) if remote_collectors else 0 # Store model_id and context on scheme self.model_id = model_id # Store remote collectors and num_workers for synchronize_weights self._remote_collectors = list(remote_collectors) self._num_workers = int(num_workers) # Register each Ray actor with explicit transport kwargs for worker_idx, remote_collector in enumerate(remote_collectors): transport = self.create_transport( remote_actor=remote_collector, worker_idx=worker_idx, ) self._register_worker_sender( worker_idx=worker_idx, transport=transport, ) # Set context with weak reference to avoid circular refs if context is not None: self.context = context # Store source model reference if provided for automatic weight extraction model = kwargs.get("model") if model is not None: self.model = model # Note: Distributed connection setup is deferred to synchronize_weights # because _receiver_schemes on workers won't exist until register_scheme_receiver is called def _init_on_receiver_impl( self, model_id: str, context: Any = None, **kwargs, ) -> None: """Initialize on worker process (receiver side). Args: model_id: Identifier for the model being synchronized context: Optional context object (typically the remote collector) **kwargs: Optional parameters (worker_idx, model, etc.) """ try: import ray self.ray = ray except ImportError: raise ImportError("Ray is required for RayWeightSyncScheme") # Store model_id and context on scheme self.model_id = model_id self.context = context # Extract worker_idx from context or kwargs if context is not None: worker_idx = getattr(context, "worker_idx", None) else: worker_idx = kwargs.get("worker_idx") self._worker_idx = worker_idx # Resolve the target model on this worker model = kwargs.get("model") if model is not None: self.model = model # get the weights to possibly instantiate a copy of the model (policy factory with multi-collector) self.weights # noqa # Create and register transport for receiver side # Note: create_transport returns TransportBackend but we know it's RayTransport transport = self.create_transport( remote_actor=None, # Receiver doesn't need actor handle worker_idx=worker_idx, ) if isinstance(transport, RayTransport): transport.set_model(model) self._register_transport_receiver(transport=transport) def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: """Set up torch.distributed connection info and share with remote collectors. This method: 1. Gets master address and finds an available port 2. Stores connection info in Ray's object store as a named actor 3. Initializes torch.distributed process group with rank=0 Args: timeout: Maximum time in seconds to wait for workers to be ready. Default is 300 seconds (5 minutes). """ if self._dist_initialized: return if self._remote_collectors is None or self._num_workers == 0: raise RuntimeError( "_setup_distributed_connection() requires remote_collectors to be set" ) # Get master address (hostname/IP) hostname = socket.gethostname() try: master_addr = socket.gethostbyname(hostname) except socket.gaierror: master_addr = "127.0.0.1" # Find an available port master_port = self._find_free_port() world_size = self._num_workers + 1 # +1 for the sender (rank 0) try: self.weights stateful_model = True except (AttributeError, RuntimeError, ValueError): stateful_model = False self._stateful_model = stateful_model # Connection info to share with workers via named Ray actor RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( name=self.connection_info_name ) self._connection_info_actor = RemoteConnectionInfo.remote( master_addr=master_addr, master_port=master_port, world_size=world_size, stateful_model=stateful_model, ) # Set environment variables for torch.distributed os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(master_port) # Initialize process group on sender (rank 0) # Note: Workers will call init_process_group in their transport's # setup_connection_and_weights_on_receiver. The init_process_group is # a collective operation, so all ranks must call it together. torch.distributed.init_process_group( backend=self._backend, rank=0, world_size=world_size, timeout=_DIST_TIMEOUT, ) self._dist_initialized = True def _setup_connection_and_weights_on_sender_impl( self, *, worker_idx: int | None = None, weights: Any | None = None, ) -> None: """Set up distributed connection and send initial weights to all workers. This method: 1. Sets up torch.distributed process group (waits for workers if needed) 2. Sends initial weights to all workers via their transports The distributed setup is done here (not in ``init_on_sender``) because workers need to have ``register_scheme_receiver`` called first. Args: worker_idx (int, optional): Not used in this implementation. weights (optional): Not used in this implementation (weights are extracted from the model). """ # Set up distributed connection (with wait for workers to be ready) if not self._dist_initialized: self._setup_distributed_connection_sender() # Send the initial weights if self._stateful_model: self._send_weights_distributed() def _send_weights_distributed(self) -> None: """Send weights to all workers via torch.distributed. Raises: RuntimeError: If no weights are available to send. """ # Extract weights from model weights = self.weights if weights is None: raise RuntimeError("No weights available to send") # Send weights to each worker (ranks 1 to num_workers) futures = [] for worker_idx in range(self._num_workers): rank = worker_idx + 1 futures.extend(weights.isend(dst=rank, return_early=True)) # Wait for all sends to complete for future in futures: future.wait() def _setup_connection_and_weights_on_receiver_impl( self, *, worker_idx: int | None = None ) -> None: """Join torch.distributed process group and receive initial weights. Delegates to the transport's :meth:`~RayTransport.setup_connection_and_weights_on_receiver`. Args: worker_idx (int, optional): The worker index. If None, uses the stored ``_worker_idx`` or defaults to 0. """ if worker_idx is None: worker_idx = self._worker_idx if worker_idx is None: worker_idx = 0 # Default to worker 0 transport = self.receiver_transport if transport is not None: # Transport handles joining process group and receiving weights transport.setup_connection_and_weights_on_receiver( worker_idx=worker_idx, model=self.model, weights=self.weights, strategy=self._strategy, ) self._dist_initialized = True @staticmethod def _find_free_port() -> int: """Find a free port on the local machine. Returns: int: An available port number. """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) s.listen(1) port = s.getsockname()[1] return port
[docs]class RayModuleTransformScheme(RayWeightSyncScheme): """Weight synchronization for RayModuleTransform. This scheme uses torch.distributed to synchronize weights between a trainer/collector and a RayModuleTransform actor. The sender is rank 0, the transform's actor is rank 1. This enables updating the weights of a module running inside a RayModuleTransform from a parent collector or training loop. Args: strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). Default is "tensordict". backend (str): The torch.distributed backend to use ("gloo" or "nccl"). Default is "gloo". Example: >>> # Create scheme and transform >>> scheme = RayModuleTransformScheme() >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) >>> >>> # Create env with transform >>> env = TransformedEnv(base_env, transform) >>> >>> # Pass scheme to parent collector >>> collector = SomeCollector( ... env, policy, ... weight_sync_schemes={"transform_module": scheme} ... ) >>> >>> # Update weights >>> collector.update_policy_weights_(model_id="transform_module") """ def __init__( self, strategy: Literal["tensordict", "state_dict"] = "tensordict", backend: str = "gloo", ): """Initialize the RayModuleTransformScheme. Args: strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). Defaults to "tensordict". backend (str): The torch.distributed backend to use ("gloo" or "nccl"). Defaults to "gloo". """ super().__init__(strategy, backend) self._ray_transform = None def _set_transform(self, ray_transform) -> None: """Store reference to the RayModuleTransform. Called by RayModuleTransform when the scheme is passed to it. Args: ray_transform: The RayModuleTransform instance. """ self._ray_transform = ray_transform def _init_on_sender_impl( self, model_id: str | None = None, context: Any = None, **kwargs, ) -> None: """Initialize on the main process (sender side). Uses the stored transform reference (set via _set_transform) to create transport for the transform's actor. Args: model_id: Identifier for the model being synchronized context: Optional context object (typically the collector) **kwargs: Optional parameters (ray_transform, model, etc.) """ try: import ray self.ray = ray except ImportError: raise ImportError("Ray is required for RayModuleTransformScheme") # Get transform reference - either stored via _set_transform or from kwargs ray_transform = self._ray_transform if ray_transform is None: ray_transform = kwargs.get("ray_transform") if ray_transform is None: raise ValueError( "ray_transform must be set via _set_transform() or provided in kwargs. " "Pass the scheme to RayModuleTransform constructor to set it automatically." ) # Store model_id self.model_id = model_id # Single worker (the transform's actor) self._num_workers = 1 # Create transport for the transform's actor # The actor handle is ray_transform._actor transport = self.create_transport( remote_actor=ray_transform._actor, worker_idx=0, ) self._register_worker_sender( worker_idx=0, transport=transport, ) # Set context if provided if context is not None: self.context = context # Store source model reference if provided for automatic weight extraction model = kwargs.get("model") if model is not None: self.model = model def _init_on_receiver_impl( self, model_id: str, context: Any = None, **kwargs, ) -> None: """Initialize on the transform's actor (receiver side). Args: model_id: Identifier for the model being synchronized context: The ModuleTransform instance (the actor's underlying class) **kwargs: Optional parameters (worker_idx, model, etc.) """ try: import ray self.ray = ray except ImportError: raise ImportError("Ray is required for RayModuleTransformScheme") # Store model_id and context self.model_id = model_id self.context = context # Single transform actor is always worker_idx=0 self._worker_idx = kwargs.get("worker_idx", 0) # Resolve the target model from context (ModuleTransform has a .module attribute) model = kwargs.get("model") if model is None and context is not None: model = getattr(context, "module", None) if model is not None: self.model = model # Create and register transport for receiver side # Note: create_transport returns TransportBackend but we know it's RayTransport transport = self.create_transport( remote_actor=None, worker_idx=self._worker_idx, ) if isinstance(transport, RayTransport): transport.set_model(model) self._register_transport_receiver(transport=transport) def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: """Set up torch.distributed for the single transform actor. Overrides parent to work with a single RayModuleTransform instead of multiple remote collectors. Args: timeout (float): Maximum time in seconds to wait for connection setup. Defaults to 300.0 (5 minutes). Raises: RuntimeError: If ``ray_transform`` is not set. """ if self._dist_initialized: return if self._ray_transform is None: raise RuntimeError( "_setup_distributed_connection() requires ray_transform to be set. " "Did you pass the scheme to RayModuleTransform?" ) # Get master address (hostname/IP) hostname = socket.gethostname() try: master_addr = socket.gethostbyname(hostname) except socket.gaierror: master_addr = "127.0.0.1" # Find an available port master_port = self._find_free_port() world_size = 2 # Sender (rank 0) + Transform (rank 1) # Check if model has weights try: w = self.weights stateful_model = w is not None except (AttributeError, RuntimeError, ValueError): stateful_model = False self._stateful_model = stateful_model # Connection info to share with the transform's actor RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( name=self.connection_info_name ) self._connection_info_actor = RemoteConnectionInfo.remote( master_addr=master_addr, master_port=master_port, world_size=world_size, stateful_model=stateful_model, ) # Set environment variables for torch.distributed os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(master_port) # Now initialize process group on sender (rank 0) # The receiver is concurrently joining via the Ray call above torch.distributed.init_process_group( backend=self._backend, rank=0, world_size=world_size, timeout=_DIST_TIMEOUT, ) self._dist_initialized = True def _setup_connection_and_weights_on_sender_impl( self, *, worker_idx: int | None = None, weights: Any | None = None, ) -> None: """Set up distributed connection and send initial weights. Args: worker_idx (int, optional): The worker index. Not used for RayModuleTransformScheme as there is only one transform actor. weights (optional): Pre-extracted weights to send. If None, weights are extracted from the model. """ receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote( scheme=self, model_id=self.model_id ) if not self._dist_initialized: self._setup_distributed_connection_sender() if self._stateful_model: self._send_weights_distributed(weights=weights) self.ray.get(receiver_future) def _send_weights_distributed(self, weights: Any | None = None) -> None: """Send weights to the transform actor via torch.distributed. Args: weights (optional): Pre-extracted weights to send. If None, weights are extracted from the model via :attr:`weights`. Raises: RuntimeError: If no weights are available to send. """ if weights is None: weights = self.weights if weights is None: raise RuntimeError("No weights available to send") # Send weights to the transform (rank 1) futures = weights.isend(dst=1, return_early=True) for future in futures: future.wait()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources