SharedMemTransport¶
- class torchrl.weight_update.SharedMemTransport[source]¶
Shared memory transport for in-place weight updates.
This transport uses queue-based buffer distribution for initialization, then updates shared memory tensors directly for subsequent weight updates. Workers automatically see weight updates without explicit communication.
Initialization flow: - Shared memory buffers are created and sent to workers via per-worker queues - Workers receive the buffer reference and apply weights to their models - Subsequent updates are pure in-place shared memory (zero-copy)
Both CPU and CUDA tensors maintain shared references when sent through mp.Queue.
- receive_weights(timeout: float | None = None, *, weights: Any = None, model: Any = None, strategy: Any = None) → Any | None[source]¶
Apply shared memory weights to the model.
For shared memory, weights are already available (passed via the weights arg). This method applies them to the model, matching the pattern of other transports.
- Parameters:
timeout – Ignored (shared memory access is instant).
weights – The shared memory buffer containing current weights.
model – The model to apply weights to.
strategy – Strategy for applying weights.
- Returns:
The applied weights, or None if not applied.
- register_weights(params_map: dict[int, multiprocessing.context.BaseContext.Queue], init_queues: dict[int, multiprocessing.context.BaseContext.Queue]) → None[source]¶
Initialize per-worker queues for shared memory buffer distribution.
- send_ack(message: str = 'updated') → None[source]¶
No-op for shared memory - no acknowledgment needed.
- send_weights(weights: Any) → None[source]¶
Update weights in-place in shared memory.
- Parameters:
weights – New weights to send. Can be a TensorDictBase or dict.
- Raises:
ValueError – If weights type is unsupported.
- setup_connection_and_weights_on_receiver(*, worker_idx: int | None = None, weights: Any = None, model: Any = None, strategy: Any = None, timeout: float = 10.0) → TensorDictBase[source]¶
Receive shared memory buffer reference from sender via their per-worker queues.
Each worker reads from its own dedicated queue, to avoid race conditions.
- Parameters:
worker_idx – The worker index.
weights – Ignored (weights come from queue).
model – Ignored.
strategy – Ignored.
timeout – Timeout for reading from queue.
- Returns:
The shared memory weights TensorDict.
- setup_connection_and_weights_on_sender() → None[source]¶
Send shared memory buffer reference to workers via their per-worker queues.
Both CPU and CUDA tensors maintain shared references through queues. Each worker reads from its own dedicated queue, to avoid race conditions.
- property unique_weights: list[tensordict.base.TensorDictBase]¶
Get the unique weights.
- Returns:
The unique weights.