SharedMemWeightSyncScheme¶
- class torchrl.weight_update.SharedMemWeightSyncScheme(policy_weights: dict[str, tensordict.base.TensorDictBase] | None = None, strategy: str = 'tensordict', auto_register: bool = True)[source]¶
Weight synchronization using shared memory.
This scheme mimics the old WeightUpdater behavior by using shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing.
By default, this scheme uses lazy registration: models are automatically registered on the first weight send. This makes it seamless to use with configuration systems like Hydra where schemes are created before models are available.
- Parameters:
policy_weights – Dictionary mapping model_id to shared TensorDict weights. Can be empty if using lazy registration (auto_register=True).
strategy – The weight transmission strategy (default: “tensordict”).
auto_register – Whether to automatically register models on first weight send. Default is True. Set to False to require explicit registration via register_shared_weights().
Example
>>> # With auto-registration (default) - works with Hydra configs >>> scheme = SharedMemWeightSyncScheme() >>> # Models are auto-registered on first weight send
>>> # With explicit registration >>> scheme = SharedMemWeightSyncScheme(auto_register=False) >>> shared_weights = TensorDict.from_module(model).share_memory_() >>> scheme.register_shared_weights("policy", shared_weights)
- create_receiver() → WeightReceiver¶
Create a receiver for this scheme (legacy).
- Returns:
WeightReceiver instance configured for this scheme.
- create_sender() → WeightSender¶
Create a sender for this scheme (legacy).
- Returns:
WeightSender instance configured for this scheme.
- create_transport(pipe_or_context: Any) → TransportBackend[source]¶
Create shared memory transport and register pipe for lazy buffer distribution (legacy).
For lazy registration to work, we register each worker’s pipe with the transport. On first weight send, the transport will send buffer references via these pipes.
Returns the shared transport instance that all workers will use. Since this is shared memory, there’s only one transport shared by all workers.
- get_receiver() → WeightReceiver¶
Get the receiver instance.
- Returns:
Receiver instance for receiving weights in this worker
- Raises:
RuntimeError – If init_on_worker() hasn’t been called yet
- get_sender() → WeightSender¶
Get the sender instance.
- Returns:
Sender instance for sending weights to workers
- Raises:
RuntimeError – If init_on_sender() hasn’t been called yet
- init_on_sender(model_id: str, context: Any = None, **kwargs) → None[source]¶
Initialize on the main process (sender side).
For SharedMemWeightSyncScheme, this handles: 1. Getting cached shared memory weights from context 2. Pre-registering the weights with the transport 3. Distributing buffer references to all workers (avoiding later deadlock)
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object providing pipes, cached_weights
**kwargs – Alternative to context (pipes, cached_weights, etc.)
- init_on_worker(model_id: str, context: Any = None, **kwargs) → None[source]¶
Initialize on worker process (receiver side).
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object providing pipe and model
**kwargs – Alternative to context (pipe, model, etc.)
- prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) → Any[source]¶
Prepare weights for SharedMemWeightSyncScheme.
For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights from the context (collector) to avoid extracting fresh (non-shared) weights.
- Parameters:
weights – Raw weights input
model_id – The model identifier
strategy – WeightStrategy for extracting/converting weights
context – Optional context (e.g., collector) for cache lookup
- Returns:
Shared memory weights ready to send
- register_shared_weights(model_id: str, weights: TensorDictBase) → None[source]¶
Register shared memory weights for a model.
This method allows explicit registration of shared weights. It’s optional when auto_register=True (the default), but required when auto_register=False.
- Parameters:
model_id – Identifier for the model.
weights – Shared memory TensorDict containing the model’s weights.