Source code for torchrl.modules.llm.backends.vllm.vllm_async
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Async vLLM engine implementation for efficient batching and inference.
This module provides an async vLLM engine that leverages native vLLM batching
for better performance and memory efficiency compared to the explicit batching
approach used in the legacy vLLM backend.
"""
from __future__ import annotations
import asyncio
import os
import random
import uuid
from collections.abc import Iterator, Sequence
from typing import Any, Literal, TYPE_CHECKING
import torch
from torchrl._utils import logger as torchrl_logger
# Import RLvLLMEngine and shared utilities
from .base import RLvLLMEngine
from .vllm_utils import stateless_init_process_group
try:
import ray
from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
except ImportError:
ray = None
def placement_group(*args, **kwargs):
"""Placement group is not available when ray is not installed."""
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
)
def remove_placement_group(*args, **kwargs):
"""Remove placement group is not available when ray is not installed."""
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
)
class PlacementGroupSchedulingStrategy:
"""Placement group scheduling strategy is not available when ray is not installed."""
def __init__(self, *args, **kwargs):
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
)
if TYPE_CHECKING:
from vllm.engine.async_llm_engine import AsyncEngineArgs
from vllm.engine.request import RequestOutput
from vllm.engine.sampling_params import SamplingParams
TIMEOUT_SECONDS = os.getenv("TORCHRL_VLLM_TIMEOUT_SECONDS", 300)
try:
import vllm
_has_vllm = True
except ImportError:
vllm = None
_has_vllm = False
if not _has_vllm:
class Worker:
"""Placeholder for Worker class when vLLM is not installed."""
def __init__(self, *args, **kwargs):
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)
else:
from vllm.worker.worker import Worker
class _AsyncvLLMWorker(Worker):
"""Async vLLM worker for Ray with weight update capabilities.
This worker extends the base vLLM Worker to support async operations
and weight updates via NCCL communication groups.
"""
def __init__(self, *args, **kwargs):
torchrl_logger.info(f"=> in {type(self).__name__}.__init__")
torchrl_logger.info(f"visible devices {os.getenv('CUDA_VISIBLE_DEVICES')}")
torchrl_logger.info(f"device count {torch.cuda.device_count()}")
super().__init__(*args, **kwargs)
self.model_update_group = None
def init_weight_update_group(
self,
master_address: str,
master_port: str,
rank_offset: int,
world_size: int,
):
"""Initialize weight update group for this worker.
Args:
master_address (str): The master address for distributed training.
master_port (str): The master port for distributed training.
rank_offset (int): Rank offset for this worker in the global weight update group.
world_size (int): Total number of processes in the weight update group.
"""
from vllm.distributed.parallel_state import get_world_group
torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group")
if self.model_update_group is not None:
torchrl_logger.info("Model update group already initialized")
return
# Get the local rank within the tensor parallel group
tp_group = get_world_group()
local_rank = tp_group.rank
torchrl_logger.info(f"Local rank in tensor parallel group: {local_rank}")
# Calculate the global rank for weight update group
rank = local_rank + rank_offset
torchrl_logger.info(
f"Initializing {type(self).__name__} weight update group with "
f"{master_address=}, {master_port=}, {rank=}, {world_size=}, device={self.device}"
)
# Import synchronous version for workers too
from .vllm_utils import stateless_init_process_group
self.model_update_group = stateless_init_process_group(
master_address, master_port, rank, world_size, self.device
)
torchrl_logger.info(f"{type(self).__name__}.init_weight_update_group success")
def update_weight(self, name: str, dtype_name: str, shape: tuple[int, ...]):
"""Update weight via broadcast from master (rank 0) - periodic-mono pattern.
Args:
name (str): Parameter name.
dtype_name (str): Parameter dtype name (e.g., 'bfloat16').
shape (tuple[int, ...]): Parameter shape.
"""
if self.model_update_group is None:
raise RuntimeError("Weight update group not initialized")
# Convert dtype name to dtype (like periodic-mono)
dtype = getattr(torch, dtype_name)
# Workers receive broadcast from master (rank 0)
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_nccl_group_ready(self):
"""Check if NCCL group is ready for communication."""
ready = self.model_update_group is not None
torchrl_logger.info(f"Worker NCCL group ready: {ready}")
return ready
class _AsyncLLMEngine:
"""Extended AsyncLLMEngine with TorchRL-specific features.
This class wraps vLLM's AsyncLLMEngine and adds functionality needed
for TorchRL integration, including weight updates and batch management.
This is a private class and should not be used directly. Use the ray remote actor class :class:`AsyncLLMEngineActor` instead.
Keyword Args:
engine_args (AsyncEngineArgs): Arguments for creating the AsyncLLMEngine instances.
bundle_indices (list[int], optional): Bundle indices for the engine.
enable_prefix_caching (bool, optional): Whether to enable prefix caching.
.. warning::
enable_prefix_caching is set to False by default, which is recommended if prompt log probs are needed.
Set it to True if prompt log probs are not needed.
See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
"""
def __init__(
self,
*,
engine_args: AsyncEngineArgs,
bundle_indices: list[int] | None = None,
enable_prefix_caching: bool = False,
):
if not _has_vllm:
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)
from vllm import AsyncLLMEngine
worker_cls = "torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker"
if engine_args.worker_cls != "auto":
old_worker_cls = engine_args.worker_cls
torchrl_logger.warning(
f"Overriding worker_cls from {old_worker_cls} to {worker_cls}"
)
if bundle_indices is not None:
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
engine_args.worker_cls = worker_cls
engine_args.enable_prefix_caching = enable_prefix_caching
# Create the engine directly - this is the source of the blocking ray.get issue
# but we need to handle it differently for multiple replicas
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.bundle_indices = bundle_indices
def ready(self) -> bool:
"""Check if engine is ready for inference."""
return True
async def generate(
self,
prompts: Any = None,
sampling_params: SamplingParams | None = None,
*,
prompt_token_ids: list[int] | list[list[int]] | None = None,
use_tqdm: bool = True,
lora_request: Any = None,
prompt_adapter_request: Any = None,
guided_options_request: Any = None,
timeout_seconds: float | None = None,
) -> RequestOutput | list[RequestOutput]:
"""Generate text with the same interface as vLLM.LLM.generate.
This method mirrors the interface of vLLM.LLM.generate to provide seamless
compatibility between sync and async engines.
Args:
prompts: String, TokensPrompt, or list of these. Input prompts for generation.
sampling_params: SamplingParams object for controlling generation behavior.
prompt_token_ids: Alternative to prompts - token IDs for generation.
use_tqdm: Whether to show progress bar (not used in async engine).
lora_request: LoRA request for adapter-based generation.
guided_options_request: Guided decoding options.
timeout_seconds: Timeout for generation in seconds.
Returns:
RequestOutput or list of RequestOutput: Generated outputs from vLLM.
"""
if not _has_vllm:
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)
from vllm import RequestOutput, SamplingParams, TokensPrompt
# Track whether input was originally a single prompt
single_prompt_input = False
# Handle prompt_token_ids if provided
if prompt_token_ids is not None:
if prompts is not None:
raise ValueError("Cannot specify both prompts and prompt_token_ids")
# Convert token IDs to TokensPrompt objects
if not prompt_token_ids:
raise ValueError("prompt_token_ids cannot be empty")
# Check if it's a list of lists or a single list
if prompt_token_ids and isinstance(prompt_token_ids[0], list):
# List of token ID lists
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompt_token_ids
]
else:
# Single token ID list - cast to ensure type compatibility
token_list = list(prompt_token_ids) if prompt_token_ids else []
prompts = TokensPrompt(prompt_token_ids=token_list)
single_prompt_input = True
elif prompts is None:
raise ValueError("Must specify either prompts or prompt_token_ids")
else:
# prompts was provided directly
if not isinstance(prompts, (list, tuple)):
single_prompt_input = True
# Default sampling params if not provided
if sampling_params is None:
sampling_params = SamplingParams()
async def _gen_one(prompt) -> RequestOutput:
request_id = str(uuid.uuid4())
final = None
# Build kwargs for engine.generate
gen_kwargs = {
"prompt": prompt,
"sampling_params": sampling_params,
"request_id": request_id,
}
# Add optional parameters if provided
if lora_request is not None:
gen_kwargs["lora_request"] = lora_request
if prompt_adapter_request is not None:
gen_kwargs["prompt_adapter_request"] = prompt_adapter_request
if guided_options_request is not None:
gen_kwargs["guided_options_request"] = guided_options_request
async for output in self.engine.generate(**gen_kwargs):
if output.finished:
final = output
assert final is not None
return final
async def _run_generation():
if single_prompt_input:
return await _gen_one(prompts)
# List of prompts: run concurrently
tasks = [asyncio.create_task(_gen_one(p)) for p in prompts]
results = await asyncio.gather(*tasks)
return results
try:
if timeout_seconds is not None and timeout_seconds > 0:
return await asyncio.wait_for(
_run_generation(), timeout=timeout_seconds
)
else:
return await _run_generation()
except TimeoutError:
# Best-effort cleanup
try:
abort_fn = getattr(self.engine, "abort", None)
if callable(abort_fn):
# We can't easily track all request IDs, so this is best-effort
pass
except Exception:
pass
raise TimeoutError(
f"vLLM generation timed out after {timeout_seconds} seconds"
)
async def get_tokenizer(self):
"""Get the tokenizer from the engine."""
return await self.engine.get_tokenizer()
async def collective_rpc_v1(
self,
method: str,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
):
"""Perform a collective RPC call to the given method (vLLM V1).
Args:
method (str): Method name to call.
timeout (float | None): Timeout for the RPC call.
args (tuple): Arguments to pass to the method.
kwargs (dict | None): Keyword arguments to pass to the method.
"""
from vllm import envs
if envs and envs.VLLM_USE_V1:
return await self.engine.collective_rpc(method, timeout, args, kwargs)
else:
return self.engine.engine.collective_rpc(method, timeout, args, kwargs)
def collective_rpc_v0(
self,
method: str,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
):
"""Perform a collective RPC call to the given method (vLLM V0).
Args:
method (str): Method name to call.
timeout (float | None): Timeout for the RPC call.
args (tuple): Arguments to pass to the method.
kwargs (dict | None): Keyword arguments to pass to the method.
"""
return self.engine.engine.collective_rpc(method, timeout, args, kwargs)
def get_num_unfinished_requests(self) -> int:
"""Get the number of unfinished requests in the engine.
Returns:
int: Number of unfinished requests.
"""
try:
# Try to access the method directly if available
if hasattr(self.engine, "get_num_unfinished_requests"):
return self.engine.get_num_unfinished_requests()
# Fallback to accessing through engine.engine for v0
elif hasattr(self.engine, "engine") and hasattr(
self.engine.engine, "get_num_unfinished_requests"
):
return self.engine.engine.get_num_unfinished_requests()
else:
# If method not available, return 0 as fallback
torchrl_logger.warning(
"get_num_unfinished_requests not available, returning 0"
)
return 0
except Exception as e:
torchrl_logger.warning(f"Error getting unfinished requests count: {e}")
return 0
def get_cache_usage(self) -> float:
"""Get the KV cache usage as a fraction between 0 and 1.
Returns:
float: Cache usage fraction (0.0 = empty, 1.0 = full).
"""
try:
# Try to get cache usage from the engine
if hasattr(self.engine, "engine") and hasattr(
self.engine.engine, "cache_config"
):
# Access the LLM engine's cache information
cache_config = self.engine.engine.cache_config
if hasattr(cache_config, "cache_usage"):
return cache_config.cache_usage
elif hasattr(self.engine.engine, "scheduler"):
# Try to get usage from the scheduler
scheduler = self.engine.engine.scheduler
if hasattr(scheduler, "get_num_free_gpu_blocks") and hasattr(
scheduler, "get_num_total_gpu_blocks"
):
free_blocks = scheduler.get_num_free_gpu_blocks()
total_blocks = scheduler.get_num_total_gpu_blocks()
if total_blocks > 0:
return 1.0 - (free_blocks / total_blocks)
# Fallback: return a random value for now (this should be replaced with actual metrics)
torchrl_logger.warning(
"Cache usage metrics not available, returning random value"
)
return (
random.random() * 0.5
) # Return a value between 0 and 0.5 to simulate partial usage
except Exception as e:
torchrl_logger.warning(f"Error getting cache usage: {e}")
return 0.0
def _gpus_per_replica(engine_args: AsyncEngineArgs) -> int:
"""Get the number of GPUs per replica for the given engine args."""
return (
engine_args.tensor_parallel_size
* getattr(engine_args, "data_parallel_size", 1) # Default to 1 if not present
* getattr(
engine_args, "pipeline_parallel_size", 1
) # Default to 1 if not present
)
def _get_bundle_indices(placement_group, index: int, length: int) -> list[int]:
"""Get bundle indices for a placement group.
Address https://github.com/ray-project/ray/issues/51117
This function is used to get the bundle indices of a placement group
and ensure that the bundles placed on the same node are grouped together.
Args:
placement_group: Ray placement group.
index (int): Index of the current replica.
length (int): Number of bundles per replica.
Returns:
list[int]: Bundle indices for this replica.
"""
if ray is None:
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
)
pg_infos = ray.util.placement_group_table(placement_group)
node_id_to_bundles = {}
for bundle, node_id in pg_infos["bundles_to_node_id"].items():
node_id_to_bundles.setdefault(node_id, []).append(bundle)
sorted_bundle_indices = sum(node_id_to_bundles.values(), [])
return sorted_bundle_indices[index * length : (index + 1) * length]
# Create Ray remote versions
if ray is not None and _has_vllm:
_AsyncLLMEngineActor = ray.remote(num_cpus=0, num_gpus=0)(_AsyncLLMEngine)
else:
_AsyncLLMEngineActor = None
[docs]class AsyncVLLM(RLvLLMEngine):
"""A service that manages multiple async vLLM engine actors for distributed inference.
This is the main entry point for async vLLM inference in TorchRL. It manages multiple
vLLM engine replicas running as Ray actors, providing load balancing, weight updates,
and a unified interface for text generation.
The service automatically handles Ray actor lifecycle management, GPU allocation through
placement groups, and provides both synchronous and asynchronous generation interfaces
that are compatible with the standard vLLM API.
Args:
engine_args (AsyncEngineArgs): Configuration for the vLLM engines.
num_replicas (int, optional): Number of engine replicas to create. Defaults to 1.
actor_class (optional): Custom Ray actor class. Defaults to the internal actor implementation.
enable_prefix_caching (bool, optional): Whether to enable prefix caching. Defaults to False.
.. warning::
enable_prefix_caching is set to False by default, which is recommended if prompt log probs are needed.
Set it to True if prompt log probs are not needed.
See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
Example:
>>> from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
>>> from vllm import SamplingParams
>>>
>>> # Simple usage - single GPU, single replica
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
>>>
>>> # Advanced usage - multi-GPU tensor parallel with multiple replicas
>>> service = AsyncVLLM.from_pretrained(
... "Qwen/Qwen2.5-7B",
... num_devices=2, # Use 2 GPUs for tensor parallelism
... num_replicas=2, # Create 2 replicas for higher throughput
... max_model_len=4096
... )
>>>
>>> # Generate text
>>> sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
>>> result = service.generate("Hello, world!", sampling_params)
>>> print(result.outputs[0].text)
>>>
>>> # Alternative: using AsyncEngineArgs directly for advanced configuration
>>> from vllm import AsyncEngineArgs
>>> engine_args = AsyncEngineArgs(
... model="Qwen/Qwen2.5-3B",
... tensor_parallel_size=2
... )
>>> service = AsyncVLLM.launch(engine_args, num_replicas=2)
.. note::
**Architecture and Design**
The AsyncVLLM service implements a distributed inference architecture with the following key components:
1. **Ray Actor Management**: Each replica runs as a separate Ray actor with dedicated GPU resources.
The service creates a placement group to ensure optimal GPU allocation and co-location of
tensor-parallel workers on the same node when possible.
2. **Load Balancing**: Generation requests are distributed across replicas using random selection
by default, or can target specific replicas using the `actor_index` parameter.
3. **Weight Synchronization**: The service supports weight updates across all replicas through
NCCL communication groups, enabling integration with distributed training workflows.
4. **Resource Management**: Automatic GPU allocation and cleanup through Ray placement groups,
with proper shutdown procedures to prevent resource leaks.
5. **API Compatibility**: Provides the same interface as vLLM's synchronous `LLM.generate()`
method, making it a drop-in replacement for async workloads.
**Ray Integration**
The service leverages Ray's actor model for distributed execution. Each replica is an independent
Ray actor that can be scheduled on different nodes. The service handles actor lifecycle,
monitors readiness, and provides centralized access to all replicas.
**Performance Considerations**
- Prefix caching is enabled by default for better performance with repeated prompts
- Tensor parallelism is supported for large models that don't fit on single GPUs
- Multiple replicas allow concurrent processing of different requests
- Native vLLM batching is used within each replica for optimal throughput
**Error Handling**
The service includes timeout support, graceful shutdown procedures, and best-effort
request cleanup on failures. Ray's fault tolerance mechanisms provide additional
resilience for long-running inference workloads.
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
num_replicas: int = 1,
actor_class=None,
enable_prefix_caching: bool = False,
):
if not _has_vllm:
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)
if ray is None:
raise ImportError(
"ray is not installed. Please install it with `pip install ray`."
)
# Enable prefix caching by default for better performance
engine_args.enable_prefix_caching = enable_prefix_caching
self.engine_args = engine_args
self.num_replicas = num_replicas
self.actor_class = actor_class or _AsyncLLMEngineActor
self.actors: list = []
self._launched = False
self._service_id = uuid.uuid4().hex[
:8
] # Unique suffix to avoid name collisions
self._placement_group = None
self._load_balancer = None
def _launch(self):
"""Launch all actor replicas."""
if self._launched:
torchrl_logger.warning("AsyncVLLMEngineService already launched")
return
# Check if CUDA is available since vLLM requires GPU
if not torch.cuda.is_available():
raise RuntimeError(
"AsyncVLLM requires CUDA but no GPU devices are available. "
"Please run on a machine with GPU support."
)
torchrl_logger.info(
f"Launching {self.num_replicas} async vLLM engine actors..."
)
# Create placement groups - one per replica to avoid conflicts
self._placement_groups = []
# Create actor replicas sequentially to avoid race conditions
for i in range(self.num_replicas):
torchrl_logger.info(
f"Creating async actor replica {i + 1}/{self.num_replicas} ..."
)
# Create individual placement group for this replica
bundles = [
{"GPU": 1.0, "CPU": 1.0}
for _ in range(self.engine_args.tensor_parallel_size)
]
torchrl_logger.info(
f"Creating placement group for replica {i + 1} with {len(bundles)} bundles"
)
placement_group_name = f"vllm-replica-{self._service_id}-{i}"
pg = placement_group(bundles, strategy="PACK", name=placement_group_name)
self._placement_groups.append(pg)
torchrl_logger.info(f"Placement group {placement_group_name} created: {pg}")
# Wait for placement group to be ready
ray.get(pg.ready(), timeout=180)
torchrl_logger.info(f"Placement group {placement_group_name} ready")
# Calculate bundle indices for tensor parallelism
bundle_indices = None
if self.engine_args.tensor_parallel_size > 1:
bundle_indices = list(range(self.engine_args.tensor_parallel_size))
bundle_index = 0 # Always use first bundle since each replica has its own placement group
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
)
actor = self.actor_class.options(
name=f"async-vllm-replica-{self._service_id}-{i}",
namespace="torchrl_vllm",
scheduling_strategy=scheduling_strategy,
num_gpus=0,
num_cpus=0,
).remote(
engine_args=self.engine_args,
bundle_indices=bundle_indices,
enable_prefix_caching=self.engine_args.enable_prefix_caching,
)
self.actors.append(actor)
torchrl_logger.info("Waiting for actors to be ready")
# Wait for this actor to be ready before creating the next one
ready_futures = [actor.ready.remote() for actor in self.actors]
try:
ray.get(
ready_futures, timeout=TIMEOUT_SECONDS
) # 5 minute timeout for engine initialization
torchrl_logger.info("✅ Actors are ready")
except Exception as e:
torchrl_logger.error(
f"❌ Failed to initialize actors within {TIMEOUT_SECONDS} seconds: {e}. You can increase the timeout by setting the TORCHRL_VLLM_TIMEOUT_SECONDS environment variable."
)
raise
# Store the first placement group for backward compatibility
self._placement_group = (
self._placement_groups[0] if self._placement_groups else None
)
self._launched = True
torchrl_logger.info(
f"✅ Successfully launched {len(self.actors)} async vLLM engine actors"
)
[docs] @classmethod
def launch(
cls,
engine_args: AsyncEngineArgs,
num_replicas: int = 1,
) -> AsyncVLLM:
"""Launch a new AsyncVLLMEngineService.
Args:
engine_args (AsyncEngineArgs): Arguments for creating the AsyncLLMEngine instances.
num_replicas (int): Number of actor replicas to create.
Returns:
AsyncVLLMEngineService: The launched service.
"""
service = cls(engine_args, num_replicas)
service._launch()
# create a default load balancer with smart routing
service.create_load_balancer()
return service
[docs] @classmethod
def from_pretrained(
cls,
model_name: str,
num_devices: int | None = None,
num_replicas: int = 1,
verbose: bool = True,
compile: bool = True,
**kwargs,
) -> AsyncVLLM:
"""Create an AsyncVLLM instance from a pretrained model.
This is a convenience method that combines model loading and service launching
in a single call, similar to how other ML libraries work.
Args:
model_name (str): The model name to pass to vLLM.
num_devices (int, optional): Number of devices to use, per replica.
num_replicas (int): Number of engine replicas to create.
verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True.
compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
**kwargs: Additional arguments passed to AsyncEngineArgs.
Returns:
AsyncVLLM: The launched async vLLM service.
Example:
>>> # Simple usage with defaults
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B")
>>>
>>> # Multi-GPU tensor parallel with multiple replicas
>>> service = AsyncVLLM.from_pretrained(
... "Qwen/Qwen2.5-7B",
... num_devices=2,
... num_replicas=2,
... max_model_len=4096
... )
>>>
>>> # Generate text
>>> from vllm import SamplingParams
>>> result = service.generate("Hello, world!", SamplingParams(max_tokens=50))
"""
return make_async_vllm_engine(
model_name=model_name,
num_devices=num_devices,
num_replicas=num_replicas,
verbose=verbose,
compile=compile,
**kwargs,
)
def _is_batch(
self, prompts: Any, prompt_token_ids: list[int] | list[list[int]] | None = None
) -> bool:
"""Check if the input represents a batch of prompts.
Args:
prompts: Input prompts that can be string, TokensPrompt, or list of these
prompt_token_ids: Alternative token IDs input
Returns:
bool: True if this represents multiple prompts, False for single prompt
"""
# If prompts is a list, we need to determine if it's a batch or a single prompt
if isinstance(prompts, list):
# Empty list is not a batch
if len(prompts) == 0:
return False
# If all elements are integers, it's a single prompt represented as token IDs
# We trust that if one is an int, then all are ints.
if any(isinstance(item, int) for item in prompts):
return False
# If it contains strings, TokensPrompt objects, or other non-integer types,
# it's a batch of prompts
return True
# If prompt_token_ids is provided and is a list of lists, it's a batch
if prompt_token_ids is not None and isinstance(prompt_token_ids, list):
if len(prompt_token_ids) > 0 and isinstance(prompt_token_ids[0], list):
return True
return False
def _iterate(
self, prompts: Any, prompt_token_ids: list[int] | list[list[int]] | None = None
):
"""Iterate over individual prompts in a batch.
Args:
prompts: Input prompts that can be string, TokensPrompt, or list of these
prompt_token_ids: Alternative token IDs input
Yields:
tuple: (individual_prompt, individual_prompt_token_ids) for each item
"""
if isinstance(prompts, list):
# Check if this is actually a single prompt represented as token IDs
if all(isinstance(item, int) for item in prompts):
# This is a single prompt as token IDs, not a batch
yield prompts, prompt_token_ids
return
# Handle list of prompts (actual batch)
if prompt_token_ids is None:
for prompt in prompts:
yield prompt, None
elif (
isinstance(prompt_token_ids, list)
and len(prompt_token_ids) > 0
and isinstance(prompt_token_ids[0], list)
):
# Both prompts and prompt_token_ids are lists
for prompt, token_ids in zip(prompts, prompt_token_ids):
yield prompt, token_ids
else:
# prompts is list, but prompt_token_ids is single list - replicate it
for prompt in prompts:
yield prompt, prompt_token_ids
else:
# Single prompt case
if (
prompt_token_ids is not None
and isinstance(prompt_token_ids, list)
and len(prompt_token_ids) > 0
and isinstance(prompt_token_ids[0], list)
):
# Single prompt but multiple token_ids - replicate prompt
for token_ids in prompt_token_ids:
yield prompts, token_ids
else:
# Single prompt, single (or no) token_ids
yield prompts, prompt_token_ids
def _generate_impl(
self,
prompt: Any,
sampling_params: SamplingParams | None = None,
*,
prompt_token_ids: list[int] | None = None,
use_tqdm: bool = True,
lora_request: Any = None,
prompt_adapter_request: Any = None,
guided_options_request: Any = None,
timeout_seconds: float | None = None,
actor_index: int | None = None,
):
"""Generate text for a single prompt and return a Ray future.
This is the internal implementation that returns a future instead of the result.
Used for batched generation to enable parallel execution.
Args:
prompt: Single prompt (string, TokensPrompt, etc.)
sampling_params: SamplingParams object for controlling generation behavior
prompt_token_ids: Token IDs for a single prompt
use_tqdm: Whether to show progress bar (not used in async engine)
lora_request: LoRA request for adapter-based generation
prompt_adapter_request: Prompt adapter request
guided_options_request: Guided decoding options
timeout_seconds: Timeout for generation in seconds
actor_index: Specific actor to use (random if None)
Returns:
Ray ObjectRef: Future that will resolve to RequestOutput
"""
if actor_index is None:
if len(self.actors) == 1:
actor = self.actors[0]
else:
if self._load_balancer is None:
raise RuntimeError(
"LoadBalancer is not created. Create a LoadBalancer using AsyncVLLM.create_load_balancer before calling generate."
)
# Extract single prompt for prefix-aware routing
single_prompt = self._extract_single_prompt_for_routing(
prompt, prompt_token_ids
)
actor_index = self._load_balancer.select_actor(prompt=single_prompt)
actor = self.actors[actor_index]
else:
actor = self.actors[actor_index]
return actor.generate.remote(
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options_request=guided_options_request,
timeout_seconds=timeout_seconds,
)
[docs] def generate(
self,
prompts: Any = None,
sampling_params: SamplingParams | None = None,
*,
prompt_token_ids: list[int] | list[list[int]] | None = None,
use_tqdm: bool = True,
lora_request: Any = None,
prompt_adapter_request: Any = None,
guided_options_request: Any = None,
timeout_seconds: float | None = None,
actor_index: int | None = None,
) -> RequestOutput | list[RequestOutput]:
"""Generate text using one of the actors with vLLM.LLM.generate interface.
This method provides the same interface as vLLM.LLM.generate for seamless
compatibility between sync and async engines. It can be used to generate text
within multiple threads / actors. If `actor_index` is not provided, the load balancer
will be used to select the actor.
`generate` is a blocking method, so it will wait for the generation to complete.
Args:
prompts (String, TokensPrompt, or list of these): Input prompts for generation.
sampling_params (SamplingParams): SamplingParams object for controlling generation behavior.
prompt_token_ids (list[int] | list[list[int]]): Alternative to prompts - token IDs for generation.
use_tqdm (bool): Whether to show progress bar (not used in async engine).
lora_request (Any): LoRA request for adapter-based generation.
prompt_adapter_request (Any): Prompt adapter request.
guided_options_request (Any): Guided decoding options.
timeout_seconds (float | None): Timeout for generation in seconds.
actor_index (int | None): Specific actor to use (random if None).
Returns:
RequestOutput | list[RequestOutput]: Generated outputs from vLLM.
"""
# Check if this is a batch request
if self._is_batch(prompts, prompt_token_ids):
# Handle batched input by unbinding and sending individual requests
futures = []
for prompt, prompt_token_ids_i in self._iterate(prompts, prompt_token_ids):
future = self._generate_impl(
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids_i,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options_request=guided_options_request,
timeout_seconds=timeout_seconds,
actor_index=actor_index,
)
futures.append(future)
# Collect all results
results = ray.get(futures)
return results
else:
# Single prompt case - call _generate_impt and get result directly
future = self._generate_impl(
prompts,
sampling_params,
prompt_token_ids=prompt_token_ids,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options_request=guided_options_request,
timeout_seconds=timeout_seconds,
actor_index=actor_index,
)
result = ray.get(future)
return result
[docs] def get_random_actor_index(self) -> int:
"""Get a random actor index."""
return random.randint(0, len(self.actors) - 1)
def _init_weight_update_group_internal(self, master_address: str, master_port: str):
"""Initialize NCCL weight update group across all actors.
Args:
master_address (str): Master address for distributed training.
master_port (str): Master port for distributed training.
Returns:
list: Ray futures for initialization calls.
"""
gpus_per_replica = _gpus_per_replica(self.engine_args)
weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
torchrl_logger.info(
f"AsyncVLLMEngineService requests weight update group for {self.num_replicas} actors "
f"with {gpus_per_replica} GPUs per replica and {weight_sync_world_size} world size"
)
from vllm import envs
refs = []
for i, actor in enumerate(self.actors):
rank_offset = 1 + i * gpus_per_replica
if envs and envs.VLLM_USE_V1:
actor_collective_rpc = actor.collective_rpc_v1
else:
actor_collective_rpc = actor.collective_rpc_v0
refs.append(
actor_collective_rpc.remote(
"init_weight_update_group",
args=(
master_address,
master_port,
rank_offset,
weight_sync_world_size,
),
)
)
torchrl_logger.info(
f"AsyncVLLMEngineService args: {master_address=}, {master_port=}, "
f"{rank_offset=}, {weight_sync_world_size=}"
)
torchrl_logger.info(
f"AsyncVLLMEngineService requests weight update group for actor {i} "
f"with rank_offset {rank_offset}"
)
return refs
[docs] def collective_rpc(
self,
method: str,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
) -> list[Any]:
"""Forward an RPC to all actors.
Args:
method (str): Method name to call.
timeout (float | None): Timeout for the RPC call.
args (tuple): Arguments to pass to the method.
kwargs (dict | None): Keyword arguments to pass to the method.
Returns:
list[Any]: Ray futures for all RPC calls.
"""
from vllm import envs
futures = []
for actor in self.actors:
if envs and envs.VLLM_USE_V1:
actor_collective_rpc = actor.collective_rpc_v1
else:
actor_collective_rpc = actor.collective_rpc_v0
futures.append(actor_collective_rpc.remote(method, timeout, args, kwargs))
return futures
[docs] def shutdown(self):
"""Shutdown all actors and clean up resources."""
torchrl_logger.info(
f"Shutting down {len(self.actors)} async vLLM engine actors..."
)
# Kill all actors
for i, actor in enumerate(self.actors):
try:
ray.kill(actor)
torchrl_logger.info(f"Shutdown async actor {i + 1}/{len(self.actors)}")
except Exception as e:
torchrl_logger.warning(f"Error shutting down async actor {i + 1}: {e}")
# Clear the actors list
self.actors.clear()
# Remove placement groups if any
if hasattr(self, "_placement_groups") and self._placement_groups:
for i, pg in enumerate(self._placement_groups):
try:
remove_placement_group(pg)
torchrl_logger.info(
f"Removed placement group {i + 1}/{len(self._placement_groups)}"
)
except Exception as e:
torchrl_logger.warning(
f"Error removing placement group {i + 1}: {e}"
)
self._placement_groups = []
# Remove legacy single placement group if any
if self._placement_group is not None:
remove_placement_group(self._placement_group)
self._placement_group = None
self._launched = False
torchrl_logger.info("AsyncVLLMEngineService shutdown complete")
# RLvLLMEngine interface implementation
[docs] def get_tp_size(self) -> int:
"""Get the tensor parallel size."""
return self.engine_args.tensor_parallel_size
[docs] def get_model_metadata(self) -> dict[str, tuple[torch.dtype, torch.Size]]:
"""Get model parameter metadata.
Note: This requires the model to be loaded. For now, we return an empty dict
and expect the metadata to be provided externally during weight updates.
"""
# TODO: Implement metadata extraction from loaded model
# This would require accessing the model from one of the actors
torchrl_logger.warning(
"AsyncVLLM.get_model_metadata() not yet implemented - returning empty dict"
)
return {}
[docs] def get_master_address(self) -> str:
"""Get the master address for weight synchronization."""
return "localhost" # Default for now
[docs] def get_master_port(self) -> int:
"""Get the master port for weight synchronization."""
# Cache the port like V1 does to ensure consistency
if not hasattr(self, "_cached_master_port"):
if _has_vllm:
try:
from vllm.utils import get_open_port
self._cached_master_port = get_open_port()
except ImportError:
self._cached_master_port = 29500 # Default port if import fails
else:
self._cached_master_port = 29500 # Default port
return self._cached_master_port
[docs] def init_weight_update_group(self) -> None:
"""Initialize the weight update communication group (RLvLLMEngine interface)."""
if not self._launched:
raise RuntimeError(
"AsyncVLLM service must be launched before initializing weight update group"
)
master_address = self.get_master_address()
master_port = self.get_master_port()
# Call the internal method with the auto-detected parameters (like V1)
refs = self._init_weight_update_group_internal(master_address, master_port)
# CRITICAL: Initialize master NCCL group immediately (like V1) - don't wait for workers
torchrl_logger.info("Setting up master NCCL group (rank 0)...")
self._setup_nccl_master_group()
# Now wait for workers to complete (like V1 does)
if ray is not None:
ray.get(refs)
torchrl_logger.info("AsyncVLLM weight update group initialized")
[docs] def update_weights(self, weights: Iterator[tuple[str, torch.Tensor]]) -> None:
"""Update model weights across all replicas using NCCL broadcast.
Args:
weights: Iterator yielding (parameter_name, tensor) tuples
"""
if not self._launched:
raise RuntimeError(
"AsyncVLLM service must be launched before updating weights"
)
# Convert iterator to dict for easier handling
weights_dict = dict(weights)
if not weights_dict:
torchrl_logger.warning("No weights provided for update")
return
torchrl_logger.info(
f"Updating {len(weights_dict)} parameters across {len(self.actors)} replicas using NCCL broadcast"
)
self._update_weights_with_nccl_broadcast_simple(weights_dict)
torchrl_logger.info("AsyncVLLM NCCL weight update completed")
def _update_weights_with_nccl_broadcast_simple(
self, weights_dict: dict[str, torch.Tensor]
) -> None:
"""Update weights using simple NCCL broadcast like V1.
This approach follows the V1 pattern:
1. Training process (master) broadcasts as rank 0
2. All vLLM workers receive as ranks 1, 2, 3...
3. Simple and reliable like the working V1 implementation
Args:
weights_dict: Dictionary of parameter names to weight tensors
"""
import time
if not hasattr(self, "_nccl_master_group") or self._nccl_master_group is None:
raise RuntimeError(
"NCCL master group not initialized. This is a bug in the setup process."
)
t0 = time.time()
# Move all weights to cuda:0 (matching NCCL communicator device)
gpu_weights = {}
for name, weight in weights_dict.items():
# Ensure weight is on cuda:0 (matching NCCL communicator)
if weight.device != torch.device("cuda:0"):
gpu_weights[name] = weight.to("cuda:0", non_blocking=True)
else:
gpu_weights[name] = weight
# Use periodic-mono pattern: individual weight updates with immediate RPC->NCCL
torchrl_logger.info(
f"Updating {len(gpu_weights)} weights using periodic-mono pattern..."
)
updated_weights = 0
with torch.cuda.device(0): # Ensure we're on the correct CUDA device
for name, weight in gpu_weights.items():
# Convert dtype to string name (like periodic-mono)
dtype_name = str(weight.dtype).split(".")[
-1
] # "torch.bfloat16" -> "bfloat16"
# Step 1: Send RPC to workers for this weight
futures = self.collective_rpc(
"update_weight", args=(name, dtype_name, tuple(weight.shape))
)
# Step 2: Immediately broadcast this weight (like periodic-mono)
self._nccl_master_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)
# Step 3: Wait for workers to complete this weight
ray.get(futures)
updated_weights += 1
torch.cuda.synchronize()
t2 = time.time()
torchrl_logger.info(
f"Successfully updated {updated_weights}/{len(gpu_weights)} weights in {t2 - t0:.3f}s"
)
def _setup_nccl_master_group(self) -> None:
"""Set up NCCL communication group for the master node (rank 0)."""
# Calculate world size (should match what workers use)
gpus_per_replica = _gpus_per_replica(self.engine_args)
weight_sync_world_size = self.num_replicas * gpus_per_replica + 1
master_address = self.get_master_address()
master_port = self.get_master_port()
torchrl_logger.info(
f"Setting up NCCL master group: rank=0, world_size={weight_sync_world_size}, "
f"address={master_address}:{master_port}"
)
# Ensure CUDA is available and initialized
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available for NCCL communication")
# Set CUDA device before initializing NCCL
torch.cuda.set_device(0)
# Initialize master as rank 0 in the NCCL group (use synchronous version)
self._nccl_master_group = stateless_init_process_group(
master_address=master_address,
master_port=str(master_port),
rank=0, # Master is always rank 0
world_size=weight_sync_world_size,
device=torch.device("cuda:0"),
)
torchrl_logger.info("NCCL master group initialized successfully")
[docs] def get_num_unfinished_requests(
self, actor_index: int | None = None
) -> int | list[int]:
"""Get the number of unfinished requests for one or all actors.
Args:
actor_index (int | None): Index of specific actor, or None for all actors.
Returns:
int | list[int]: Number of unfinished requests for the specified actor,
or list of counts for all actors if actor_index is None.
"""
if not self._launched:
raise RuntimeError(
"AsyncVLLM service must be launched before getting request counts"
)
if actor_index is not None:
if not (0 <= actor_index < len(self.actors)):
raise IndexError(
f"Actor index {actor_index} out of range [0, {len(self.actors)})"
)
actor = self.actors[actor_index]
return ray.get(actor.get_num_unfinished_requests.remote())
else:
# Get counts from all actors
futures = [
actor.get_num_unfinished_requests.remote() for actor in self.actors
]
return ray.get(futures)
[docs] def get_cache_usage(self, actor_index: int | None = None) -> float | list[float]:
"""Get the KV cache usage for one or all actors.
Args:
actor_index (int | None): Index of specific actor, or None for all actors.
Returns:
float | list[float]: Cache usage fraction for the specified actor,
or list of usage fractions for all actors if actor_index is None.
"""
if not self._launched:
raise RuntimeError(
"AsyncVLLM service must be launched before getting cache usage"
)
if actor_index is not None:
if not (0 <= actor_index < len(self.actors)):
raise IndexError(
f"Actor index {actor_index} out of range [0, {len(self.actors)})"
)
actor = self.actors[actor_index]
return ray.get(actor.get_cache_usage.remote())
else:
# Get usage from all actors
futures = [actor.get_cache_usage.remote() for actor in self.actors]
return ray.get(futures)
[docs] def create_load_balancer(
self,
strategy: Literal["requests", "kv-cache"]
| Sequence[Literal["prefix-aware", "requests", "kv-cache", "round-robin"]]
| None = None,
**kwargs,
) -> LoadBalancer:
"""Create a load balancer for this AsyncVLLM service.
Args:
strategy: Load balancing strategy or sequence of strategies in fallback order.
Default: ["prefix-aware", "requests"] - tries cache-aware routing first,
then load balancing. Single strategies: "requests", "kv-cache"
Strategy sequences: ["prefix-aware", "requests", "round-robin"]
**kwargs: Additional arguments passed to LoadBalancer constructor.
Returns:
LoadBalancer: Configured load balancer instance. This is stored in the AsyncVLLM instance.
Examples:
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=3)
>>> # Use smart defaults (prefix-aware -> requests)
>>> lb = service.create_load_balancer()
>>> selected_actor_index = lb.select_actor(prompt="Hello world")
>>> # Simple single strategy
>>> lb = service.create_load_balancer("requests")
>>> selected_actor_index = lb.select_actor()
>>> # Custom strategy hierarchy
>>> lb = service.create_load_balancer(
... ["prefix-aware", "kv-cache", "round-robin"],
... prefix_length=16,
... overload_threshold=2.0
... )
>>> selected_actor_index = lb.select_actor(prompt="Hello world")
"""
if not self._launched:
raise RuntimeError(
"AsyncVLLM service must be launched before creating load balancer"
)
load_balancer = LoadBalancer(self, strategy, **kwargs)
self._load_balancer = load_balancer
return load_balancer
def _extract_single_prompt_for_routing(
self,
prompts: Any = None,
prompt_token_ids: list[int] | list[list[int]] | None = None,
) -> str | list[int] | None:
"""Extract a single prompt for load balancer routing, if possible.
Args:
prompts: The prompts argument passed to generate().
prompt_token_ids: The prompt_token_ids argument passed to generate().
Returns:
str | list[int] | None: Single prompt for routing, or None if multiple prompts.
"""
try:
# Handle prompt_token_ids first (takes precedence over prompts)
if prompt_token_ids is not None:
if isinstance(prompt_token_ids, list):
if len(prompt_token_ids) == 0:
return None # Empty list
elif len(prompt_token_ids) == 1:
# Single prompt case - could be tokens directly or nested list
if isinstance(prompt_token_ids[0], int):
# Single token sequence: [token1, token2, ...]
return prompt_token_ids
elif isinstance(prompt_token_ids[0], list):
# Nested list with single prompt: [[token1, token2, ...]]
return prompt_token_ids[0]
else:
return None
else:
# Multiple prompts: [[tokens1...], [tokens2...], ...]
return None
else:
# Not a list, invalid format
return None
# Handle prompts argument
if prompts is None:
return None
# Import vLLM types for proper checking
try:
pass
except ImportError:
# Fallback if imports fail
type(None)
type(None)
# Single string prompt
if isinstance(prompts, str):
return prompts
# TokensPrompt object
elif hasattr(prompts, "prompt_token_ids"): # TokensPrompt-like object
return prompts.prompt_token_ids
# TextPrompt object
elif hasattr(prompts, "prompt"): # TextPrompt-like object
return prompts.prompt
# List of prompts
elif isinstance(prompts, (list, tuple)):
if len(prompts) == 0:
return None # Empty list
elif len(prompts) == 1:
# Single prompt in list - recursively extract
return self._extract_single_prompt_for_routing(prompts[0], None)
else:
# Multiple prompts - cannot do prefix routing
return None
# Other types (shouldn't happen in normal usage)
else:
torchrl_logger.debug(
f"Unknown prompt type for routing: {type(prompts)}"
)
return None
except Exception as e:
torchrl_logger.debug(f"Error extracting single prompt for routing: {e}")
return None
class LoadBalancer:
"""Load balancer for distributing requests across AsyncVLLM actors with strategy hierarchy.
This class implements sophisticated load balancing with multiple strategies and intelligent
fallback mechanisms. Strategies are tried in order until one succeeds, providing robust
request routing even when some strategies fail.
Args:
actors: Either a single AsyncVLLM instance or a list of Ray actors.
strategy: Single strategy or sequence of strategies in fallback order.
Available strategies:
- "prefix-aware": Route based on prompt prefix for cache locality
- "requests": Select actor with fewest pending requests
- "kv-cache": Select actor with lowest KV cache utilization
- "round-robin": Simple round-robin distribution
Default: ["prefix-aware", "requests"]
prefix_length: Number of tokens/words to use for prefix routing (default: 8).
overload_threshold: Multiplier for average load to consider actor overloaded (default: 1.5).
Examples:
>>> service = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-3B", num_replicas=3)
>>> # Simple strategy
>>> lb = LoadBalancer(service, "requests")
>>> actor_idx = lb.select_actor()
>>> # Strategy hierarchy: try prefix-aware first, fall back to requests, then round-robin
>>> lb = LoadBalancer(service, ["prefix-aware", "requests", "round-robin"])
>>> actor_idx = lb.select_actor(prompt="Hello world") # Uses prefix routing
>>> actor_idx = lb.select_actor() # Falls back to requests (no prompt)
>>> # Custom configuration
>>> lb = LoadBalancer(
... service,
... ["prefix-aware", "kv-cache"],
... prefix_length=16,
... overload_threshold=2.0
... )
"""
def __init__(
self,
actors: list[Any] | AsyncVLLM,
strategy: Literal["requests", "kv-cache"]
| Sequence[Literal["prefix-aware", "requests", "kv-cache", "round-robin"]]
| None = None,
prefix_length: int = 8,
overload_threshold: float = 1.5,
):
if strategy is None:
strategy = ["prefix-aware", "requests"]
# Handle both AsyncVLLM instances and direct actor lists
if hasattr(actors, "actors"): # AsyncVLLM instance
self.actors = actors.actors
self.async_vllm = actors
elif isinstance(actors, list): # Direct list of actors
self.actors = actors
self.async_vllm = None
else:
raise ValueError(
"actors must be either an AsyncVLLM instance or a list of actors"
)
if not self.actors:
raise ValueError("No actors provided")
# Handle both single strategy and strategy hierarchy
if isinstance(strategy, str):
self.strategies = [strategy]
else:
self.strategies = list(strategy)
# Validate strategies
valid_strategies = {"prefix-aware", "requests", "kv-cache", "round-robin"}
for s in self.strategies:
if s not in valid_strategies:
raise ValueError(
f"Invalid strategy '{s}'. Must be one of {valid_strategies}"
)
if not self.strategies:
raise ValueError("At least one strategy must be provided")
self.strategy = self.strategies[
0
] # Primary strategy for backward compatibility
self.prefix_length = prefix_length
self.overload_threshold = overload_threshold
self._round_robin_index = 0 # For round-robin fallback
def select_actor(
self,
prompt: str | list[int] | None = None,
request_context: dict[str, Any] | None = None,
) -> int:
"""Select the optimal actor index based on the configured strategy hierarchy.
Args:
prompt: The input prompt (string or token list) for prefix-aware routing.
request_context: Additional context for routing decisions.
Returns:
int: Index of the selected actor in the actors list.
Raises:
RuntimeError: If unable to gather metrics from actors.
ValueError: If no actors are available.
"""
if not self.actors:
raise ValueError("No actors available for selection")
# Try each strategy in order until one succeeds
for i, strategy in enumerate(self.strategies):
try:
torchrl_logger.debug(
f"Trying strategy {i+1}/{len(self.strategies)}: {strategy}"
)
if strategy == "prefix-aware":
if prompt is not None:
return self._select_by_prefix_aware(prompt)
else:
torchrl_logger.debug(
"No prompt provided for prefix-aware routing, trying next strategy"
)
continue
elif strategy == "requests":
return self._select_by_requests()
elif strategy == "kv-cache":
return self._select_by_cache_usage()
elif strategy == "round-robin":
return self._select_round_robin()
else:
torchrl_logger.warning(
f"Unknown strategy: {strategy}, trying next strategy"
)
continue
except Exception as e:
torchrl_logger.warning(
f"Strategy '{strategy}' failed with error: {e}. "
f"Trying next strategy..."
)
continue
# All strategies failed, final fallback to random
torchrl_logger.warning(
f"All strategies {self.strategies} failed. Falling back to random selection."
)
return random.randint(0, len(self.actors) - 1)
def _select_by_requests(self) -> int:
"""Select actor with fewest pending requests."""
if self.async_vllm is not None:
# Use AsyncVLLM's built-in method to get request counts
request_counts = self.async_vllm.get_num_unfinished_requests()
else:
# Query actors directly
futures = [
actor.get_num_unfinished_requests.remote() for actor in self.actors
]
request_counts = ray.get(futures)
# Find the actor with minimum pending requests
min_requests = min(request_counts)
min_indices = [
i for i, count in enumerate(request_counts) if count == min_requests
]
# If multiple actors have the same minimum count, choose randomly among them
selected_index = random.choice(min_indices)
torchrl_logger.debug(
f"LoadBalancer (requests): Selected actor {selected_index} "
f"with {min_requests} pending requests. "
f"Request counts: {request_counts}"
)
return selected_index
def _select_by_cache_usage(self) -> int:
"""Select actor with lowest KV cache utilization."""
if self.async_vllm is not None:
# Use AsyncVLLM's built-in method to get cache usage
cache_usages = self.async_vllm.get_cache_usage()
else:
# Query actors directly
futures = [actor.get_cache_usage.remote() for actor in self.actors]
cache_usages = ray.get(futures)
# Find the actor with minimum cache usage
min_usage = min(cache_usages)
min_indices = [
i for i, usage in enumerate(cache_usages) if abs(usage - min_usage) < 1e-6
]
# If multiple actors have similar cache usage, choose randomly among them
selected_index = random.choice(min_indices)
torchrl_logger.debug(
f"LoadBalancer (kv-cache): Selected actor {selected_index} "
f"with {min_usage:.3f} cache usage. "
f"Cache usages: {[f'{u:.3f}' for u in cache_usages]}"
)
return selected_index
def _select_by_prefix_aware(self, prompt: str | list[int]) -> int:
"""Select actor based on prompt prefix for cache locality.
Args:
prompt: Input prompt as string or token list.
Returns:
int: Selected actor index.
Raises:
ValueError: If prefix cannot be extracted.
"""
try:
# Extract prefix tokens
prefix_tokens = self._extract_prefix_tokens(prompt)
if not prefix_tokens:
raise ValueError("Could not extract meaningful prefix tokens")
# Create consistent hash from prefix
prefix_hash = hash(tuple(prefix_tokens))
preferred_actor = prefix_hash % len(self.actors)
# Check if preferred actor is overloaded
if self._is_actor_overloaded(preferred_actor):
torchrl_logger.debug(
f"Preferred actor {preferred_actor} is overloaded "
f"(threshold: {self.overload_threshold}), falling back to load-based selection"
)
# Fall back to requests-based selection
return self._select_by_requests()
torchrl_logger.debug(
f"LoadBalancer (prefix-aware): Selected actor {preferred_actor} "
f"for prefix hash {prefix_hash} (tokens: {prefix_tokens[:4]}...)"
)
return preferred_actor
except Exception as e:
torchrl_logger.warning(f"Prefix-aware routing failed: {e}")
raise
def _select_round_robin(self) -> int:
"""Select actor using round-robin strategy."""
selected = self._round_robin_index % len(self.actors)
self._round_robin_index = (self._round_robin_index + 1) % len(self.actors)
torchrl_logger.debug(f"LoadBalancer (round-robin): Selected actor {selected}")
return selected
def _extract_prefix_tokens(self, prompt: str | list[int]) -> list[int]:
"""Extract prefix tokens from prompt (string or token list).
Args:
prompt: Input prompt.
Returns:
list[int]: Prefix tokens (up to self.prefix_length).
Raises:
ValueError: If tokenization fails or prompt is invalid.
"""
if isinstance(prompt, list):
# Already tokenized
if not prompt:
raise ValueError("Empty token list provided")
return prompt[: self.prefix_length]
elif isinstance(prompt, str):
# Need to tokenize - this requires access to tokenizer
if not prompt.strip():
raise ValueError("Empty or whitespace-only string provided")
# Try to get tokenizer from AsyncVLLM instance
if self.async_vllm is not None:
try:
# This is a simplistic approach - in practice you'd want to cache the tokenizer
# For now, use a simple heuristic based on string content
return self._simple_string_hash(prompt)
except Exception as e:
torchrl_logger.warning(f"Could not tokenize string: {e}")
return self._simple_string_hash(prompt)
else:
# Fall back to simple string hashing
return self._simple_string_hash(prompt)
else:
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
def _simple_string_hash(self, text: str) -> list[int]:
"""Create pseudo-tokens from string for prefix routing.
This is a fallback when proper tokenization isn't available.
"""
# Use words as pseudo-tokens, limited to prefix_length
words = text.strip().split()[: self.prefix_length]
if not words:
raise ValueError("No words found in text")
# Convert words to integers using hash
pseudo_tokens = [
abs(hash(word)) % 50000 for word in words
] # Simulate vocab size
return pseudo_tokens
def _is_actor_overloaded(self, actor_index: int) -> bool:
"""Check if an actor is overloaded compared to average load.
Args:
actor_index: Index of actor to check.
Returns:
bool: True if actor is overloaded.
"""
try:
if self.async_vllm is not None:
request_counts = self.async_vllm.get_num_unfinished_requests()
else:
futures = [
actor.get_num_unfinished_requests.remote() for actor in self.actors
]
request_counts = ray.get(futures)
if not request_counts:
return False
avg_requests = sum(request_counts) / len(request_counts)
actor_requests = request_counts[actor_index]
is_overloaded = actor_requests > avg_requests * self.overload_threshold
torchrl_logger.debug(
f"Actor {actor_index}: {actor_requests} requests, "
f"avg: {avg_requests:.1f}, threshold: {avg_requests * self.overload_threshold:.1f}, "
f"overloaded: {is_overloaded}"
)
return is_overloaded
except Exception as e:
torchrl_logger.warning(f"Could not check actor load: {e}")
return False # Assume not overloaded if we can't check
def get_stats(self) -> dict[str, Any]:
"""Get current load balancing statistics for all actors.
Returns:
dict: Statistics including request counts and cache usage for all actors.
"""
stats = {
"strategies": self.strategies,
"primary_strategy": self.strategy, # For backward compatibility
"num_actors": len(self.actors),
"prefix_length": self.prefix_length,
"overload_threshold": self.overload_threshold,
"round_robin_index": self._round_robin_index,
"actor_stats": [],
}
try:
if self.async_vllm is not None:
request_counts = self.async_vllm.get_num_unfinished_requests()
cache_usages = self.async_vllm.get_cache_usage()
else:
request_futures = [
actor.get_num_unfinished_requests.remote() for actor in self.actors
]
cache_futures = [
actor.get_cache_usage.remote() for actor in self.actors
]
request_counts = ray.get(request_futures)
cache_usages = ray.get(cache_futures)
for i, (requests, cache_usage) in enumerate(
zip(request_counts, cache_usages)
):
stats["actor_stats"].append(
{
"actor_index": i,
"pending_requests": requests,
"cache_usage": cache_usage,
}
)
except Exception as e:
torchrl_logger.warning(f"Error gathering load balancer stats: {e}")
stats["error"] = str(e)
return stats
[docs]def make_async_vllm_engine(
model_name: str,
num_devices: int | None = None,
num_replicas: int = 1,
verbose: bool = True,
compile: bool = True,
**kwargs,
) -> AsyncVLLM:
"""Create an async vLLM engine service.
Args:
model_name (str): The model name to pass to vLLM.
num_devices (int, optional): Number of devices to use, per replica.
num_replicas (int): Number of engine replicas to create.
verbose (bool, optional): Whether to enable verbose logging with throughput statistics. Defaults to True.
compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
**kwargs: Additional arguments passed to AsyncEngineArgs.
Returns:
AsyncVLLM: The launched engine service.
Raises:
RuntimeError: If no CUDA devices are available.
ValueError: If invalid device configuration is provided.
Example:
>>> # Create a single-GPU async engine
>>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B")
>>>
>>> # Create a 2-GPU tensor parallel async engine with 2 replicas
>>> service = make_async_vllm_engine("Qwen/Qwen2.5-3B", num_devices=2, num_replicas=2)
>>> # Generate text
>>> result = service.generate("Hello, world!", sampling_params)
"""
if not _has_vllm:
raise ImportError(
"vllm is not installed. Please install it with `pip install vllm`."
)
from vllm import AsyncEngineArgs
# Check if CUDA is available since vLLM requires GPU
if not torch.cuda.is_available():
raise RuntimeError(
"AsyncVLLM requires CUDA but no GPU devices are available. "
"Please run on a machine with GPU support."
)
# Handle device specification
if num_devices is None:
num_devices = 1
# Configure verbose logging if requested
if verbose:
import logging
# Enable vLLM's throughput logging by setting the appropriate log level
logging.getLogger("vllm.engine.metrics").setLevel(logging.INFO)
logging.getLogger("vllm").setLevel(logging.INFO)
# vLLM logs throughput stats at INFO level every few seconds
# The stats include: prompt throughput, generation throughput, running/pending requests, GPU KV cache usage
torchrl_logger.info(
"Enabled verbose vLLM logging - throughput statistics will be displayed"
)
# Create engine args
kwargs.setdefault("distributed_executor_backend", "ray")
# Don't explicitly set enable_prefix_caching to avoid conflicts
kwargs.setdefault("enable_prefix_caching", True)
# Set compilation flag - this controls whether vLLM will compile the model for better performance
# Disabled by default in GRPO since it can cause issues during training
if "compilation_config" not in kwargs:
if compile:
kwargs["compilation_config"] = {"enabled": True}
else:
kwargs["compilation_config"] = {"enabled": False}
engine_args = AsyncEngineArgs(
model=model_name,
tensor_parallel_size=num_devices,
worker_cls="torchrl.modules.llm.backends.vllm.vllm_async._AsyncvLLMWorker",
**kwargs,
)
return AsyncVLLM.launch(engine_args, num_replicas)