Source code for torchrl.collectors.llm.weight_update.vllm_v2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import time
from collections.abc import Iterator
import torch
from tensordict import TensorDictBase
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors import WeightUpdaterBase
from torchrl.modules.llm.backends.vllm import RLvLLMEngine
try:
    pass
    _has_transformers = True
except ImportError:
    _has_transformers = False
[docs]class vLLMUpdaterV2(WeightUpdaterBase):
    """Simplified vLLM weight updater using the RLvLLMEngine interface.
    This updater works with any vLLM engine that implements the RLvLLMEngine
    interface, automatically extracting configuration and handling weight updates
    through the engine's own methods.
    Args:
        vllm_engine: A vLLM engine implementing the RLvLLMEngine interface.
    .. note:: This class can be created through :class:`torchrl.collectors.llm.vLLMUpdater` with `v2=True`.
    """
    def __init__(self, vllm_engine: RLvLLMEngine):
        # Check that vllm_engine implements the RLvLLMEngine interface
        if not isinstance(vllm_engine, RLvLLMEngine):
            raise TypeError(
                f"vllm_engine must implement RLvLLMEngine interface, got {type(vllm_engine)}"
            )
        torchrl_logger.info(f"=> in {type(self).__name__}.__init__")
        self.vllm_engine = vllm_engine
        self.initialized_group = None
        # Extract configuration from engine
        self.vllm_tp_size = vllm_engine.get_tp_size()
        self.master_address = vllm_engine.get_master_address()
        self.master_port = vllm_engine.get_master_port()
        self.model_metadata = vllm_engine.get_model_metadata()
        torchrl_logger.info(
            f"Initialized vLLMUpdaterV2 with tp_size={self.vllm_tp_size}"
        )
[docs]    def init(
        self, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None
    ) -> None:
        """Initialize the weight updater.
        Args:
            model_metadata: Optional model metadata. If not provided, uses engine's metadata.
        """
        if model_metadata is not None:
            self.model_metadata = model_metadata
        # Initialize the engine's weight update group
        self.vllm_engine.init_weight_update_group()
        self.initialized_group = True
        torchrl_logger.info("Weight update group initialized")
[docs]    def push_weights(
        self, weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase
    ):
        """Push weights to the vLLM engine.
        Args:
            weights: Either an iterator of (name, tensor) pairs or a TensorDictBase
        """
        if isinstance(weights, TensorDictBase):
            weights = iter(weights.flatten_keys(".").items())
        if self.initialized_group is None:
            raise RuntimeError("Weight updater not initialized. Call init() first.")
        # Delegate to the engine's update_weights method
        self.vllm_engine.update_weights(weights)
        torchrl_logger.info("Weight update completed")
        # Call post-hooks to increment policy version
        torchrl_logger.info("Calling post-hooks...")
        self._call_post_hooks()
        torchrl_logger.info("Post-hooks completed")
[docs]    def push_weights_from_transformers(self, transformers_model):
        """Push weights from a transformers model.
        Args:
            transformers_model: A transformers PreTrainedModel or TorchRL wrapper
        """
        if not _has_transformers:
            raise ImportError("transformers not available")
        t0 = time.time()
        # Extract state dict from model, handling LoRA models properly
        if hasattr(transformers_model, "model") and hasattr(
            transformers_model.model, "state_dict"
        ):
            # TorchRL wrapper (e.g., TransformersWrapper)
            model = transformers_model.model
            # Check if it's a LoRA model
            if hasattr(model, "merge_and_unload"):
                state_dict = model.merge_and_unload().state_dict()
            else:
                state_dict = model.state_dict()
        elif hasattr(transformers_model, "state_dict"):
            # Direct transformers model
            # Check if it's a LoRA model
            if hasattr(transformers_model, "merge_and_unload"):
                state_dict = transformers_model.merge_and_unload().state_dict()
            else:
                state_dict = transformers_model.state_dict()
        else:
            raise TypeError(
                f"Cannot extract state_dict from {type(transformers_model)}"
            )
        t1 = time.time()
        torchrl_logger.info(f"Time to extract state_dict: {t1 - t0}")
        # Convert to iterator for memory efficiency
        weights_iter = iter(state_dict.items())
        self.push_weights(weights_iter)
        torchrl_logger.info(f"Time to push weights: {time.time() - t1}")
[docs]    def push_weights_from_transformers_optimized(
        self, transformers_model, batch_size=50
    ):
        """Optimized version of push_weights_from_transformers with GPU pre-loading.
        This method provides several optimizations:
        1. Pre-loads all weights to GPU before transfer
        2. Optionally batches weights for better memory management
        3. Uses non-blocking transfers when possible
        Args:
            transformers_model: A transformers PreTrainedModel or TorchRL wrapper
            batch_size: Number of weights to transfer in each batch (0 = no batching)
        """
        if not _has_transformers:
            raise ImportError("transformers not available")
        t0 = time.time()
        # Extract state dict from model, handling LoRA models properly
        if hasattr(transformers_model, "model") and hasattr(
            transformers_model.model, "state_dict"
        ):
            # TorchRL wrapper (e.g., TransformersWrapper)
            model = transformers_model.model
            if hasattr(model, "merge_and_unload"):
                state_dict = model.merge_and_unload().state_dict()
            else:
                state_dict = model.state_dict()
        elif hasattr(transformers_model, "state_dict"):
            # Direct transformers model
            if hasattr(transformers_model, "merge_and_unload"):
                state_dict = transformers_model.merge_and_unload().state_dict()
            else:
                state_dict = transformers_model.state_dict()
        else:
            raise TypeError(
                f"Cannot extract state_dict from {type(transformers_model)}"
            )
        t1 = time.time()
        torchrl_logger.info(f"Time to extract state_dict: {t1 - t0:.3f}s")
        # Pre-load all weights to GPU for faster transfer
        gpu_weights = {}
        with torch.device("cuda:0"):  # Ensure we're using the right GPU
            for name, weight in state_dict.items():
                if not weight.is_cuda:
                    gpu_weights[name] = weight.cuda(non_blocking=True)
                else:
                    gpu_weights[name] = weight
        # Synchronize to ensure all transfers are complete
        torch.cuda.synchronize()
        t2 = time.time()
        torchrl_logger.info(f"Time to move weights to GPU: {t2 - t1:.3f}s")
        # Transfer weights (optionally in batches)
        if batch_size > 0:
            weight_items = list(gpu_weights.items())
            for i in range(0, len(weight_items), batch_size):
                batch = weight_items[i : i + batch_size]
                self.push_weights(iter(batch))
                torchrl_logger.info(
                    f"Transferred batch {i//batch_size + 1}/{(len(weight_items) + batch_size - 1)//batch_size}"
                )
        else:
            # Transfer all at once
            self.push_weights(iter(gpu_weights.items()))
        t3 = time.time()
        torchrl_logger.info(
            f"Time to push weights: {t3 - t2:.3f}s, total time: {t3 - t0:.3f}s"
        )
    # Required WeightUpdaterBase methods
    def _sync_weights_with_worker(self, *, worker_id=None, server_weights=None):
        """Sync weights with worker (delegates to push_weights)."""
        if server_weights is None:
            raise ValueError("server_weights cannot be None")
        if hasattr(server_weights, "items"):
            # Dict-like object
            self.push_weights(iter(server_weights.items()))
        else:
            # Assume it's a model with state_dict
            self.push_weights_from_transformers(server_weights)
    def _get_server_weights(self):
        """Not used - weights must be passed directly."""
        return None
    def _maybe_map_weights(self, server_weights):
        """Map weights to expected format."""
        return server_weights  # No mapping needed, handled in push_weights methods
[docs]    def register_collector(self, collector):  # noqa: F821
        """Register a collector and set up policy version increment post-hook.
        Args:
            collector: The collector to register (DataCollectorBase)
        """
        result = super().register_collector(collector)
        # Only register the increment_version post-hook once for the first collector
        # This avoids N^2 complexity where each weight update calls increment_version
        # on all collectors N times (once per registered collector)
        if len(self.post_hooks) == 0:
            torchrl_logger.info("Registering policy version increment post-hook")
            self.register_post_hook(self._increment_all_collector_versions)
        return result
    def _increment_all_collector_versions(self):
        """Increment version for all registered collectors efficiently."""
        torchrl_logger.info(
            f"Incrementing policy version for {len(self.collectors)} collectors..."
        )
        for i, collector in enumerate(self.collectors):
            try:
                collector.increment_version()
                torchrl_logger.debug(
                    f"Incremented version for collector {i+1}/{len(self.collectors)}"
                )
            except Exception as e:
                torchrl_logger.warning(
                    f"Failed to increment version for collector {i+1}: {e}"
                )
        torchrl_logger.info("All collector versions incremented")
[docs]    @classmethod
    def get_model_metadata(cls, model) -> dict[str, tuple[torch.dtype, torch.Size]]:
        """Get model metadata from a model.
        Args:
            model: A model with state_dict() method (e.g., TransformersWrapper)
        Returns:
            dict: Mapping of parameter names to (dtype, shape) tuples
        """
        if hasattr(model, "model") and hasattr(model.model, "state_dict"):
            # TorchRL wrapper (e.g., TransformersWrapper)
            model_obj = model.model
            # Check if it's a LoRA model
            if hasattr(model_obj, "merge_and_unload"):
                sd = model_obj.merge_and_unload().state_dict()
            else:
                sd = model_obj.state_dict()
        elif hasattr(model, "state_dict"):
            # Direct model
            # Check if it's a LoRA model
            if hasattr(model, "merge_and_unload"):
                sd = model.merge_and_unload().state_dict()
            else:
                sd = model.state_dict()
        else:
            raise TypeError(f"Cannot extract state_dict from {type(model)}")
        return {k: (v.dtype, v.shape) for k, v in sd.items()}
    # Remove the weakrefs from the updater for serialization
    def __getstate__(self):
        state = self.__dict__.copy()
        state["_collector_wrs"] = None
        return state