Shortcuts

VLLMWeightSyncScheme

class torchrl.weight_update.llm.VLLMWeightSyncScheme(master_address: str | None = None, master_port: int | None = None, gpus_per_replica: int = 1, num_replicas: int = 1, strategy: Literal['tensordict', 'state_dict'] = 'tensordict', device: torch.device | str | int = 0)[source]

Weight synchronization scheme for vLLM engines.

This scheme uses collective communication (NCCL) to broadcast weights from a trainer to vLLM inference workers with parallelism support.

Parameters:
  • master_address – Address of the master node. Defaults to “localhost”.

  • master_port – Port of the master node. If None, will auto-assign.

  • gpus_per_replica – Number of GPUs per replica (tp_size × dp_size × pp_size).

  • num_replicas – Number of vLLM engine replicas. Defaults to 1.

  • strategy – Weight extraction strategy (“tensordict” or “state_dict”).

  • device – Device index to use for communication. Defaults to 0. Note: When using Ray, each actor sees only its assigned GPU as device 0 due to CUDA_VISIBLE_DEVICES isolation. You should typically use 0.

Warning

Collective communication requires ALL ranks to participate simultaneously. Both the sender (trainer, rank 0) and all receivers (vLLM workers, ranks 1+) must call init_all_workers_group() at approximately the same time for the collective handshake to succeed. Do NOT wait for one init to complete before starting the other - start both and wait for both together.

Note

The world_size for NCCL will be: 1 (trainer) + num_replicas × gpus_per_replica (vLLM workers)

Example

>>> # Single replica with 2 GPUs (e.g., tp_size=2)
>>> scheme = VLLMWeightSyncScheme(
...     master_port=12345,
...     gpus_per_replica=2,
...     num_replicas=1,
...     strategy="tensordict"
... )  # world_size = 1 + 1*2 = 3
>>>
>>> # Multiple replicas with 1 GPU each
>>> scheme = VLLMWeightSyncScheme(
...     master_port=12345,
...     gpus_per_replica=1,
...     num_replicas=2,
...     strategy="tensordict"
... )  # world_size = 1 + 2*1 = 3
>>>
>>> # Multiple replicas with tp_size=2, dp_size=1, pp_size=1
>>> scheme = VLLMWeightSyncScheme(
...     master_port=12345,
...     gpus_per_replica=2,  # 2*1*1
...     num_replicas=3,
...     strategy="tensordict"
... )  # world_size = 1 + 3*2 = 7
>>>
>>> # In trainer process (rank 0)
>>> sender = VLLMWeightSender(scheme)
>>> sender.register_model(policy)
>>>
>>> # In vLLM worker process (rank 1+)
>>> receiver = VLLMWeightReceiver(scheme, vllm_engine)
>>>
>>> # IMPORTANT: Both must init simultaneously for collective handshake
>>> # With Ray:
>>> init_sender = sender_actor.init_all_workers_group.remote(metadata)
>>> init_receiver = receiver_actor.init_all_workers_group.remote(metadata)
>>> ray.get([init_sender, init_receiver])  # Wait for both together
>>>
>>> # After init, updates work normally
>>> sender.update_weights()
>>> # Weights are received automatically via collectives
create_receiver(vllm_engine) VLLMWeightReceiver[source]

Create a weight receiver for a vLLM worker process.

Parameters:

vllm_engine – The vLLM engine instance (must implement RLvLLMEngine interface).

create_sender() VLLMWeightSender[source]

Create a weight sender for the trainer process.

create_transport(pipe_or_context: Any) VLLMCollectiveTransport[source]

Create transport for collective communication.

For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). This method is required by the base class but transport creation for vLLM is more complex and typically handled by sender/receiver initialization.

Parameters:

pipe_or_context – Not used for vLLM (kept for API compatibility).

Returns:

A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources