WeightSyncScheme¶
- class torchrl.weight_update.WeightSyncScheme(strategy: Literal['state_dict', 'tensordict'] = 'state_dict')[source]¶
Configuration for how to synchronize ONE model across workers.
A scheme manages synchronization of ONE model across workers. The collector maintains a dict of {model_id: scheme} pairs.
- create_receiver() WeightReceiver[source]¶
Create a receiver for this scheme (legacy).
- Returns:
WeightReceiver instance configured for this scheme.
- create_sender() WeightSender[source]¶
Create a sender for this scheme (legacy).
- Returns:
WeightSender instance configured for this scheme.
- abstract create_transport(pipe_or_context: Any) TransportBackend[source]¶
Create transport for communication.
- Parameters:
pipe_or_context – Either a pipe connection or context object to extract pipe from.
- Returns:
A transport backend instance.
- get_receiver() WeightReceiver[source]¶
Get the receiver instance.
- Returns:
Receiver instance for receiving weights in this worker
- Raises:
RuntimeError – If init_on_worker() hasn’t been called yet
- get_sender() WeightSender[source]¶
Get the sender instance.
- Returns:
Sender instance for sending weights to workers
- Raises:
RuntimeError – If init_on_sender() hasn’t been called yet
- init_on_sender(model_id: str, context: Any = None, **kwargs) None[source]¶
Initialize on the main process (sender side).
This method is called once in the collector’s _run_processes() method, after workers have been started and are ready to receive messages.
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object (e.g., collector) providing: - .pipes: list[mp.Connection] - .get_model(model_id: str) -> nn.Module - .get_cached_weights(model_id: str) -> TensorDict | None - .num_workers: int
**kwargs – Alternative to context (pipes, num_workers, model, cached_weights, etc.)
- init_on_worker(model_id: str, context: Any = None, **kwargs) None[source]¶
Initialize on worker process (receiver side).
This method is called once in each worker’s initialization.
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object (e.g., inner collector) providing: - .pipe: mp.Connection - .get_model(model_id: str) -> nn.Module
**kwargs – Alternative to context (pipe, model, etc.)
- prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) Any[source]¶
Prepare weights for sending.
This method handles weight extraction, conversion, and any scheme-specific preparation (e.g., cache lookups for SharedMemWeightSyncScheme).
- Parameters:
weights – Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.)
model_id – The model identifier (e.g., “policy”)
strategy – WeightStrategy for extracting/converting weights
context – Optional context (e.g., collector) for model resolution
- Returns:
Prepared weights ready to send via transport