RayModuleTransformScheme¶
- class torchrl.weight_update.RayModuleTransformScheme(strategy: str = 'tensordict')[source]¶
Weight synchronization for RayModuleTransform actors.
This scheme is designed specifically for updating models hosted within Ray actors, such as RayModuleTransform instances. It creates a transport that directly calls the actor’s weight update methods.
- Parameters:
strategy (str) – The weight transmission strategy (“state_dict” or “tensordict”). Default is “tensordict”.
- create_receiver() RayModuleTransformReceiver[source]¶
Create a specialized receiver for Ray actor communication.
- create_sender() RayModuleTransformSender[source]¶
Create a specialized sender for Ray actor communication.
- create_transport(pipe_or_context: Any) TransportBackend[source]¶
Create RayActorTransport for the given actor.
- Parameters:
pipe_or_context – Either a Ray actor reference or a context object from which to extract the actor reference.
- Returns:
RayActorTransport configured with the actor reference.
- get_receiver() WeightReceiver¶
Get the receiver instance.
- Returns:
Receiver instance for receiving weights in this worker
- Raises:
RuntimeError – If init_on_worker() hasn’t been called yet
- get_sender() WeightSender¶
Get the sender instance.
- Returns:
Sender instance for sending weights to workers
- Raises:
RuntimeError – If init_on_sender() hasn’t been called yet
- init_on_sender(model_id: str, context: Any = None, **kwargs) None[source]¶
Initialize on the main process (sender side).
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object providing actor references
**kwargs – Alternative to context (actors, actor_refs, source_model, etc.)
- init_on_worker(model_id: str, context: Any = None, **kwargs) None[source]¶
Initialize on worker process (receiver side).
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object (typically the actor itself)
**kwargs – Optional parameters (actor_ref, model, etc.)
- prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) Any¶
Prepare weights for sending.
This method handles weight extraction, conversion, and any scheme-specific preparation (e.g., cache lookups for SharedMemWeightSyncScheme).
- Parameters:
weights – Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.)
model_id – The model identifier (e.g., “policy”)
strategy – WeightStrategy for extracting/converting weights
context – Optional context (e.g., collector) for model resolution
- Returns:
Prepared weights ready to send via transport