DistributedTransport¶
- class torchrl.weight_update.DistributedTransport(store=None, rank=None, sync=True)[source]¶
torch.distributed transport for communicating with a single distributed worker.
This transport handles weight updates for ONE specific distributed worker via torch.distributed send/recv. Multiple transports are created for multiple workers, following the same pattern as multiprocess collectors.
- receive_weights(timeout: float = 1.0) tuple[str, Any] | None[source]¶
Receive weights via torch.distributed, using TCPStore for signaling.
This implements the RPC-like pattern: 1. Check TCPStore for signal (non-blocking) 2. If signal present, receive weights via torch.distributed 3. Clean up signal and send acknowledgment
- Parameters:
timeout – Timeout for receiving (currently not used for TCPStore check)
- Returns:
Tuple of (model_id, weights) if weights were received, None otherwise.
- send_ack(message: str = 'updated') None[source]¶
Send acknowledgment back to sender via TCPStore.
- Parameters:
message – Acknowledgment message to send (default: “updated”)
- send_weights(model_id: str, weights: Any) None[source]¶
Send weights to the distributed worker.
Note: We don’t pass model_id to the remote collector because remote collectors don’t have weight senders - they apply weights directly to their local policy.