Shortcuts

RemotevLLMWrapper

class torchrl.modules.llm.RemotevLLMWrapper(model, max_concurrency: int = 16, validate_model: bool = True, **kwargs)[source]

A remote Ray actor wrapper for vLLMWrapper that provides a simplified interface.

This class wraps a vLLMWrapper instance as a Ray actor, allowing remote execution while providing a clean interface that doesn’t require explicit remote() and get() calls.

Parameters:
  • model (vllm.LLM | str) – The vLLM model to wrap. - If a string, it will be passed to vllm.LLM and downloaded on the remote worker. - If a vLLM LLM object, it must be a remote model with a ray handle (not a local model). Local vLLM models are not serializable and will raise an error.

  • max_concurrency (int, optional) – Maximum number of concurrent calls to the remote actor. Defaults to 16.

  • validate_model (bool, optional) – Whether to validate the model. Defaults to True.

  • **kwargs – All other arguments are passed directly to vLLMWrapper.

Example

>>> import ray
>>> from torchrl.modules.llm.policies import RemotevLLMWrapper
>>>
>>> # Initialize Ray if not already done
>>> if not ray.is_initialized():
...     ray.init()
>>>
>>> # Create remote wrapper
>>> remote_wrapper = RemotevLLMWrapper(
...     model="gpt2",
...     input_mode="history",
...     generate=True,
...     generate_kwargs={"max_new_tokens": 50}
... )
>>>
>>> # Use like a regular wrapper (no remote/get calls needed)
>>> result = remote_wrapper(tensordict_input)
>>> print(result["text"].response)
property batching

Whether batching is enabled.

cleanup_batching()[source]

Clean up batching resources.

property collector

The collector associated with the module.

property device

The device used for computation.

property dist_params_keys

The keys for distribution parameters.

property dist_sample_keys

The keys for distribution samples.

forward(tensordict, **kwargs)[source]

Forward pass that automatically handles remote execution.

property generate

Whether text generation is enabled.

get_batching_state()[source]

Get the current batching state.

get_dist(tensordict, **kwargs)[source]

Get distribution from logits/log-probs with optional masking.

get_dist_with_prompt_mask(tensordict, **kwargs)[source]

Get distribution masked to only include response tokens (exclude prompt).

get_new_version(**kwargs)[source]

Get a new version of the wrapper with altered parameters.

property in_keys

The input keys.

property inplace

Whether in-place operations are used.

property layout

The layout used for output tensors.

log_prob(data, **kwargs)[source]

Compute log probabilities.

property log_prob_keys

The keys for log probabilities.

property log_probs_key

The key for log probabilities output.

property masks_key

The key for masks output.

property num_samples

The number of samples to generate.

property out_keys

The output keys.

property pad_output

Whether output sequences are padded.

property text_key

The key for text output.

property tokens_key

The key for tokens output.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources