Shortcuts

SharedMemWeightSyncScheme

class torchrl.weight_update.SharedMemWeightSyncScheme(policy_weights: dict[str, tensordict.base.TensorDictBase] | None = None, strategy: str = 'tensordict', auto_register: bool = True)[source]

Weight synchronization using shared memory.

This scheme mimics the old WeightUpdater behavior by using shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing.

By default, this scheme uses lazy registration: models are automatically registered on the first weight send. This makes it seamless to use with configuration systems like Hydra where schemes are created before models are available.

Parameters:
  • policy_weights – Dictionary mapping model_id to shared TensorDict weights. Can be empty if using lazy registration (auto_register=True).

  • strategy – The weight transmission strategy (default: “tensordict”).

  • auto_register – Whether to automatically register models on first weight send. Default is True. Set to False to require explicit registration via register_shared_weights().

Example

>>> # With auto-registration (default) - works with Hydra configs
>>> scheme = SharedMemWeightSyncScheme()
>>> # Models are auto-registered on first weight send
>>> # With explicit registration
>>> scheme = SharedMemWeightSyncScheme(auto_register=False)
>>> shared_weights = TensorDict.from_module(model).share_memory_()
>>> scheme.register_shared_weights("policy", shared_weights)
create_receiver() WeightReceiver

Create a receiver for this scheme (legacy).

Returns:

WeightReceiver instance configured for this scheme.

create_sender() WeightSender

Create a sender for this scheme (legacy).

Returns:

WeightSender instance configured for this scheme.

create_transport(pipe_or_context: Any) TransportBackend[source]

Create shared memory transport and register pipe for lazy buffer distribution (legacy).

For lazy registration to work, we register each worker’s pipe with the transport. On first weight send, the transport will send buffer references via these pipes.

Returns the shared transport instance that all workers will use. Since this is shared memory, there’s only one transport shared by all workers.

get_receiver() WeightReceiver

Get the receiver instance.

Returns:

Receiver instance for receiving weights in this worker

Raises:

RuntimeError – If init_on_worker() hasn’t been called yet

get_sender() WeightSender

Get the sender instance.

Returns:

Sender instance for sending weights to workers

Raises:

RuntimeError – If init_on_sender() hasn’t been called yet

init_on_sender(model_id: str, context: Any = None, **kwargs) None[source]

Initialize on the main process (sender side).

For SharedMemWeightSyncScheme, this handles: 1. Getting cached shared memory weights from context 2. Pre-registering the weights with the transport 3. Distributing buffer references to all workers (avoiding later deadlock)

Parameters:
  • model_id – Identifier for the model being synchronized

  • context – Optional context object providing pipes, cached_weights

  • **kwargs – Alternative to context (pipes, cached_weights, etc.)

init_on_worker(model_id: str, context: Any = None, **kwargs) None[source]

Initialize on worker process (receiver side).

Parameters:
  • model_id – Identifier for the model being synchronized

  • context – Optional context object providing pipe and model

  • **kwargs – Alternative to context (pipe, model, etc.)

prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) Any[source]

Prepare weights for SharedMemWeightSyncScheme.

For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights from the context (collector) to avoid extracting fresh (non-shared) weights.

Parameters:
  • weights – Raw weights input

  • model_id – The model identifier

  • strategy – WeightStrategy for extracting/converting weights

  • context – Optional context (e.g., collector) for cache lookup

Returns:

Shared memory weights ready to send

register_shared_weights(model_id: str, weights: TensorDictBase) None[source]

Register shared memory weights for a model.

This method allows explicit registration of shared weights. It’s optional when auto_register=True (the default), but required when auto_register=False.

Parameters:
  • model_id – Identifier for the model.

  • weights – Shared memory TensorDict containing the model’s weights.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources