Shortcuts

SharedMemWeightSyncScheme

class torchrl.weight_update.SharedMemWeightSyncScheme(strategy: str = 'tensordict', sync: bool = True)[source]

Weight synchronization using shared memory.

This scheme uses shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing.

A background thread on the receiver side listens for “receive” instructions from the sender. When an instruction arrives, the thread applies the current shared memory weights to the model and sends an acknowledgment.

Parameters:
  • strategy – The weight transmission strategy (default: “tensordict”).

  • sync – If True (default), send() blocks until receiver acknowledges. If False, send() returns immediately (use send_async/wait_async).

Example

>>> # Basic usage
>>> scheme = SharedMemWeightSyncScheme()
>>> # Weights are initialized via init_on_sender()
apply_weights(weights: TensorDictBase, inplace: bool = True) None

Apply weights to the model.

Parameters:
  • weights – The weights to apply.

  • inplace – Whether to apply weights in place. Default is True.

connect(*, worker_idx: int | None = None, weights: Any | None = None) None

Method to be called once the workers have started.

Triggers a rendez-vous for the workers to receive their copy of the weights.

Dispatches to _setup_connection_and_weights_on_sender_impl() or _setup_connection_and_weights_on_receiver_impl() based on which initialization was performed.

property context: Any | None

Get the context object (e.g., collector), if available.

Returns:

The context object if available, None otherwise.

create_transport(**kwargs) TransportBackend[source]

Create shared memory transport.

Returns the shared transport instance that all workers will use. Since this is shared memory, there’s only one transport shared by all workers.

Note

This is used internally by init_on_sender/init_on_receiver.

init_on_receiver(*, model_id: str, context: Any = None, **kwargs) None

Initialize on worker process (receiver side).

This method is called once in each worker’s initialization.

Parameters:
  • model_id – Identifier for the model being synchronized

  • context – Optional context object (e.g., inner collector)

  • **kwargs – Alternative to context (model, etc.)

init_on_sender(*args, **kwargs) None

Initialize on the main process (sender side).

This method is called once in the collector’s _run_processes() method, after workers have been started and are ready to receive messages.

property model: Any | None

Get the model object, if available.

Returns:

The model object if available, None otherwise.

property model_id: str | None

Get the model ID for this scheme.

Returns:

The model ID if set, None otherwise.

prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) Any[source]

Prepare weights for SharedMemWeightSyncScheme.

When weights=None, we extract fresh weights from the model and update the shared memory buffer in-place so workers see the change.

Parameters:
  • weights – Raw weights input

  • model_id – The model identifier

  • strategy – WeightStrategy for extracting/converting weights

  • context – Optional context (e.g., collector) for cache lookup

Returns:

Shared memory weights ready to send

receive(timeout: float | None = None) tensordict.base.TensorDictBase | None

Check for and apply new weights (non-blocking).

This method is called in the worker’s main loop to check if new weights have been sent. If weights are available, they are applied to the registered model immediately, and the update is cascaded to any sub-collectors via context.update_policy_weights_().

Parameters:

timeout – Maximum time to wait for weights (seconds). None means no timeout (blocking). Some transports may not support timeout and will raise ValueError if specified.

Returns:

The received weights if available, None otherwise.

Note: For SharedMemWeightSyncScheme, this always returns None since workers automatically see updates via shared memory.

property receiver_transport: torchrl.weight_update.weight_sync_schemes.TransportBackend | None

Get the receiver transport.

Returns:

The receiver transport.

send(weights: Any = None, worker_ids: int | list[int] | None = None) None[source]

Send weights via shared memory (in-place update).

For SharedMemWeightSyncScheme: 1. prepare_weights() updates the shared memory buffer in-place 2. _send_instruction() tells workers to apply the new weights 3. If sync=True, waits for acknowledgments from all workers

Parameters:
  • weights – Weights to send (can be None to extract from model).

  • worker_ids – Which workers to notify (None = all workers).

property sender_transports: dict[int, torchrl.weight_update.weight_sync_schemes.TransportBackend]

Get the sender transports.

Returns:

The sender transports.

property shared_transport: torchrl.weight_update.weight_sync_schemes.TransportBackend | None

Get the shared transport.

Returns:

The shared transport.

shutdown() None[source]

Stop the background receiver thread and clean up.

property weights: Any | None

Get the current weights from shared memory.

For SharedMemWeightSyncScheme: - On sender side: weights are in transport’s _unique_weights - On receiver side: weights are in _receiver_shared_weights (stored during connect())

Returns:

The weights TensorDict if available, None otherwise.

property worker_idx: int | None

Get the worker index for this scheme.

Returns:

The worker index if set, None otherwise.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources