Shortcuts

vLLMUpdater

class torchrl.collectors.llm.vLLMUpdater(master_address: str | None = None, master_port: int | None = None, model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None, vllm_tp_size: int | None = None)[source]

A class that sends weights to vLLM workers.

This class handles synchronizing weights between a training policy and vLLM inference workers. It supports both local vLLM instances and remote Ray actors.

Parameters:
  • master_address (str, optional) – The master address for distributed training. Defaults to localhost.

  • master_port (int, optional) – The master port for distributed training. If None, will auto-assign.

  • model_metadata (dict[str, tuple[torch.dtype, torch.Size]], optional) – Model metadata mapping parameter names to their dtype and shape. If not provided, will be extracted from policy.

  • vllm_tp_size (int, optional) – vLLM tensor parallel size. Defaults to 1.

init()[source]

Initialize the updater with model metadata and initialize the group.

_sync_weights_with_worker()[source]

Synchronize weights with a vLLM worker.

_get_server_weights()[source]

Not used - weights must be passed directly.

_maybe_map_weights()[source]

No mapping needed.

all_worker_ids()[source]

Returns [0] since we only have one worker.

Note

This class assumes the policy is a transformers model that can be loaded by vLLM. The policy must have a state_dict() method that returns the model weights.

all_worker_ids() list[int][source]

Returns [0] since we only have one worker.

property collector: torch.collector.DataCollectorBase

The collector or container of the receiver.

Returns None if the container is out-of-scope or not set.

classmethod from_policy(policy: TensorDictModuleBase) WeightUpdaterBase | None

Optional classmethod to create a weight updater instance from a policy.

This method can be implemented by subclasses to provide custom initialization logic based on the policy. If implemented, this method will be called before falling back to the default constructor when initializing a weight updater in a collector.

Parameters:

policy (TensorDictModuleBase) – The policy to create the weight updater from.

Returns:

An instance of the weight updater, or None if the policy

cannot be used to create an instance.

Return type:

WeightUpdaterBase | None

classmethod get_model_metadata(model: TensorDictModuleBase) dict[str, tuple[torch.dtype, torch.Size]][source]

Get the model metadata from a model.

Parameters:

model (TensorDictModuleBase) – The model to get the metadata from. Must be a TransformersWrapper or equivalent.

Returns:

The model metadata.

Return type:

dict[str, tuple[torch.dtype, torch.Size]]

increment_version()

Increment the policy version.

init(model_metadata: dict[str, tuple[torch.dtype, torch.Size]]) None[source]

Initialize the updater with model metadata and initialize the group.

Parameters:

model_metadata (dict[str, tuple[torch.dtype, torch.Size]]) – The model metadata mapping parameter names to their dtype and shape.

property post_hooks: list[Callable[[], NoneType]]

The list of post-hooks registered to the weight updater.

push_weights(policy_or_weights: TensorDictModuleBase | TensorDictBase | dict | None = None, worker_ids: torch.device | int | list[int] | list[torch.device] | None = None)

Updates the weights of the policy, or on specified / all remote workers.

Parameters:
  • policy_or_weights – The source to get weights from. Can be: - TensorDictModuleBase: A policy module whose weights will be extracted - TensorDictBase: A TensorDict containing weights - dict: A regular dict containing weights - None: Will try to get weights from server using _get_server_weights()

  • worker_ids – An optional list of workers to update.

Returns: nothing.

register_collector(collector: DataCollectorBase)[source]

Register a collector in the updater.

Once registered, the updater will not accept another collector.

Parameters:

collector (DataCollectorBase) – The collector to register.

register_post_hook(hook: Callable[[], None])

Registers a post-hook to be called after weights are updated.

Parameters:

hook (Callable[[], None]) – The post-hook to register.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources