LLMCollector¶
- class torchrl.collectors.llm.LLMCollector(env: torchrl.envs.common.EnvBase | collections.abc.Callable[[], torchrl.envs.common.EnvBase], *, policy: collections.abc.Callable[[tensordict.base.TensorDictBase], tensordict.base.TensorDictBase] | None = None, policy_factory: collections.abc.Callable[[], collections.abc.Callable[[tensordict.base.TensorDictBase], tensordict.base.TensorDictBase]] | None = None, dialog_turns_per_batch: int | None = None, yield_only_last_steps: bool | None = None, yield_completed_trajectories: bool | None = None, postproc: collections.abc.Callable[[tensordict.base.TensorDictBase], tensordict.base.TensorDictBase] | None = None, total_dialog_turns: int = - 1, async_envs: bool | None = None, replay_buffer: torchrl.data.replay_buffers.replay_buffers.ReplayBuffer | None = None, reset_at_each_iter: bool = False, flatten_data: bool | None = None, weight_updater: torchrl.collectors.weight_update.WeightUpdaterBase | collections.abc.Callable[[], torchrl.collectors.weight_update.WeightUpdaterBase] | None = None, queue: Any | None = None, track_policy_version: bool | torchrl.envs.llm.transforms.policy_version.PolicyVersion = False, verbose: bool = False)[source]¶
A simplified version of Collector for LLM inference.
- Parameters:
env (EnvBase or EnvBase constructor) – the environment to be used for data collection.
- Keyword Arguments:
policy (Callable[[TensorDictBase], TensorDictBase]) – the policy to be used for data collection.
policy_factory (Callable[[], Callable], optional) –
a callable that returns a policy instance. This is exclusive with the policy argument.
Note
policy_factory comes in handy whenever the policy cannot be serialized.
dialog_turns_per_batch (int, optional) – A keyword-only argument representing the total number of elements in a batch. It is always required except when yield_completed_trajectories=True.
total_dialog_turns (int) – A keyword-only argument representing the total number of steps returned by the collector during its lifespan. -1 is never ending (until shutdown). Defaults to -1.
yield_completed_trajectories (bool, optional) –
whether to yield batches of rollouts with a given number of steps (yield_completed_trajectories=False, default) or single, completed trajectories (yield_completed_trajectories=True). Defaults to False unless yield_only_last_steps=True, where it cannot be False.
Warning
If the done state of the environment is not properly set, this may lead to a collector that never leads any data.
yield_only_last_steps (bool, optional) –
whether to yield every step of a trajectory, or only the last (done) steps. If True, a single trajectory is yielded (or written in the buffer) at a time.
Warning
If the done state of the environment is not properly set, this may lead to a collector that never leads any data.
postproc (Callable, optional) – A post-processing transform, such as a
Transformor aMultiStepinstance. Defaults toNone.async_envs (bool, optional) – if
True, the environment will be run asynchronously. Defaults to True if the environment is aAsyncEnvPoolinstance.replay_buffer (ReplayBuffer, optional) – if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to
None.reset_at_each_iter (bool, optional) – if
True, the environment will be reset at each iteration.flatten_data (bool, optional) – if
True, the collector will flatten the collected data before returning it. In practice, this means that if an environment of batch-size (B,) is used and run for T steps, flatten_data=True will present data of shape (B*T,), whereas flatten_data=False will not present data of shape (B, T). Defaults to True when replay_buffer is provided, False otherwise.weight_updater (WeightUpdaterBase or constructor, optional) – An instance of
WeightUpdaterBaseor its subclass, responsible for updating the policy weights on remote inference workers. This is typically not used inCollectoras it operates in a single-process environment. Consider using a constructor if the updater needs to be serialized.track_policy_version (bool or PolicyVersion, optional) – if
True, the collector will track the version of the policy. This will be mediated by thePolicyVersiontransform, which will be added to the environment. Alternatively, aPolicyVersioninstance can be passed, which will be used to track the policy version. Defaults to False.verbose (bool, optional) – if
True, the collector will print progress information. Defaults to False.
Examples
>>> import vllm >>> from torchrl.modules import vLLMWrapper >>> from torchrl.testing.mocking_classes import DummyStrDataLoader >>> from torchrl.envs import LLMEnv >>> llm_model = vllm.LLM("gpt2") >>> tokenizer = llm_model.get_tokenizer() >>> tokenizer.pad_token = tokenizer.eos_token >>> policy = vLLMWrapper(llm_model) >>> dataloader = DummyStrDataLoader(1) >>> env = LLMEnv.from_dataloader( ... dataloader=dataloader, ... tokenizer=tokenizer, ... from_text=True, ... batch_size=1, ... group_repeats=True, ... ) >>> collector = LLMCollector( ... env=env, ... policy_factory=lambda: policy, ... dialog_turns_per_batch=env.batch_size[0], ... total_dialog_turns=3, ... ) >>> for i, data in enumerate(collector): ... if i == 2: ... print(data) ... break LazyStackedTensorDict( fields={ attention_mask: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), collector: LazyStackedTensorDict( fields={ traj_ids: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1), done: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), terminated: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.bool, is_shared=False), text: NonTensorStack( [['plsgqejeyd']], batch_size=torch.Size([1, 1]), device=None), text_response: NonTensorStack( [['ec.n.n.n.tjbjz3perwhz']], batch_size=torch.Size([1, 1]), device=None), tokens: Tensor(shape=torch.Size([1, 1, 22]), device=cpu, dtype=torch.int64, is_shared=False), tokens_response: Tensor(shape=torch.Size([1, 1, 16]), device=cpu, dtype=torch.int64, is_shared=False)}, exclusive_fields={ }, batch_size=torch.Size([1, 1]), device=None, is_shared=False, stack_dim=1) >>> del collector
- classmethod as_remote(remote_config: dict[str, Any] | None = None)¶
Creates an instance of a remote ray class.
- Parameters:
cls (Python Class) – class to be remotely instantiated.
remote_config (dict) – the quantity of CPU cores to reserve for this class.
- Returns:
A function that creates ray remote class instances.
- async_shutdown(timeout: float | None = None, close_env: bool = True) None¶
Finishes processes started by ray.init() during async execution.
- cascade_execute(attr_path: str, *args, **kwargs) Any¶
Execute a method on a nested attribute of this collector.
This method allows remote callers to invoke methods on nested attributes of the collector without needing to know the full structure. It’s particularly useful for calling methods on weight sync schemes from the sender side.
- Parameters:
attr_path – Full path to the callable, e.g., “_receiver_schemes[‘model_id’]._set_dist_connection_info”
*args – Positional arguments to pass to the method.
**kwargs – Keyword arguments to pass to the method.
- Returns:
The return value of the method call.
Examples
>>> collector.cascade_execute( ... "_receiver_schemes['policy']._set_dist_connection_info", ... connection_info_ref, ... worker_idx=0 ... )
- property dialog_turns_per_batch: int¶
Alias to frames_per_batch.
- get_model(model_id: str)¶
Get model instance by ID (for weight sync schemes).
- Parameters:
model_id – Model identifier (e.g., “policy”, “value_net”)
- Returns:
The model instance
- Raises:
ValueError – If model_id is not recognized
- get_policy_model()[source]¶
Get the policy model.
This method is used by RayLLMCollector to get the remote LLM instance for weight updates.
- Returns:
The policy model instance
- get_policy_version() str | int | None[source]¶
Get the current policy version.
This method exists to support remote calls in Ray actors, since properties cannot be accessed directly through Ray’s RPC mechanism.
- Returns:
The current version number (int) or UUID (str), or None if version tracking is disabled.
- getattr_env(attr)¶
Get an attribute from the environment.
- getattr_policy(attr)¶
Get an attribute from the policy.
- getattr_rb(attr)¶
Get an attribute from the replay buffer.
- init_updater(*args, **kwargs)¶
Initialize the weight updater with custom arguments.
This method passes the arguments to the weight updater’s init method. If no weight updater is set, this is a no-op.
- Parameters:
*args – Positional arguments for weight updater initialization
**kwargs – Keyword arguments for weight updater initialization
- is_initialized() bool[source]¶
Check if the collector is initialized and ready.
- Returns:
True if the collector is initialized and ready to collect data.
- Return type:
bool
- iterator() Iterator[TensorDictBase]¶
Iterates through the DataCollector.
Yields: TensorDictBase objects containing (chunks of) trajectories
- load_state_dict(state_dict: OrderedDict, **kwargs) None¶
Loads a state_dict on the environment and policy.
- Parameters:
state_dict (OrderedDict) – ordered dictionary containing the fields “policy_state_dict” and
"env_state_dict".
- pause()¶
Context manager that pauses the collector if it is running free.
- property policy_version: str | int | None¶
The current policy version.
- receive_weights(policy_or_weights: tensordict.base.TensorDictBase | tensordict.nn.common.TensorDictModuleBase | torch.nn.modules.module.Module | dict | None = None, *, weights: tensordict.base.TensorDictBase | dict | None = None, policy: tensordict.nn.common.TensorDictModuleBase | torch.nn.modules.module.Module | None = None) None¶
Receive and apply weights to the collector’s policy.
This method applies weights to the local policy. When receiver schemes are registered, it delegates to those schemes. Otherwise, it directly applies the provided weights.
The method accepts weights in multiple forms for convenience:
Examples
>>> # Receive from registered schemes (distributed collectors) >>> collector.receive_weights() >>> >>> # Apply weights from a policy module (positional) >>> collector.receive_weights(trained_policy) >>> >>> # Apply weights from a TensorDict (positional) >>> collector.receive_weights(weights_tensordict) >>> >>> # Use keyword arguments for clarity >>> collector.receive_weights(weights=weights_td) >>> collector.receive_weights(policy=trained_policy)
- Parameters:
policy_or_weights –
The weights to apply. Can be:
nn.Module: A policy module whose weights will be extracted and appliedTensorDictModuleBase: A TensorDict module whose weights will be extractedTensorDictBase: A TensorDict containing weightsdict: A regular dict containing weightsNone: Receive from registered schemes or mirror from original policy
- Keyword Arguments:
weights – Alternative to positional argument. A TensorDict or dict containing weights to apply. Cannot be used together with
policy_or_weightsorpolicy.policy – Alternative to positional argument. An
nn.ModuleorTensorDictModuleBasewhose weights will be extracted. Cannot be used together withpolicy_or_weightsorweights.
- Raises:
ValueError – If conflicting parameters are provided or if arguments are passed when receiver schemes are registered.
- register_scheme_receiver(weight_recv_schemes: dict[str, torchrl.weight_update.weight_sync_schemes.WeightSyncScheme], *, synchronize_weights: bool = True)¶
Set up receiver schemes for this collector to receive weights from parent collectors.
This method initializes receiver schemes and stores them in _receiver_schemes for later use by _receive_weights_scheme() and receive_weights().
Receiver schemes enable cascading weight updates across collector hierarchies: - Parent collector sends weights via its weight_sync_schemes (senders) - Child collector receives weights via its weight_recv_schemes (receivers) - If child is also a parent (intermediate node), it can propagate to its own children
- Parameters:
weight_recv_schemes (dict[str, WeightSyncScheme]) – Dictionary of {model_id: WeightSyncScheme} to set up as receivers. These schemes will receive weights from parent collectors.
- Keyword Arguments:
synchronize_weights (bool, optional) – If True, synchronize weights immediately after registering the schemes. Defaults to True.
- reset(index=None, **kwargs) None¶
Resets the environments to a new initial state.
- property rollout: Callable[[], TensorDictBase]¶
Computes a rollout in the environment using the provided policy.
- Returns:
TensorDictBase containing the computed rollout.
- set_seed(seed: int, static_seed: bool = False) int¶
Sets the seeds of the environments stored in the DataCollector.
- Parameters:
seed (int) – integer representing the seed to be used for the environment.
static_seed (bool, optional) – if
True, the seed is not incremented. Defaults to False
- Returns:
Output seed. This is useful when more than one environment is contained in the DataCollector, as the seed will be incremented for each of these. The resulting seed is the seed of the last environment.
Examples
>>> from torchrl.envs import ParallelEnv >>> from torchrl.envs.libs.gym import GymEnv >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> env_fn = lambda: GymEnv("Pendulum-v1") >>> env_fn_parallel = ParallelEnv(6, env_fn) >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) >>> collector = Collector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) >>> out_seed = collector.set_seed(1) # out_seed = 6
- shutdown(timeout: float | None = None, close_env: bool = True, raise_on_error: bool = True) None¶
Shuts down all workers and/or closes the local environment.
- Parameters:
timeout (float, optional) – The timeout for closing pipes between workers. No effect for this class.
close_env (bool, optional) – Whether to close the environment. Defaults to True.
raise_on_error (bool, optional) – Whether to raise an error if the shutdown fails. Defaults to True.
- start()¶
Starts the collector in a separate thread for asynchronous data collection.
The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data collection from training, allowing your training loop to run independently of the data collection process.
- Raises:
RuntimeError – If no replay buffer is defined during the collector’s initialization.
Example
>>> from torchrl.modules import RandomPolicy >>> >>> import time >>> from functools import partial >>> >>> import tqdm >>> >>> from torchrl.collectors import Collector >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py >>> >>> # Set the gym backend to gymnasium >>> set_gym_backend("gymnasium").set() >>> >>> if __name__ == "__main__": ... # Create a random policy for the Pong environment ... env = GymEnv("ALE/Pong-v5") ... policy = RandomPolicy(env.action_spec) ... ... # Initialize a shared replay buffer ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) ... ... # Create a synchronous data collector ... collector = Collector( ... env, ... policy=policy, ... replay_buffer=rb, ... frames_per_batch=256, ... total_frames=-1, ... ) ... ... # Progress bar to track the number of collected frames ... pbar = tqdm.tqdm(total=100_000) ... ... # Start the collector asynchronously ... collector.start() ... ... # Track the write count of the replay buffer ... prec_wc = 0 ... while True: ... wc = rb.write_count ... c = wc - prec_wc ... prec_wc = wc ... ... # Update the progress bar ... pbar.update(c) ... pbar.set_description(f"Write Count: {rb.write_count}") ... ... # Check the write count every 0.5 seconds ... time.sleep(0.5) ... ... # Stop when the desired number of frames is reached ... if rb.write_count . 100_000: ... break ... ... # Shut down the collector ... collector.async_shutdown()
- state_dict() OrderedDict¶
Returns the local state_dict of the data collector (environment and policy).
- Returns:
an ordered dictionary with fields
"policy_state_dict"and “env_state_dict”.
- update_policy_weights_(policy_or_weights: tensordict.base.TensorDictBase | tensordict.nn.common.TensorDictModuleBase | dict | None = None, *, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, **kwargs) None¶
Update policy weights for the data collector.
This method synchronizes the policy weights used by the collector with the latest trained weights. It supports both local and remote weight updates, depending on the collector configuration.
The method accepts weights in multiple forms for convenience:
Examples
>>> # Pass policy module as positional argument >>> collector.update_policy_weights_(policy_module) >>> >>> # Pass TensorDict weights as positional argument >>> collector.update_policy_weights_(weights_tensordict) >>> >>> # Use keyword arguments for clarity >>> collector.update_policy_weights_(weights=weights_td, model_id="actor") >>> collector.update_policy_weights_(policy=actor_module, model_id="actor") >>> >>> # Update multiple models atomically >>> collector.update_policy_weights_(weights_dict={ ... "actor": actor_weights, ... "critic": critic_weights, ... })
- Parameters:
policy_or_weights –
The weights to update with. Can be:
nn.Module: A policy module whose weights will be extractedTensorDictModuleBase: A TensorDict module whose weights will be extractedTensorDictBase: A TensorDict containing weightsdict: A regular dict containing weightsNone: Will try to get weights from server using_get_server_weights()
- Keyword Arguments:
weights – Alternative to positional argument. A TensorDict or dict containing weights to update. Cannot be used together with
policy_or_weightsorpolicy.policy – Alternative to positional argument. An
nn.ModuleorTensorDictModuleBasewhose weights will be extracted. Cannot be used together withpolicy_or_weightsorweights.worker_ids – Identifiers for the workers to update. Relevant when the collector has multiple workers. Can be int, list of ints, device, or list of devices.
model_id – The model identifier to update (default:
"policy"). Cannot be used together withweights_dict.weights_dict – Dictionary mapping model_id to weights for updating multiple models atomically. Keys should match model_ids registered in
weight_sync_schemes. Cannot be used together withmodel_id,policy_or_weights,weights, orpolicy.
- Raises:
TypeError – If
worker_idsis provided but noweight_updateris configured.ValueError – If conflicting parameters are provided.
Note
Users should extend the
WeightUpdaterBaseclasses to customize the weight update logic for specific use cases.See also
LocalWeightsUpdaterBaseandRemoteWeightsUpdaterBase().
- property worker_idx: int | None¶
Get the worker index for this collector.
- Returns:
The worker index (0-indexed).
- Raises:
RuntimeError – If worker_idx has not been set.