Source code for torchrl.weight_update._mp
from __future__ import annotations
from collections.abc import Callable
from typing import Any
import torch
from tensordict import TensorDictBase
from torch import multiprocessing as mp, nn
from torchrl.weight_update._shared import SharedMemWeightSyncScheme
from torchrl.weight_update.utils import _resolve_model
from torchrl.weight_update.weight_sync_schemes import TransportBackend
[docs]class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme):
"""Weight synchronization for multiprocess operations using queues.
This scheme creates transports that communicate via multiprocessing queues.
Unlike the parent SharedMemWeightSyncScheme which uses shared memory for in-place
updates, this scheme sends actual weight copies through queues to workers.
A background thread on the receiver side listens for "receive" instructions
from the sender. When an instruction arrives, the thread receives the weights
from the weight queue and applies them to the model.
It follows the same two-phase pattern as SharedMemWeightSyncScheme:
1. **init_on_sender()**: Stores the recipe for creating device-specific weights
(model reference, devices, mapping functions) without creating actual copies
2. **synchronize_weights()**: Creates device-specific weight copies on-demand,
sends them sequentially to workers via queues, allowing garbage collection
between workers to minimize memory usage
This approach avoids holding multiple weight copies in memory simultaneously,
which is especially beneficial for large models with many workers.
Synchronization flow:
- **init_on_sender()**: Store configuration and register worker queues
- **synchronize_weights()**: Create and send initial weights on-demand
- **init_on_receiver()**: Create receiver that reads from queue
- **send()**: Extract and send weight updates, wait for acknowledgments
Args:
strategy: The weight transmission strategy (default: "tensordict").
Can be "tensordict" or "state_dict".
sync: If True (default), send() blocks until receiver acknowledges.
If False, send() returns immediately (use send_async/wait_async).
Example:
>>> # Basic usage with collector
>>> scheme = MultiProcessWeightSyncScheme()
>>> collector = MultiSyncCollector(
... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3,
... policy=policy,
... frames_per_batch=100,
... total_frames=1000,
... weight_sync_schemes={"policy": scheme},
... )
>>> # scheme.collect() is called automatically by collector
>>> # Weights are created on-demand and sent to workers efficiently
Note:
The on-demand weight creation means that synchronize_weights() will be
slower than if weights were pre-computed, but memory usage is significantly
reduced, especially when workers use different devices or when the model
is large.
"""
def __init__(self, strategy: str = "tensordict", sync: bool = True):
"""Initialize the MultiProcessWeightSyncScheme.
Args:
strategy: The weight transmission strategy (default: "tensordict").
sync: If True (default), send() blocks until receiver acknowledges.
"""
super().__init__(strategy, sync=sync)
# Override parent's shared transport - we don't use shared memory
self._shared_transport = None
def _init_on_sender_impl(
self,
*,
model_id: str | None = None,
context: Any = None,
weights: TensorDictBase | None = None,
model: nn.Module | None = None,
params_map: dict[int, TensorDictBase] | None = None,
devices: list[torch.device] | None = None,
device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None,
num_workers: int | None = None,
ctx: Any = None,
**kwargs,
) -> None:
"""Initialize on the main process (sender side).
This method stores the configuration needed to create device-specific weight
copies during synchronization. Weight copies are created on-demand during
`synchronize_weights()` to reduce memory usage.
Similar to `SharedMemWeightSyncScheme`, this follows a two-phase pattern:
1. `init_on_sender()`: Store the recipe for creating weights
2. `synchronize_weights()`: Create and send weights on-demand
Args:
model_id: Identifier for the model being synchronized (e.g., "policy").
Required when using context.
context: Optional context object (e.g., collector) providing:
- num_workers: Number of worker processes
- policy_device: List of devices for each worker
When provided, model_id is used to resolve the model from context.
weights: Pre-extracted weights as TensorDict. Mutually exclusive with
model and context. Used when weights are already available.
model: Model to extract weights from. Mutually exclusive with weights
and context.
params_map: Pre-computed mapping of worker_idx to device-specific weights.
Most explicit option. When provided, all other parameters must be None.
devices: List of devices for each worker. Used with weights or model to
automatically create device-specific copies. Length must equal num_workers.
device_map_fn: Custom function (worker_idx, weights) -> device_weights.
Allows full control over device mapping. Requires num_workers.
num_workers: Number of workers. Required with device_map_fn, inferred
from devices length otherwise.
ctx: The multiprocessing context to use. Defaults to `multiprocessing.get_context()`.
**kwargs: Reserved for future use.
Examples:
Simple usage with collector context (most common):
>>> scheme = MultiProcessWeightSyncScheme()
>>> collector = MultiSyncCollector(
... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3,
... policy=policy,
... frames_per_batch=100,
... weight_sync_schemes={"policy": scheme},
... )
>>> # scheme.init_on_sender() is called automatically by collector
Direct initialization with explicit devices:
>>> scheme = MultiProcessWeightSyncScheme()
>>> weights = TensorDict.from_module(policy)
>>> scheme.init_on_sender(
... weights=weights,
... devices=[torch.device("cpu"), torch.device("cuda:0")],
... num_workers=2,
... )
Advanced: Pre-computed params_map:
>>> weights_cpu = TensorDict.from_module(policy)
>>> weights_cuda = weights_cpu.to("cuda")
>>> scheme.init_on_sender(
... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda},
... num_workers=3,
... )
"""
# Get params_map from parent class logic
params_map_result = self._get_params_map(
context=context,
model_id=model_id,
weights=weights,
model=model,
params_map=params_map,
devices=devices,
device_map_fn=device_map_fn,
num_workers=num_workers,
)
# Store the mapping recipe for later use in synchronize_weights
# Don't store params_map directly to save memory - we'll recompute on demand
# Note: We don't store context directly to avoid pickle issues -
# it's available via _context_ref
self._device_mapping_info = {
"model_id": model_id,
"weights": weights,
"model": model,
"params_map": params_map,
"devices": devices,
"device_map_fn": device_map_fn,
"num_workers": num_workers
if num_workers is not None
else len(params_map_result),
}
# Create per-worker queues for weight distribution
# Each worker gets its own queue for receiving weights
all_workers = list(params_map_result.keys())
if not hasattr(self, "_weight_init_queues"):
self._weight_init_queues = {}
if ctx is None:
ctx = mp.get_context()
for worker_idx in all_workers:
if worker_idx not in self._weight_init_queues:
self._weight_init_queues[worker_idx] = ctx.Queue()
# Create instruction queues for background receiver
if worker_idx not in self._instruction_queues:
self._instruction_queues[worker_idx] = ctx.Queue()
# Create ack queues for synchronous mode
if worker_idx not in self._ack_queues:
self._ack_queues[worker_idx] = ctx.Queue()
# Store model_id and context on scheme
self.model_id = model_id
if context is not None:
self.context = context
# Register workers with their queues
for worker_idx in all_workers:
queue = self._weight_init_queues[worker_idx]
ack_queue = self._ack_queues[worker_idx]
# Create MPTransport for this worker with ack queue
transport = MPTransport(weight_queue=queue, ack_queue=ack_queue)
self._register_worker_sender(worker_idx=worker_idx, transport=transport)
def _init_on_receiver_impl(
self,
*,
model_id: str,
context: Any = None,
**kwargs,
) -> None:
"""Initialize on worker process (receiver side).
Args:
model_id: Identifier for the model being synchronized
context: Optional context object providing worker_idx and model
**kwargs: Alternative to context (worker_idx, model, etc.)
"""
# Extract parameters from context or kwargs
if context is not None:
worker_idx = getattr(context, "worker_idx", None)
if hasattr(context, "get_model"):
model = context.get_model(model_id)
else:
model = _resolve_model(context, model_id)
else:
worker_idx = kwargs.get("worker_idx")
model = kwargs.get("model")
if worker_idx is None:
raise ValueError("worker_idx must be provided via context or kwargs")
# Get the queue for this worker
if worker_idx not in self._weight_init_queues:
raise ValueError(
f"Worker {worker_idx} not registered. init_on_sender() must be called first."
)
queue = self._weight_init_queues[worker_idx]
ack_queue = self._ack_queues.get(worker_idx)
# Store on scheme directly
self.model_id = model_id
if context is not None:
self.context = context
# Store instruction and ack queue references for this worker
if worker_idx in self._instruction_queues:
self._receiver_instruction_queue = self._instruction_queues[worker_idx]
if worker_idx in self._ack_queues:
self._receiver_ack_queue = self._ack_queues[worker_idx]
# Create transport with the worker's queue and ack queue
transport = MPTransport(weight_queue=queue, ack_queue=ack_queue)
self._register_transport_receiver(transport=transport)
if model is not None:
self.model = model
# Store worker_idx for synchronize_weights
self.worker_idx = worker_idx
[docs] def send(
self,
weights: Any = None,
worker_ids: int | list[int] | None = None,
) -> None:
"""Send weights synchronously to workers.
This method:
1. Prepares weights (extracts from model if weights=None)
2. Sends weights to the weight queue
3. Sends "receive" instruction to workers' background threads
4. If sync=True, waits for acknowledgments from those workers
Args:
weights: Weights to send. Can be:
- None: Extract from model via context.get_model(model_id)
- nn.Module: Extract weights from module
- TensorDict: Use directly
- dict: Convert to TensorDict
worker_ids: Which workers to send to:
- None: Send to all workers (default)
- int: Send to single worker
- list[int]: Send to specific workers
Note: If sync=True (default), this is a blocking call that ensures
specified workers are updated before returning.
"""
if not self.initialized_on_sender:
raise RuntimeError("Must be initialized on sender before sending weights")
if not self.synchronized_on_sender:
raise RuntimeError("Must be synchronized on sender before sending weights")
model_id = self.model_id
context = self.context
# Let the scheme prepare the weights
prepared_weights = self.prepare_weights(
weights=weights,
model_id=model_id,
strategy=self._strategy,
context=context,
)
transports = list(self._iterate_transports(worker_ids))
# Send weights to all workers first via queue (non-blocking)
for transport in transports:
if hasattr(transport, "send_weights_async"):
# For MPTransport, pass model_id; other transports don't need it
transport.send_weights_async(prepared_weights, model_id=model_id)
else:
# Fallback for transports that don't support async send
transport.send_weights(prepared_weights)
# Send instruction to workers' background threads to receive the weights
self._send_instruction(instruction="receive", worker_ids=worker_ids)
# Wait for all acknowledgments if in synchronous mode
if self.sync:
self._wait_for_ack(worker_ids=worker_ids)
def _setup_connection_and_weights_on_sender_impl(
self,
*,
worker_idx: int | None = None,
weights: Any | None = None,
) -> None:
"""Synchronize weights with workers before collection starts.
Computes device-specific weight copies on-demand and sends them to workers
sequentially via queues. This is called once after workers are initialized
but before they start collecting data.
Unlike send(), this does not wait for acknowledgments since workers are still
in their initialization phase.
This approach creates weight copies on-demand and sends them sequentially,
allowing garbage collection between workers to reduce memory usage.
Raises:
RuntimeError: If init_on_sender() was not called first.
"""
# Get the device mapping info stored during init_on_sender
if not hasattr(self, "_device_mapping_info"):
raise RuntimeError(
"synchronize_weights() requires init_on_sender() to be called first"
)
mapping_info = self._device_mapping_info
# Get context from weakref
context = self.context
# Compute params_map on-demand
# Extract with explicit type casting for type checker
model_id = mapping_info["model_id"]
weights = mapping_info["weights"]
model = mapping_info["model"]
params_map_arg = mapping_info["params_map"]
devices = mapping_info["devices"]
device_map_fn = mapping_info["device_map_fn"]
num_workers = mapping_info["num_workers"]
params_map = self._get_params_map(
context=context,
model_id=model_id,
weights=weights,
model=model,
params_map=params_map_arg,
devices=devices,
device_map_fn=device_map_fn,
num_workers=num_workers,
)
# Send to workers sequentially via queues (no ACK - workers are still initializing)
# This allows GC to clean up each worker's weights before creating the next
for i, transport in enumerate(self._iterate_transports()):
if worker_idx is not None and i != worker_idx:
continue
worker_weights = params_map[i]
if hasattr(transport, "send_weights_async"):
transport.send_weights_async(worker_weights, model_id=self._model_id)
else:
raise RuntimeError(
f"Transport {type(transport)} does not support async send for synchronization"
)
# Clean up the mapping info after synchronization
delattr(self, "_device_mapping_info")
def _setup_connection_and_weights_on_receiver_impl(
self, *, worker_idx: int | None = None
) -> None:
"""Receive initial weights and start background receiver thread.
This method:
1. Receives initial weights from the sender via queue
2. Applies them to the model
3. Starts a background thread that listens for "receive" instructions
Args:
worker_idx: The worker index.
"""
# Use stored worker_idx if not provided
if worker_idx is None:
worker_idx = self._worker_idx
if worker_idx is None:
raise RuntimeError(
"worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl."
)
# Receive initial weights from queue via transport
if self._receiver_transport is None:
raise RuntimeError("Receiver transport not set.")
weights = self._receiver_transport.setup_connection_and_weights_on_receiver(
worker_idx=worker_idx,
weights=self.weights,
model=self.model,
strategy=self._strategy,
)
# Store received weights for later use
if weights is not None:
self._receiver_weights = weights
# Apply weights to model
if weights is not None and self.model is not None:
self._strategy.apply_weights(self.model, weights, inplace=False)
# Start background receiver thread
self._start_background_receiver()
def _background_receive_loop(self):
"""Background thread loop that waits for instructions and receives weights.
This loop:
1. Waits for a "receive" instruction from the sender
2. Receives weights from the weight queue
3. Applies them to the model
4. Sends an acknowledgment back to the sender
5. Repeats until stop event is set or "stop" instruction received
"""
from torchrl._utils import logger as torchrl_logger
while not self._stop_event.is_set():
try:
instruction = self._wait_for_instruction()
if instruction is None:
# Stop event was set or timeout
continue
if instruction == "receive":
# Receive weights from transport (blocking)
if self._receiver_transport is not None:
weights = self._receiver_transport.receive_weights(
model=self.model,
strategy=self._strategy,
)
if weights is not None:
# Cascade weight update to sub-collectors if context supports it
model_id = self._model_id or "policy"
if self.context is not None and hasattr(
self.context, "update_policy_weights_"
):
self.context.update_policy_weights_(
model_id=model_id, policy_or_weights=weights
)
# Send acknowledgment
self._send_ack("updated")
elif instruction == "stop":
break
else:
torchrl_logger.warning(
f"MultiProcessWeightSyncScheme: Unknown instruction: {instruction}"
)
except Exception as e:
if not self._stop_event.is_set():
torchrl_logger.warning(
f"MultiProcessWeightSyncScheme: Background receiver error: {e}"
)
[docs] def create_transport(self, **kwargs) -> TransportBackend:
"""Create an MPTransport using the provided queue.
Note:
This is used internally by init_on_sender/init_on_receiver.
"""
queue = kwargs.get("queue")
return MPTransport(weight_queue=queue, ack_queue=None)
[docs]class MPTransport:
"""Multiprocessing transport using queues.
This transport uses queues for weight distribution and synchronization.
Similar to SharedMemTransport's queue-based approach, MPTransport uses
queues to send initial weights to workers during synchronization.
Initialization flow:
- synchronize_weights() extracts weights and sends to all workers via queues
- Workers receive the initial weights via setup_connection_and_weights_on_receiver()
- Subsequent updates use send_weights_async() followed by acknowledgments
Args:
weight_queue (mp.Queue): The queue to use for sending weights.
ack_queue (mp.Queue): The queue to use for receiving acknowledgments.
timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds.
"""
def __init__(self, weight_queue, ack_queue=None, timeout: float = 10.0):
self.timeout = timeout
self.weight_queue = weight_queue
self.ack_queue = ack_queue
[docs] def send_weights_async(self, weights: Any, model_id: str = "policy") -> None:
"""Send weights through the queue without waiting for acknowledgment.
Use wait_ack() to wait for acknowledgment after sending to all workers.
"""
# Send in format expected by worker loop: ((model_id, weights), "update_weights")
self.weight_queue.put(((model_id, weights), "update_weights"))
[docs] def receive_weights(
self,
timeout: float | None = None,
*,
weights: Any = None,
model: Any = None,
strategy: Any = None,
) -> Any | None:
"""Receive weights from the queue (used in worker process).
This method only handles weight update messages. Other messages
(like "close", "continue", etc.) are ignored and should be handled
by the main worker loop.
Args:
timeout: Maximum time to wait for weights (seconds).
None means use the transport's default timeout.
weights: Ignored (weights come from queue).
model: The model to apply weights to.
strategy: Strategy for applying weights to the model.
Returns:
The received weights, or None if no data available.
"""
# Use transport's default timeout if not specified
if timeout is None:
timeout = self.timeout
data_in, msg = self.weight_queue.get(timeout=timeout)
if msg == "update_weights":
# data_in is (model_id, weights) - we ignore model_id, scheme knows it
_model_id, received_weights = data_in
# Apply weights to model if provided
if model is not None and strategy is not None:
strategy.apply_weights(model, received_weights)
return received_weights
else:
raise ValueError(f"Expected 'update_weights' but got {msg}")
[docs] def setup_connection_and_weights_on_sender(self) -> None:
"""No-op for MPTransport - weights are sent via scheme's synchronize_weights().
The actual sending happens in MultiProcessWeightSyncScheme._setup_connection_and_weights_on_sender_impl(), which:
1. Extracts weights from the context (e.g., collector.policy)
2. Calls send_weights_async() on all worker transports
3. Sends initial weights through queues to all workers
This is similar to SharedMemTransport.setup_connection_and_weights_on_sender() which
sends shared memory buffer references via queues.
"""
[docs] def setup_connection_and_weights_on_receiver(
self,
*,
worker_idx: int,
weights: Any = None,
model: Any = None,
strategy: Any = None,
) -> Any:
"""Receive initial weights from sender during worker initialization.
This method blocks waiting for the initial weights to be sent from the main process
via queue. Similar to SharedMemTransport.setup_connection_and_weights_on_receiver() which receives
shared memory buffer references via queues, this receives the actual weights via queues.
The received weights are then applied to the worker's model by the scheme's synchronize_weights().
Args:
worker_idx: The worker index (used for logging/debugging).
weights: Ignored (weights come from queue).
model: Ignored.
strategy: Ignored.
Returns:
The received weights if available, None otherwise (weights will come later via receive()).
"""
# Wait for initial weights (blocking)
data_in, msg = self.weight_queue.get(timeout=self.timeout)
if msg == "update_weights":
# data_in is (model_id, weights), extract just the weights
_, received_weights = data_in
return received_weights
else:
raise ValueError(f"Expected 'update_weights' but got {msg}")