Shortcuts

Source code for torchrl.collectors.llm.base

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from collections import deque
from typing import Any, Callable

import torch

from tensordict import lazy_stack, TensorDictBase

from torchrl._utils import as_remote, logger as torchrl_logger

from torchrl.collectors import SyncDataCollector
from torchrl.collectors.llm.utils import _QueueAsRB
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
from torchrl.envs import AsyncEnvPool
from torchrl.envs.common import EnvBase
from torchrl.envs.llm.transforms.policy_version import PolicyVersion


[docs]class LLMCollector(SyncDataCollector): """A simplified version of SyncDataCollector for LLM inference. Args: env (EnvBase or EnvBase constructor): the environment to be used for data collection. Keyword Args: policy (Callable[[TensorDictBase], TensorDictBase]): the policy to be used for data collection. policy_factory (Callable[[], Callable], optional): a callable that returns a policy instance. This is exclusive with the `policy` argument. .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. steps_per_batch (int): A keyword-only argument representing the total number of elements in a batch; -1 is never ending (until shutdown). total_steps (int): A keyword-only argument representing the total number of steps returned by the collector during its lifespan. yield_completed_trajectories (bool, optional): whether to yield batches of rollouts with a given number of steps (`yield_completed_trajectories=False`, default) or single, completed trajectories (`yield_completed_trajectories=True`). Defaults to `False` unless `yield_only_last_steps=True`, where it cannot be `False`. .. warning:: If the `done` state of the environment is not properly set, this may lead to a collector that never leads any data. yield_only_last_steps (bool, optional): whether to yield every step of a trajectory, or only the last (done) steps. If `True`, a single trajectory is yielded (or written in the buffer) at a time. .. warning:: If the `done` state of the environment is not properly set, this may lead to a collector that never leads any data. postproc (Callable, optional): A post-processing transform, such as a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` instance. Defaults to ``None``. async_envs (bool, optional): if ``True``, the environment will be run asynchronously. Defaults to `True` if the environment is a :class:`~torchrl.envs.AsyncEnvPool` instance. replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to ``None``. reset_at_each_iter (bool, optional): if ``True``, the environment will be reset at each iteration. flatten_data (bool, optional): if ``True``, the collector will flatten the collected data before returning it. In practice, this means that if an environment of batch-size `(B,)` is used and run for `T` steps, `flatten_data=True` will present data of shape `(B*T,)`, whereas `flatten_data=False` will not present data of shape `(B, T)`. Defaults to `True` when `replay_buffer` is provided, `False` otherwise. weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` or its subclass, responsible for updating the policy weights on remote inference workers. This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. Consider using a constructor if the updater needs to be serialized. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track the policy version. Defaults to `False`. verbose (bool, optional): if ``True``, the collector will print progress information. Defaults to `False`. Examples: >>> import vllm >>> from torchrl.modules import vLLMWrapper >>> from pytorch.rl.test.mocking_classes import DummyStrDataLoader >>> from torchrl.envs import LLMEnv >>> llm_model = vllm.LLM("gpt2") >>> tokenizer = llm_model.get_tokenizer() >>> tokenizer.pad_token = tokenizer.eos_token >>> policy = vLLMWrapper(llm_model) >>> dataloader = DummyStrDataLoader(1) >>> env = LLMEnv.from_dataloader( ... dataloader=dataloader, ... tokenizer=tokenizer, ... from_text=True, ... batch_size=1, ... group_repeats=True, ... ) >>> collector = LLMCollector( ... env=env, ... policy_factory=lambda: policy, ... dialog_turns_per_batch=env.batch_size[0], ... total_dialog_turns=3, ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break LazyStackedTensorDict( fields={ attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), collector: LazyStackedTensorDict( fields={ traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1), done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), text: NonTensorStack( [['plsgqejeyd']], batch_size=torch.Size([1, 1]), device=None), text_response: NonTensorStack( [['ec.n.n.n.tjbjz3perwhz']], batch_size=torch.Size([1, 1]), device=None), tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1) >>> del collector """ def __init__( self, env: EnvBase | Callable[[], EnvBase], *, policy: Callable[[TensorDictBase], TensorDictBase] | None = None, policy_factory: Callable[[], Callable[[TensorDictBase], TensorDictBase]] | None = None, dialog_turns_per_batch: int, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, total_dialog_turns: int = -1, async_envs: bool | None = None, replay_buffer: ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, queue: Any | None = None, track_policy_version: bool | PolicyVersion = False, verbose: bool = False, ): if queue is not None and replay_buffer is not None: raise RuntimeError( "Handling both a buffer and a queue is not possible at the moment." ) elif queue is not None: # disguise the queue as a replay buffer replay_buffer = _QueueAsRB(queue) super().__init__( create_env_fn=env, policy=policy, policy_factory=policy_factory, frames_per_batch=dialog_turns_per_batch, replay_buffer=replay_buffer, total_frames=total_dialog_turns, weight_updater=weight_updater, reset_at_each_iter=reset_at_each_iter, trust_policy=True, use_buffers=False, no_cuda_sync=True, extend_buffer=True, postproc=postproc, ) if yield_only_last_steps is None: yield_only_last_steps = False if yield_completed_trajectories is None: yield_completed_trajectories = yield_only_last_steps elif yield_only_last_steps and not yield_completed_trajectories: raise TypeError( "yield_only_last_steps=True requires yield_completed_trajectories=True (or None)" ) if yield_only_last_steps: if flatten_data is not None: raise TypeError( "`yield_only_last_steps` cannot be `True` when `flatten_data` is passed." ) if self.reset_at_each_iter: raise TypeError( "`yield_only_last_steps` cannot be `True` when `reset_at_each_iter=True`." ) if flatten_data is None: flatten_data = replay_buffer is not None self.flatten_data = flatten_data self.yield_completed_trajectories = yield_completed_trajectories self.yield_only_last_steps = yield_only_last_steps self.verbose = verbose if self.yield_completed_trajectories: if len(self.env.batch_size) != 1: raise ValueError( "`yield_only_last_steps` only works with envs that have a single batch dimension. Got " f"env.batch_size={self.env.batch_size}." ) self._yield_queues = [deque() for _ in range(self.env.batch_size[0])] self._trajectory_queue = deque() self.async_envs = bool(async_envs) | isinstance(self.env, AsyncEnvPool) if self.async_envs and not isinstance(self.env, AsyncEnvPool): # This basically means that `async_envs` is automatically set and passing is it useless as of today, # except for the following error. raise RuntimeError( "async_envs requires the environment to be an AsyncEnvPool instance." ) self.policy_version_tracker = track_policy_version if isinstance(track_policy_version, bool) and track_policy_version: if isinstance(self.env, AsyncEnvPool): raise RuntimeError( "AsyncEnvPool is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " "and pass that transform to the collector." ) self.policy_version_tracker = PolicyVersion() self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore elif isinstance(track_policy_version, PolicyVersion): self.policy_version_tracker = track_policy_version self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore else: self.policy_version_tracker = None def set_postproc(self, postproc: Callable[[TensorDictBase], TensorDictBase]): if self.postproc is not None: raise RuntimeError("Postproc already set") self.postproc = postproc
[docs] def increment_version(self): """Increment the policy version.""" if self.policy_version_tracker is not None: if not isinstance(self.policy_version_tracker, PolicyVersion): raise RuntimeError( "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." ) self.policy_version_tracker.increment_version()
@property def policy_version(self) -> str | int | None: """The current policy version.""" if not isinstance(self.policy_version_tracker, PolicyVersion): return None return self.policy_version_tracker.version
[docs] def get_policy_version(self) -> str | int | None: """Get the current policy version. This method exists to support remote calls in Ray actors, since properties cannot be accessed directly through Ray's RPC mechanism. Returns: The current version number (int) or UUID (str), or None if version tracking is disabled. """ return self.policy_version
@property def total_dialog_turns(self): return self.total_frames @property def dialog_turns_per_batch(self) -> int: """Alias to `frames_per_batch`.""" return self.requested_frames_per_batch @property def rollout(self) -> Callable[[], TensorDictBase]: if self.yield_completed_trajectories: if self.async_envs: return self._rollout_yield_trajs_async else: return self._rollout_yield_trajs else: return self._rollout_all def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout if self.reset_at_each_iter or self._shuttle is None: self._shuttle = self.env.reset() trajectory = [] collected_steps = 0 policy_input = self._shuttle while collected_steps < self.dialog_turns_per_batch: if self.verbose: torchrl_logger.info( f"LLMCollector: Collected {collected_steps} steps over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(policy_input) env_output, env_next_output = self.env.step_and_maybe_reset(env_input) # carry over collector data without messing up devices collector_data = env_output.get("collector").copy() env_next_output.set("collector", collector_data) self._update_traj_ids(env_output) trajectory.append(env_output.clone()) collected_steps += env_output.numel() policy_input = self._shuttle = env_next_output trajectory = lazy_stack(trajectory, -1) if self.flatten_data: return trajectory.view(-1) return trajectory def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rollout if self._shuttle is None: raise RuntimeError("Data shuttle not found") # next_output = self.env.reset() else: next_output = self._shuttle collected_steps = 0 dones = torch.zeros(self.env.batch_size, dtype=torch.bool) while True: if self._trajectory_queue: break env_input = self.policy(next_output) cur_output, next_output = self.env.step_and_maybe_reset(env_input) # for i in range(cur_output.numel()): # print(len(cur_output[i]["text"]) < len(cur_output[i]["next", "text"])) # carry over collector data without messing up devices self._update_traj_ids(cur_output) collector_data = cur_output.get("collector").copy() next_output.set("collector", collector_data) # if the loop is interrupted self._shuttle = next_output collected_steps += next_output.numel() for i, (_data, queue) in enumerate( zip(cur_output.unbind(0), self._yield_queues) ): queue.append(_data) dones[i] = _data["next", "done"].any() if dones.any(): for idx in dones.nonzero(as_tuple=True)[0].tolist(): if not self.yield_only_last_steps: self._trajectory_queue.append( lazy_stack(self._yield_queues[idx], -1) ) else: # FIXME: We need to increment the step count here because iterator() won't # see the extra steps # We use lazy-stack because unsqueeze doesn't nest the strings in lists self._trajectory_queue.append( lazy_stack([self._yield_queues[idx][-1]]) ) self._yield_queues[idx].clear() result = self._trajectory_queue.popleft() if self.verbose: torchrl_logger.info( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result started = False def _rollout_yield_trajs_async( self, ) -> TensorDictBase: # A simplified version of rollout if not self.started: next_output = self._shuttle env_input = self.policy(next_output) self.env.async_step_and_maybe_reset_send(env_input) self.started = True collected_steps = 0 dones = torch.zeros(self.env.batch_size, dtype=torch.bool) while True: if self._trajectory_queue: break cur_output, next_output = self.env.async_step_and_maybe_reset_recv() # Get the env ids env_ids = cur_output.get(self.env._env_idx_key).tolist() # carry over collector data without messing up devices self._update_traj_ids(cur_output) collector_data = cur_output.get("collector").copy() next_output.set("collector", collector_data) collected_steps += next_output.numel() dones.fill_(False) for i, _data in zip(env_ids, cur_output.unbind(0)): queue = self._yield_queues[i] queue.append(_data) dones[i] = _data["next", "done"].any() if dones.any(): for idx in dones.nonzero(as_tuple=True)[0].tolist(): if not self.yield_only_last_steps: self._trajectory_queue.append( lazy_stack(self._yield_queues[idx], -1) ) else: # FIXME: We need to increment the step count here because iterator() won't # see the extra steps # We use lazy-stack because unsqueeze doesn't nest the strings in lists self._trajectory_queue.append( lazy_stack([self._yield_queues[idx][-1]]) ) self._yield_queues[idx].clear() # Launch the next batch: # FIXME: Add a condition RE number of frames here if True: env_input = self.policy(next_output) self.env.async_step_and_maybe_reset_send(env_input) result = self._trajectory_queue.popleft() if self.verbose: torchrl_logger.info( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result as_remote = as_remote
[docs] def get_policy_model(self): """Get the policy model. This method is used by RayLLMCollector to get the remote LLM instance for weight updates. Returns: The policy model instance """ return self.policy.model
[docs] def is_initialized(self) -> bool: """Check if the collector is initialized and ready. Returns: bool: True if the collector is initialized and ready to collect data. """ # The collector is initialized if it has a valid environment and policy return hasattr(self, "_env") and hasattr(self, "_policy")
def set_weight_updater(self, weight_updater: WeightUpdaterBase): self.weight_updater = weight_updater return True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources