Shortcuts

VLLMCollectiveTransport

class torchrl.weight_update.llm.VLLMCollectiveTransport(master_address: str, master_port: int, rank: int | None, world_size: int, device: torch.device | str | int | None = None, vllm_engine: Any | None = None)[source]

Transport for vLLM using collective communication (NCCL).

COLLECTIVE LAYER ONLY - This class handles the data transfer layer. RPC coordination is handled separately by the caller (sender/receiver).

This transport uses PyTorch distributed collectives to broadcast weights from a trainer (rank 0) to vLLM workers (ranks 1+).

Separation of Concerns: - This class: NCCL collective operations (GPU-GPU data transfer) - Caller (sender/receiver): RPC coordination (when to start collective)

Parameters:
  • master_address – Address of the master node for distributed init.

  • master_port – Port of the master node for distributed init.

  • rank – Rank of this process (0 for trainer, 1+ for vLLM workers).

  • world_size – Total number of processes (1 + num_replicas * gpus_per_replica).

  • device – Device to use for communication (typically cuda:0).

  • vllm_engine – Optional vLLM engine reference (for receiver side).

Note

The RPC layer (e.g., Ray remote calls) must ensure all ranks call init_all_workers_group() simultaneously before any collective operations.

check_connection() bool[source]

Check if the communication group is initialized.

init_all_workers_group(model_metadata: dict[str, tuple[torch.dtype, torch.Size]])[source]

Initialize the collective communication group.

Parameters:

model_metadata – Dict mapping param names to (dtype, shape) tuples.

receive_weights(timeout: float = 1.0) tuple[str, Any] | None[source]

Receive weights from broadcaster.

This should only be called from worker ranks (rank > 0). This method is called by vLLM engine internally through collective operations.

Returns:

None - vLLM handles weight application internally via collectives.

send_weights(model_id: str, weights: Any) None[source]

Broadcast weights to all workers using NCCL.

This method follows AsyncVLLM’s periodic-mono pattern: For each weight: RPC → NCCL broadcast → Wait for RPC completion

This should only be called from rank 0 (trainer).

Parameters:
  • model_id – ID of the model (used for logging).

  • weights – TensorDict or dict of weights to broadcast.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources