Source code for torchrl.collectors.llm.ray_collector
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import copy
import warnings
from typing import Any, Callable, Iterator
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModuleBase
from torchrl.collectors.llm import LLMCollector
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.envs import EnvBase
from torchrl.envs.llm.transforms.policy_version import PolicyVersion
RAY_ERR = None
try:
import ray
_has_ray = True
except ImportError as err:
_has_ray = False
RAY_ERR = err
[docs]class RayLLMCollector(LLMCollector):
"""A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely.
Args:
env (EnvBase or EnvBase constructor): the environment to be used for data collection.
Keyword Args:
policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection.
policy_factory (Callable[[], Callable], optional): a callable that returns
a policy instance. This is exclusive with the `policy` argument.
dialog_turns_per_batch (int): A keyword-only argument representing the total
number of elements in a batch.
total_dialog_turns (int): A keyword-only argument representing the total
number of dialog turns returned by the collector during its lifespan.
yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the
last (done) steps.
yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps
or single, completed trajectories.
postproc (Callable, optional): A post-processing transform.
async_envs (bool, optional): if True, the environment will be run asynchronously.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
but populate the buffer instead.
reset_at_each_iter (bool, optional): if True, the environment will be reset at each iteration.
flatten_data (bool, optional): if True, the collector will flatten the collected data
before returning it.
weight_updater (WeightUpdaterBase or constructor, optional): An instance of WeightUpdaterBase
or its subclass, responsible for updating the policy weights on remote inference workers.
ray_init_config (dict[str, Any], optional): keyword arguments to pass to ray.init().
remote_config (dict[str, Any], optional): keyword arguments to pass to cls.as_remote().
sync_iter (bool, optional): if `True`, items yeilded by the collector will be synced to the local process.
If `False`, the collector will collect the next batch of data in between yielding.
This has no effect when data is collected through the :meth:`start` method.
For example:
>>> collector = RayLLMCollector(..., sync_iter=True)
>>> for data in collector: # blocking
... # expensive operation - collector is idle
>>> collector = RayLLMCollector(..., sync_iter=False)
>>> for data in collector: # non-blocking
... # expensive operation - collector is collecting data
This is somehwat equivalent to using :class:`~torchrl.collectors.MultiSyncDataCollector` (`sync_iter=True`) or
:class:`~torchrl.collectors.MultiAsyncDataCollector` (`sync_iter=False`).
Defaults to `True`.
verbose (bool, optional): if ``True``, the collector will print progress information.
Defaults to `False`.
"""
def __init__(
self,
env: EnvBase | Callable[[], EnvBase],
*,
policy: Callable[[TensorDictBase], TensorDictBase] | None = None,
policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]]
| None = None,
dialog_turns_per_batch: int,
total_dialog_turns: int = -1,
yield_only_last_steps: bool | None = None,
yield_completed_trajectories: bool | None = None,
postproc: Callable[[TensorDictBase], TensorDictBase] | None = None,
async_envs: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
reset_at_each_iter: bool = False,
flatten_data: bool | None = None,
weight_updater: WeightUpdaterBase
| Callable[[], WeightUpdaterBase]
| None = None,
ray_init_config: dict[str, Any] | None = None,
remote_config: dict[str, Any] | None = None,
track_policy_version: bool | PolicyVersion = False,
sync_iter: bool = True,
verbose: bool = False,
) -> None:
if not _has_ray:
raise RuntimeError(
"ray library not found, unable to create a RayLLMCollector. "
) from RAY_ERR
if not ray.is_initialized():
if ray_init_config is None:
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
ray_init_config = DEFAULT_RAY_INIT_CONFIG
ray.init(**ray_init_config)
if not sync_iter:
remote_config = copy.copy(remote_config)
remote_config.setdefault("max_concurrency", 2)
remote_cls = LLMCollector.as_remote(remote_config).remote
self.sync_iter = sync_iter
self._collector = remote_cls(
env=env,
policy=policy,
policy_factory=policy_factory,
dialog_turns_per_batch=dialog_turns_per_batch,
total_dialog_turns=total_dialog_turns,
yield_only_last_steps=yield_only_last_steps,
yield_completed_trajectories=yield_completed_trajectories,
postproc=postproc,
async_envs=async_envs,
replay_buffer=replay_buffer,
reset_at_each_iter=reset_at_each_iter,
flatten_data=flatten_data,
weight_updater=weight_updater,
track_policy_version=track_policy_version,
verbose=verbose,
)
def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]):
return ray.get(self._collector.set_postproc.remote(postproc))
def _next_remote(self) -> None:
return self._collector.next.remote()
[docs] def next(self) -> None:
"""Get the next batch of data from the collector.
Returns:
None as the data is written directly to the replay buffer.
"""
return ray.get(self._next_remote())
def __iter__(self) -> Iterator[None]:
"""Returns an iterator that yields None as the collector writes directly to the replay buffer."""
if not self.sync_iter:
future = self._next_remote()
else:
future = None
while True:
try:
if self.sync_iter:
yield self.next()
else:
result = ray.get(future)
future = self._next_remote()
yield result
except StopIteration:
break
[docs] def start(self):
"""Starts the collector in a background thread."""
pending_task = self._collector.start.remote()
return ray.get(pending_task)
[docs] def shutdown(self):
"""Shuts down the collector."""
pending_task = self._collector.shutdown.remote()
return ray.get(pending_task)
[docs] def async_shutdown(self, timeout=None):
"""Shuts down the collector asynchronously."""
pending_task = self._collector.async_shutdown.remote(timeout=timeout)
return ray.get(pending_task)
[docs] def update_policy_weights_(
self,
policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None,
*,
worker_ids: torch.device | int | list[int] | list[torch.device] | None = None,
**kwargs,
):
"""Updates the policy weights on remote workers.
Args:
policy_or_weights: The weights to update with. Can be:
- TensorDictModuleBase: A policy module whose weights will be extracted
- TensorDictBase: A TensorDict containing weights
- dict: A regular dict containing weights
- None: Will try to get weights from server using _get_server_weights()
worker_ids: The workers to update. If None, updates all workers.
"""
if "policy_weights" in kwargs:
warnings.warn(
"`policy_weights` is deprecated. Use `policy_or_weights` instead.",
DeprecationWarning,
)
policy_or_weights = kwargs.pop("policy_weights")
pending_task = self._collector.update_policy_weights_.remote(
policy_or_weights=policy_or_weights, worker_ids=worker_ids
)
return ray.get(pending_task)
@property
def total_dialog_turns(self):
"""Total number of dialog turns to collect."""
return ray.get(self._collector.total_dialog_turns.remote)
@property
def dialog_turns_per_batch(self) -> int:
"""Number of dialog turns per batch."""
return ray.get(self._collector.dialog_turns_per_batch.remote)
@property
def rollout(self) -> Callable[[], TensorDictBase]:
"""Returns the rollout function."""
return ray.get(self._collector.rollout.remote())
[docs] def init_updater(self, *args, **kwargs):
"""Initialize the weight updater with custom arguments.
This method calls init_updater on the remote collector.
Args:
*args: Positional arguments for weight updater initialization
**kwargs: Keyword arguments for weight updater initialization
"""
ray.get(self._collector.init_updater.remote(*args, **kwargs))
@property
def policy_version(self) -> str | int | None:
"""The current version of the policy.
Returns:
The current version number (int) or UUID (str), or None if version tracking is disabled.
"""
return ray.get(self._collector.get_policy_version.remote())
@property
def weight_updater(self) -> WeightUpdaterBase:
"""The weight updater instance.
We can pass the weight updater because it's stateless, hence serializable.
"""
return ray.get(self._collector.weight_updater.remote)
@weight_updater.setter
def weight_updater(self, weight_updater: WeightUpdaterBase):
"""Set the weight updater instance."""
ray.get(self._collector.set_weight_updater.remote(weight_updater))
weight_updater.register_collector(self)
[docs] def increment_version(self):
"""Increment the policy version."""
return ray.get(self._collector.increment_version.remote())