VLLMWeightSyncScheme¶
- class torchrl.weight_update.llm.VLLMWeightSyncScheme(master_address: str | None = None, master_port: int | None = None, gpus_per_replica: int = 1, num_replicas: int = 1, strategy: Literal['tensordict', 'state_dict'] = 'tensordict', device: torch.device | str | int = 0)[source]¶
Weight synchronization scheme for vLLM engines.
This scheme uses collective communication (NCCL) to broadcast weights from a trainer to vLLM inference workers with parallelism support.
- Parameters:
master_address – Address of the master node. Defaults to “localhost”.
master_port – Port of the master node. If None, will auto-assign.
gpus_per_replica – Number of GPUs per replica (tp_size × dp_size × pp_size).
num_replicas – Number of vLLM engine replicas. Defaults to 1.
strategy – Weight extraction strategy (“tensordict” or “state_dict”).
device – Device index to use for communication. Defaults to 0. Note: When using Ray, each actor sees only its assigned GPU as device 0 due to CUDA_VISIBLE_DEVICES isolation. You should typically use 0.
Warning
Collective communication requires ALL ranks to participate simultaneously. Both the sender (trainer, rank 0) and all receivers (vLLM workers, ranks 1+) must call
init_all_workers_group()at approximately the same time for the collective handshake to succeed. Do NOT wait for one init to complete before starting the other - start both and wait for both together.Note
The world_size for NCCL will be: 1 (trainer) + num_replicas × gpus_per_replica (vLLM workers)
Example
>>> # Single replica with 2 GPUs (e.g., tp_size=2) >>> scheme = VLLMWeightSyncScheme( ... master_port=12345, ... gpus_per_replica=2, ... num_replicas=1, ... strategy="tensordict" ... ) # world_size = 1 + 1*2 = 3 >>> >>> # Multiple replicas with 1 GPU each >>> scheme = VLLMWeightSyncScheme( ... master_port=12345, ... gpus_per_replica=1, ... num_replicas=2, ... strategy="tensordict" ... ) # world_size = 1 + 2*1 = 3 >>> >>> # Multiple replicas with tp_size=2, dp_size=1, pp_size=1 >>> scheme = VLLMWeightSyncScheme( ... master_port=12345, ... gpus_per_replica=2, # 2*1*1 ... num_replicas=3, ... strategy="tensordict" ... ) # world_size = 1 + 3*2 = 7 >>> >>> # In trainer process (rank 0) >>> sender = VLLMWeightSender(scheme) >>> sender.register_model(policy) >>> >>> # In vLLM worker process (rank 1+) >>> receiver = VLLMWeightReceiver(scheme, vllm_engine) >>> >>> # IMPORTANT: Both must init simultaneously for collective handshake >>> # With Ray: >>> init_sender = sender_actor.init_all_workers_group.remote(metadata) >>> init_receiver = receiver_actor.init_all_workers_group.remote(metadata) >>> ray.get([init_sender, init_receiver]) # Wait for both together >>> >>> # After init, updates work normally >>> sender.update_weights() >>> # Weights are received automatically via collectives
- create_receiver(vllm_engine) VLLMWeightReceiver[source]¶
Create a weight receiver for a vLLM worker process.
- Parameters:
vllm_engine – The vLLM engine instance (must implement RLvLLMEngine interface).
- create_sender() VLLMWeightSender[source]¶
Create a weight sender for the trainer process.
- create_transport(pipe_or_context: Any) VLLMCollectiveTransport[source]¶
Create transport for collective communication.
For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). This method is required by the base class but transport creation for vLLM is more complex and typically handled by sender/receiver initialization.
- Parameters:
pipe_or_context – Not used for vLLM (kept for API compatibility).
- Returns:
A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called).