VLLMCollectiveTransport¶
- class torchrl.weight_update.llm.VLLMCollectiveTransport(master_address: str, master_port: int, rank: int | None, world_size: int, device: device | str | int | None = None, vllm_engine: Any | None = None)[source]¶
Transport for vLLM using vLLM’s native WeightTransferConfig API (vLLM 0.17+).
This transport uses vLLM’s built-in NCCL weight transfer engine to broadcast weights from a trainer (rank 0) to vLLM workers (ranks 1+).
- Parameters:
master_address – Address of the master node for distributed init.
master_port – Port of the master node for distributed init.
rank – Rank of this process (0 for trainer, 1+ for vLLM workers).
world_size – Total number of processes (1 + num_replicas * gpus_per_replica).
device – Device to use for communication (typically cuda:0).
vllm_engine – Optional vLLM engine reference (for receiver side).
- init_all_workers_group(model_metadata: dict[str, tuple[dtype, Size]], gpus_per_replica: int | None = None)[source]¶
Initialize the collective communication group using vLLM’s native API.
- Parameters:
model_metadata – Dict mapping param names to (dtype, shape) tuples.
gpus_per_replica – GPUs per replica (for rank_offset calculation). Inferred if not provided.