vLLMUpdater¶
- class torchrl.collectors.llm.vLLMUpdater(master_address: str | None = None, master_port: int | None = None, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None, vllm_tp_size: int | None = None)[source]¶
A class that sends weights to vLLM workers.
This class handles synchronizing weights between a training policy and vLLM inference workers. It supports both local vLLM instances and remote Ray actors.
- Parameters:
master_address (str, optional) – The master address for distributed training. Defaults to localhost.
master_port (int, optional) – The master port for distributed training. If None, will auto-assign.
model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional) – Model metadata mapping parameter names to their dtype and shape. If not provided, will be extracted from policy.
vllm_tp_size (int, optional) – vLLM tensor parallel size. Defaults to 1.
Note
This class assumes the policy is a transformers model that can be loaded by vLLM. The policy must have a state_dict() method that returns the model weights.
- property collector: torch.collector.DataCollectorBase¶
The collector or container of the receiver.
Returns None if the container is out-of-scope or not set.
- classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None ¶
Optional classmethod to create a weight updater instance from a policy.
This method can be implemented by subclasses to provide custom initialization logic based on the policy. If implemented, this method will be called before falling back to the default constructor when initializing a weight updater in a collector.
- Parameters:
policy (TensorDictModuleBase) – The policy to create the weight updater from.
- Returns:
- An instance of the weight updater, or None if the policy
cannot be used to create an instance.
- Return type:
WeightUpdaterBase | None
- classmethod get_model_metadata(model: TensorDictModuleBase) dict[str, tuple[torch.dtype, torch.Size]] [source]¶
Get the model metadata from a model.
- Parameters:
model (TensorDictModuleBase) – The model to get the metadata from. Must be a TransformersWrapper or equivalent.
- Returns:
The model metadata.
- Return type:
dict[str, tuple[torch.dtype, torch.Size]]
- increment_version()¶
Increment the policy version.
- init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) None [source]¶
Initialize the updater with model metadata and initialize the group.
- Parameters:
model_metadata (dict[str, tuple[torch.dtype, torch.Size]]) – The model metadata mapping parameter names to their dtype and shape.
- property post_hooks: list[Callable[[], NoneType]]¶
The list of post-hooks registered to the weight updater.
- push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)¶
Updates the weights of the policy, or on specified / all remote workers.
- Parameters:
policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()
worker_ids – An optional list of workers to update.
Returns: nothing.
- register_collector(collector: DataCollectorBase)[source]¶
Register a collector in the updater.
Once registered, the updater will not accept another collector.
- Parameters:
collector (DataCollectorBase) – The collector to register.
- register_post_hook(hook: Callable[[], None])¶
Registers a post-hook to be called after weights are updated.
- Parameters:
hook (Callable[[], None]) – The post-hook to register.