Shortcuts

AsyncSGLang

class torchrl.modules.llm.AsyncSGLang(server_url: str | None = None, model_path: str | None = None, tp_size: int = 1, dp_size: int = 1, timeout: float = 300.0, **server_kwargs: Any)[source]

Server-based SGLang inference service for TorchRL.

AsyncSGLang provides a unified interface for text generation using SGLang servers, supporting both managed (subprocess) and external server modes. It integrates seamlessly with TorchRL’s RL training workflows through NCCL-based weight synchronization.

Key Features:
  • HTTP-based generation via SGLang’s native /generate API

  • Cache-aware load balancing through SGLang Router

  • NCCL-based weight synchronization for RL training

  • Support for both managed and external server modes

  • Compatible interface with vLLM backends for easy migration

Parameters:
  • server_url – URL of an external SGLang server (e.g., “http://localhost:30000”). If None, a managed server will be launched.

  • model_path – Path or name of the model to load (for managed mode).

  • tp_size – Tensor parallel size (default: 1).

  • dp_size – Data parallel size (default: 1).

  • timeout – Request timeout in seconds (default: 300).

  • **server_kwargs – Additional arguments passed to SGLang server launch.

Examples

>>> # Connect to an existing SGLang server
>>> service = AsyncSGLang.connect("http://localhost:30000")
>>> result = service.generate("Hello, world!")
>>>
>>> # Launch a managed SGLang server
>>> service = AsyncSGLang.from_pretrained("Qwen/Qwen2.5-3B")
>>> result = service.generate("Hello, world!")
>>>
>>> # With custom parameters
>>> service = AsyncSGLang.from_pretrained(
...     "Qwen/Qwen2.5-7B",
...     tp_size=2,
...     max_model_len=4096
... )

Note

For RL training with weight updates, use the weight synchronization methods after initializing the NCCL communication group.

classmethod connect(server_url: str) AsyncSGLang[source]

Connect to an existing SGLang server.

Parameters:

server_url – URL of the SGLang server (e.g., “http://localhost:30000”)

Returns:

Connected service instance

Return type:

AsyncSGLang

Raises:

ConnectionError – If the server is not reachable

flush_cache() bool[source]

Flush the radix cache on the server.

This is automatically triggered when weights are updated.

Returns:

True if cache was flushed successfully

Return type:

bool

classmethod from_pretrained(model_name: str, tp_size: int = 1, dp_size: int = 1, **kwargs: Any) AsyncSGLang[source]

Create an AsyncSGLang instance by launching a managed server.

Parameters:
  • model_name – Model name or path to load

  • tp_size – Tensor parallel size

  • dp_size – Data parallel size

  • **kwargs – Additional server arguments

Returns:

Service with managed server

Return type:

AsyncSGLang

Example

>>> service = AsyncSGLang.from_pretrained(
...     "Qwen/Qwen2.5-3B",
...     tp_size=2,
...     max_model_len=4096
... )
generate(prompts: str | list[str] | None = None, sampling_params: dict[str, Any] | None = None, *, input_ids: list[int] | list[list[int]] | None = None, return_logprobs: bool = False, return_text: bool = True, timeout: float | None = None, **kwargs: Any) dict[str, Any] | list[dict[str, Any]][source]

Generate text completions from text prompts or token IDs.

You can provide either prompts (text) OR input_ids (tokens), but not both.

Parameters:
  • prompts – Input text prompt(s) for generation. Mutually exclusive with input_ids.

  • sampling_params – Sampling parameters (temperature, top_p, max_tokens, etc.)

  • input_ids – Input token ID(s) for generation. Can be a single list of ints or a list of lists for batch generation. Mutually exclusive with prompts.

  • return_logprobs – Whether to return log probabilities

  • return_text – Whether to return generated text

  • timeout – Request timeout in seconds

  • **kwargs – Additional sampling parameters (temperature, max_new_tokens, etc.) These are merged into sampling_params for convenience.

Returns:

Generation results with ‘text’, ‘output_ids’, ‘meta_info’

Return type:

dict or list[dict]

Example

>>> # Generate from text
>>> result = service.generate(
...     "What is the capital of France?",
...     {"temperature": 0.7, "max_tokens": 100}
... )
>>> print(result["text"])
>>> # Generate from token IDs
>>> result = service.generate(
...     input_ids=[1, 2, 3, 4],
...     sampling_params={"max_tokens": 50}
... )
>>> print(result["output_ids"])
>>> # Using kwargs for sampling params
>>> result = service.generate("Hello", max_new_tokens=50, temperature=0.7)
generate_batch(prompts: list[str], sampling_params: dict[str, Any] | None = None, **kwargs: Any) list[dict[str, Any]][source]

Generate text completions for a batch of prompts.

This is an alias for generate() with a list of prompts.

Parameters:
  • prompts – List of input prompts

  • sampling_params – Sampling parameters

  • **kwargs – Additional arguments passed to generate()

Returns:

List of generation results

Return type:

list[dict]

get_dp_size() int[source]

Get the data parallel size.

get_master_address() str[source]

Get the master address for weight synchronization.

get_master_port() int[source]

Get the master port for weight synchronization.

get_model_metadata() dict[str, tuple[dtype, Size]][source]

Get model parameter metadata.

Note: This requires fetching from the server. For now, returns empty dict and expects metadata to be provided externally.

get_tp_size() int[source]

Get the tensor parallel size.

init_weight_update_group(master_address: str | None = None, master_port: int | None = None) None[source]

Initialize the NCCL weight update group via SGLang’s HTTP API.

This calls the SGLang server’s /init_weights_update_group endpoint to set up NCCL communication for weight synchronization.

Parameters:
  • master_address – Master address for NCCL (default: “localhost”)

  • master_port – Master port for NCCL (auto-assigned if None)

property server_url: str

Get the server URL.

shutdown() None[source]

Shutdown the managed SGLang server if running.

update_weights(weights: Iterator[tuple[str, Tensor]]) None[source]

Update model weights via NCCL broadcast.

This method coordinates with the SGLang server to broadcast weights from the trainer (rank 0) to all workers.

Parameters:

weights – Iterator yielding (parameter_name, tensor) tuples

update_weights_from_distributed(name: str, dtype: dtype, shape: tuple[int, ...]) None[source]

Signal the server to receive a weight update via NCCL broadcast.

This calls SGLang’s /update_weights_from_distributed endpoint to coordinate weight reception.

Parameters:
  • name – Name of the parameter to update

  • dtype – Data type of the tensor

  • shape – Shape of the tensor

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources