Shortcuts

DistributedTransport

class torchrl.weight_update.DistributedTransport(store=None, rank=None, sync=True)[source]

torch.distributed transport for communicating with a single distributed worker.

This transport handles weight updates for ONE specific distributed worker via torch.distributed send/recv. Multiple transports are created for multiple workers, following the same pattern as multiprocess collectors.

check_connection() bool[source]

Check if torch.distributed is initialized.

receive_weights(timeout: float = 1.0) tuple[str, Any] | None[source]

Distributed workers receive weights through torch.distributed primitives.

send_weights(model_id: str, weights: Any) None[source]

Send weights to the distributed worker.

Note: We don’t pass model_id to the remote collector because remote collectors don’t have weight senders - they apply weights directly to their local policy.

send_weights_async(model_id: str, weights: Any) None[source]

Send weights to distributed worker without waiting for acknowledgment.

Use wait_ack() to wait for acknowledgment after sending to all workers.

wait_ack() None[source]

Wait for acknowledgment from distributed worker.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources