VLLMCollectiveTransport¶
- class torchrl.weight_update.llm.VLLMCollectiveTransport(master_address: str, master_port: int, rank: int | None, world_size: int, device: torch.device | str | int | None = None, vllm_engine: Any | None = None)[source]¶
Transport for vLLM using collective communication (NCCL).
COLLECTIVE LAYER ONLY - This class handles the data transfer layer. RPC coordination is handled separately by the caller (sender/receiver).
This transport uses PyTorch distributed collectives to broadcast weights from a trainer (rank 0) to vLLM workers (ranks 1+).
Separation of Concerns: - This class: NCCL collective operations (GPU-GPU data transfer) - Caller (sender/receiver): RPC coordination (when to start collective)
- Parameters:
master_address – Address of the master node for distributed init.
master_port – Port of the master node for distributed init.
rank – Rank of this process (0 for trainer, 1+ for vLLM workers).
world_size – Total number of processes (1 + num_replicas * gpus_per_replica).
device – Device to use for communication (typically cuda:0).
vllm_engine – Optional vLLM engine reference (for receiver side).
Note
The RPC layer (e.g., Ray remote calls) must ensure all ranks call init_all_workers_group() simultaneously before any collective operations.
- init_all_workers_group(model_metadata: dict[str, tuple[torch.dtype, torch.Size]])[source]¶
Initialize the collective communication group.
- Parameters:
model_metadata – Dict mapping param names to (dtype, shape) tuples.
- receive_weights(timeout: float = 1.0) tuple[str, Any] | None[source]¶
Receive weights from broadcaster.
This should only be called from worker ranks (rank > 0). This method is called by vLLM engine internally through collective operations.
- Returns:
None - vLLM handles weight application internally via collectives.
- send_weights(model_id: str, weights: Any) None[source]¶
Broadcast weights to all workers using NCCL.
This method follows AsyncVLLM’s periodic-mono pattern: For each weight: RPC → NCCL broadcast → Wait for RPC completion
This should only be called from rank 0 (trainer).
- Parameters:
model_id – ID of the model (used for logging).
weights – TensorDict or dict of weights to broadcast.