VLLMWeightSyncScheme¶
- class torchrl.weight_update.llm.VLLMWeightSyncScheme(master_address: str | None = None, master_port: int | None = None, gpus_per_replica: int = 1, num_replicas: int = 1, strategy: Literal['tensordict', 'state_dict'] = 'tensordict', device: torch.device | str | int = 0)[source]¶
Weight synchronization scheme for vLLM engines.
This scheme uses collective communication (NCCL) to broadcast weights from a trainer to vLLM inference workers with parallelism support.
- Parameters:
master_address – Address of the master node. Defaults to “localhost”.
master_port – Port of the master node. If None, will auto-assign.
gpus_per_replica – Number of GPUs per replica (tp_size × dp_size × pp_size).
num_replicas – Number of vLLM engine replicas. Defaults to 1.
strategy – Weight extraction strategy (“tensordict” or “state_dict”).
device – Device index to use for communication. Defaults to 0. Note: When using Ray, each actor sees only its assigned GPU as device 0 due to CUDA_VISIBLE_DEVICES isolation. You should typically use 0.
Warning
Collective communication requires ALL ranks to participate simultaneously. Both the sender (trainer, rank 0) and all receivers (vLLM workers, ranks 1+) must call
init_all_workers_group()at approximately the same time for the collective handshake to succeed. Do NOT wait for one init to complete before starting the other - start both and wait for both together.Note
The world_size for NCCL will be: 1 (trainer) + num_replicas × gpus_per_replica (vLLM workers)
Example
>>> # Single replica with 2 GPUs (e.g., tp_size=2) >>> scheme = VLLMWeightSyncScheme( ... master_port=12345, ... gpus_per_replica=2, ... num_replicas=1, ... strategy="tensordict" ... ) # world_size = 1 + 1*2 = 3 >>> >>> # Multiple replicas with 1 GPU each >>> scheme = VLLMWeightSyncScheme( ... master_port=12345, ... gpus_per_replica=1, ... num_replicas=2, ... strategy="tensordict" ... ) # world_size = 1 + 2*1 = 3 >>> >>> # Multiple replicas with tp_size=2, dp_size=1, pp_size=1 >>> scheme = VLLMWeightSyncScheme( ... master_port=12345, ... gpus_per_replica=2, # 2*1*1 ... num_replicas=3, ... strategy="tensordict" ... ) # world_size = 1 + 3*2 = 7 >>> >>> # In trainer process (rank 0) >>> sender = VLLMWeightSender(scheme) >>> sender.register_model(policy) >>> >>> # In vLLM worker process (rank 1+) >>> receiver = VLLMWeightReceiver(scheme, vllm_engine) >>> >>> # IMPORTANT: Both must init simultaneously for collective handshake >>> # With Ray: >>> init_sender = sender_actor.init_all_workers_group.remote(metadata) >>> init_receiver = receiver_actor.init_all_workers_group.remote(metadata) >>> ray.get([init_sender, init_receiver]) # Wait for both together >>> >>> # After init, updates work normally >>> sender.update_weights() >>> # Weights are received automatically via collectives
- create_receiver(vllm_engine) VLLMWeightReceiver[source]¶
Create a weight receiver for a vLLM worker process.
- Parameters:
vllm_engine – The vLLM engine instance (must implement RLvLLMEngine interface).
- create_sender() VLLMWeightSender[source]¶
Create a weight sender for the trainer process.
- create_transport(pipe_or_context: Any) VLLMCollectiveTransport[source]¶
Create transport for collective communication.
For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). This method is required by the base class but transport creation for vLLM is more complex and typically handled by sender/receiver initialization.
- Parameters:
pipe_or_context – Not used for vLLM (kept for API compatibility).
- Returns:
A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called).
- get_receiver() WeightReceiver¶
Get the receiver instance.
- Returns:
Receiver instance for receiving weights in this worker
- Raises:
RuntimeError – If init_on_worker() hasn’t been called yet
- get_sender() WeightSender¶
Get the sender instance.
- Returns:
Sender instance for sending weights to workers
- Raises:
RuntimeError – If init_on_sender() hasn’t been called yet
- init_on_sender(model_id: str, context: Any = None, **kwargs) None¶
Initialize on the main process (sender side).
This method is called once in the collector’s _run_processes() method, after workers have been started and are ready to receive messages.
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object (e.g., collector) providing: - .pipes: list[mp.Connection] - .get_model(model_id: str) -> nn.Module - .get_cached_weights(model_id: str) -> TensorDict | None - .num_workers: int
**kwargs – Alternative to context (pipes, num_workers, model, cached_weights, etc.)
- init_on_worker(model_id: str, context: Any = None, **kwargs) None¶
Initialize on worker process (receiver side).
This method is called once in each worker’s initialization.
- Parameters:
model_id – Identifier for the model being synchronized
context – Optional context object (e.g., inner collector) providing: - .pipe: mp.Connection - .get_model(model_id: str) -> nn.Module
**kwargs – Alternative to context (pipe, model, etc.)
- prepare_weights(weights: Any, model_id: str, strategy: WeightStrategy, context: Any = None) Any¶
Prepare weights for sending.
This method handles weight extraction, conversion, and any scheme-specific preparation (e.g., cache lookups for SharedMemWeightSyncScheme).
- Parameters:
weights – Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.)
model_id – The model identifier (e.g., “policy”)
strategy – WeightStrategy for extracting/converting weights
context – Optional context (e.g., collector) for model resolution
- Returns:
Prepared weights ready to send via transport