Shortcuts

Source code for torchrl.modules.llm.backends.vllm

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Recipes for vLLM instantiation.

From https://docs.vllm.ai/en/v0.7.0/getting_started/examples/rlhf.html
"""


from __future__ import annotations

import os

from contextlib import nullcontext

import torch
from torchrl._utils import logger as torchrl_logger

from torchrl.modules.llm.utils import _cuda_visible_devices

try:
    from vllm import LLM
    from vllm.utils import get_open_port
    from vllm.worker.worker import Worker
except ImportError:

    class LLM:  # noqa
        ...

    class Worker:  # noqa
        ...

    class get_open_port:  # noqa
        ...


[docs]def stateless_init_process_group( master_address: str | None, master_port: str | None, rank, world_size, device ): """Initializes a stateless process group for distributed communication. Creates a `StatelessProcessGroup` instance without relying on the global process group in `torch.distributed`. This approach is recommended for initializing data-plane communication (NCCL) between external processes (e.g., training processes) and vLLM workers. Args: master_address (str | None): The address of the master node. Defaults to "localhost" if not specified. master_port (str | None): The port used by the master node. Automatically assigns an open port if not specified. rank (int): The rank of the current process. world_size (int): The total number of processes in the distributed group. device: The device to use for communication. Returns: PyNcclCommunicator: A PyNcclCommunicator instance initialized with the created StatelessProcessGroup. """ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup if master_address is None: master_address = "localhost" # get_ip() if master_port is None: master_port = get_open_port() pg = StatelessProcessGroup.create( host=master_address, port=master_port, rank=rank, world_size=world_size ) pynccl = PyNcclCommunicator(pg, device=device) return pynccl
[docs]class vLLMWorker(Worker): """vLLM worker for Ray. vLLMParameterServer will always take rank 0 in the stateless process group initialized by this worker. And the tp ranks associated with the LLM class will be in the range [1, tp_size]. """ def __init__(self, *args, **kwargs): import os torchrl_logger.info(f"=> in {type(self).__name__}.__init__") torchrl_logger.info(f"visible devices {os.getenv('CUDA_VISIBLE_DEVICES')}") torchrl_logger.info(f"device count {torch.cuda.device_count()}") super().__init__(*args, **kwargs) def init_weight_update_group( self, master_address, master_port, rank_offset, world_size ): from vllm.distributed.parallel_state import get_world_group torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group") # Before # rank = get_world_group().rank + rank_offset # Get the local rank within the tensor parallel group tp_group = get_world_group() local_rank = tp_group.rank torchrl_logger.info(f"Local rank in tensor parallel group: {local_rank}") # Calculate the global rank for weight update group # rank_offset is 1, so ranks will be [1, 2] for tp_size=2 rank = local_rank + rank_offset torchrl_logger.info( f"Initializing {type(self).__name__} weight update group with {master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}" ) self.model_update_group = stateless_init_process_group( master_address, master_port, rank, world_size, self.device, ) torchrl_logger.info(f"{type(self).__name__}.init_weight_update_group success") def update_weight_broadcast(self, name, dtype, shape): weight = torch.empty(shape, dtype=dtype, device="cuda") self.model_update_group.broadcast( weight, src=0, stream=torch.cuda.current_stream() ) self.model_runner.model.load_weights(weights=[(name, weight)]) del weight def update_weight(self, name, weight): self.model_runner.model.load_weights(weights=[(name, weight)]) del weight
[docs] def check_weights_changed(self): """Check if the weights are updated to 0.""" # TODO: This is a test and should be treated as such weights_updated = True for p in self.model_runner.model.parameters(): weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p)) return weights_updated
[docs]class LLMOnDevice(LLM): """A thin wrapper around `vllm.LLM` to control its placement devices.""" def __init__(self, *args, bundle_indices: list | None = None, **kwargs): # Stop Ray from manipulating CUDA_VISIBLE_DEVICES at the top-level os.environ.pop("CUDA_VISIBLE_DEVICES", None) # Configure GPU utilization for Ray workers if bundle_indices is not None: os.environ[ "VLLM_RAY_PER_WORKER_GPUS" ] = "0.4" # Allow multiple workers per GPU os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) torchrl_logger.info( f"Initializing LLM with bundle_indices={bundle_indices}" ) self.args = args self.kwargs = kwargs def initialize(self): # Let vLLM handle device placement super().__init__(*self.args, **self.kwargs) return True
[docs]def make_vllm_worker( *, model_name: str, devices: list[torch.device | int] | None = None, num_devices: int | None = None, make_ray_worker: bool = True, enforce_eager: bool = False, **kwargs, ) -> LLM | ray.actor.ActorClass: # noqa """Creates a vLLM inference engine with tensor parallelism support. Args: model_name (str): The model name to pass to vLLM.LLM. devices (list[torch.device | int], optional): List of devices to use. Exclusive with num_devices. num_devices (int, optional): Number of devices to use. Exclusive with devices. make_ray_worker (bool, optional): Whether to create a Ray actor. Defaults to True. enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to `False`. **kwargs: Additional arguments passed to vLLM.LLM.__init__. Returns: LLM | ray.actor.ActorClass: Either a local vLLM LLM instance or a Ray actor handle. Example: >>> # Create a 2-GPU tensor parallel worker with Ray >>> worker = make_vllm_worker("Qwen/Qwen2.5-3B", num_devices=2) >>> # Create a local LLM instance on GPU 1 >>> llm = make_vllm_worker("Qwen/Qwen2.5-3B", devices=[1], make_ray_worker=False) """ # Handle device specification if num_devices is not None and devices is not None: raise ValueError("Cannot specify both num_devices and devices") if num_devices is not None: devices = None elif devices is None: devices = [0] # Default to first GPU num_devices = 1 elif len(devices) > 1: # Convert devices to indices devices = [ torch.device(device).index if not isinstance(device, int) else device for device in devices ] num_devices = len(devices) # Validate devices if devices is not None: for d in devices: if not isinstance(d, int) or d < 0 or d >= torch.cuda.device_count(): raise ValueError(f"Invalid device index: {d}") if make_ray_worker: import ray if not ray.is_initialized(): raise RuntimeError("Ray is not initialized") torchrl_logger.info( f"Creating vLLM Ray worker with tensor_parallel_size={num_devices}" ) # Configure Ray remote class with minimal resources # Let vLLM handle GPU allocation through environment variables worker_cls = ray.remote( num_cpus=4, # Minimal CPU request num_gpus=0, # Let vLLM handle GPU allocation )(LLMOnDevice) # Create worker with tensor parallelism config worker = worker_cls.remote( model=model_name, bundle_indices=devices, # Pass device indices to LLMOnDevice tensor_parallel_size=num_devices, distributed_executor_backend="ray", enforce_eager=enforce_eager, worker_cls="torchrl.modules.llm.backends.vllm.vLLMWorker", **kwargs, ) ray.get(worker.initialize.remote()) return worker else: # Local non-Ray mode - use LLM directly with _cuda_visible_devices(devices) if devices is not None else nullcontext(): torchrl_logger.info( f"Creating local vLLM LLM with tensor_parallel_size={num_devices}, devices={devices}" ) return LLM( model=model_name, tensor_parallel_size=num_devices, enforce_eager=True, **kwargs, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources