RayWeightSyncScheme¶
- class torchrl.weight_update.RayWeightSyncScheme(strategy: Literal['state_dict', 'tensordict'] = 'state_dict')[source]¶
Weight synchronization for Ray distributed computing.
This scheme uses Ray’s object store and remote calls to synchronize weights across distributed workers (Ray actors).
Each remote collector gets its own transport, following the same pattern as multiprocess collectors.
- create_receiver() WeightReceiver¶
Create a receiver for this scheme (legacy).
- Returns:
WeightReceiver instance configured for this scheme.
- create_sender() WeightSender¶
Create a sender for this scheme (legacy).
- Returns:
WeightSender instance configured for this scheme.
- create_transport(pipe_or_context: Any) TransportBackend[source]¶
Create Ray-based transport for a specific remote collector.
- Parameters:
pipe_or_context – The Ray actor handle for the remote collector.
- Returns:
RayTransport configured for this specific remote collector.
- get_receiver() WeightReceiver¶
Get the receiver instance.
- Returns:
Receiver instance for receiving weights in this worker
- Raises:
RuntimeError – If init_on_worker() hasn’t been called yet
- get_sender() WeightSender¶
Get the sender instance.
- Returns:
Sender instance for sending weights to workers
- Raises:
RuntimeError – If init_on_sender() hasn’t been called yet
- init_on_sender(model_id: str, context: Any = None, **kwargs) None[source]¶
Initialize on the main process (sender side).
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object providing remote_collectors
**kwargs – Alternative to context (remote_collectors, source_model, etc.)
- init_on_worker(model_id: str, context: Any = None, **kwargs) None[source]¶
Initialize on worker process (receiver side).
For Ray workers, weight updates are handled via remote method calls, so this is typically a no-op. The receiver is created but doesn’t need special initialization.
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object (typically the remote collector)
**kwargs – Optional parameters (pipe, model, etc.)
- prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) Any¶
Prepare weights for sending.
This method handles weight extraction, conversion, and any scheme-specific preparation (e.g., cache lookups for SharedMemWeightSyncScheme).
- Parameters:
weights – Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.)
model_id – The model identifier (e.g., “policy”)
strategy – WeightStrategy for extracting/converting weights
context – Optional context (e.g., collector) for model resolution
- Returns:
Prepared weights ready to send via transport