AsyncSGLang¶
- class torchrl.modules.llm.AsyncSGLang(server_url: str | None = None, model_path: str | None = None, tp_size: int = 1, dp_size: int = 1, timeout: float = 300.0, **server_kwargs: Any)[source]¶
Server-based SGLang inference service for TorchRL.
AsyncSGLang provides a unified interface for text generation using SGLang servers, supporting both managed (subprocess) and external server modes. It integrates seamlessly with TorchRL’s RL training workflows through NCCL-based weight synchronization.
- Key Features:
HTTP-based generation via SGLang’s native /generate API
Cache-aware load balancing through SGLang Router
NCCL-based weight synchronization for RL training
Support for both managed and external server modes
Compatible interface with vLLM backends for easy migration
- Parameters:
server_url – URL of an external SGLang server (e.g., “http://localhost:30000”). If None, a managed server will be launched.
model_path – Path or name of the model to load (for managed mode).
tp_size – Tensor parallel size (default: 1).
dp_size – Data parallel size (default: 1).
timeout – Request timeout in seconds (default: 300).
**server_kwargs – Additional arguments passed to SGLang server launch.
Examples
>>> # Connect to an existing SGLang server >>> service = AsyncSGLang.connect("http://localhost:30000") >>> result = service.generate("Hello, world!") >>> >>> # Launch a managed SGLang server >>> service = AsyncSGLang.from_pretrained("Qwen/Qwen2.5-3B") >>> result = service.generate("Hello, world!") >>> >>> # With custom parameters >>> service = AsyncSGLang.from_pretrained( ... "Qwen/Qwen2.5-7B", ... tp_size=2, ... max_model_len=4096 ... )
Note
For RL training with weight updates, use the weight synchronization methods after initializing the NCCL communication group.
- classmethod connect(server_url: str) AsyncSGLang[source]¶
Connect to an existing SGLang server.
- Parameters:
server_url – URL of the SGLang server (e.g., “http://localhost:30000”)
- Returns:
Connected service instance
- Return type:
- Raises:
ConnectionError – If the server is not reachable
- flush_cache() bool[source]¶
Flush the radix cache on the server.
This is automatically triggered when weights are updated.
- Returns:
True if cache was flushed successfully
- Return type:
bool
- classmethod from_pretrained(model_name: str, tp_size: int = 1, dp_size: int = 1, **kwargs: Any) AsyncSGLang[source]¶
Create an AsyncSGLang instance by launching a managed server.
- Parameters:
model_name – Model name or path to load
tp_size – Tensor parallel size
dp_size – Data parallel size
**kwargs – Additional server arguments
- Returns:
Service with managed server
- Return type:
Example
>>> service = AsyncSGLang.from_pretrained( ... "Qwen/Qwen2.5-3B", ... tp_size=2, ... max_model_len=4096 ... )
- generate(prompts: str | list[str] | None = None, sampling_params: dict[str, Any] | None = None, *, input_ids: list[int] | list[list[int]] | None = None, return_logprobs: bool = False, return_text: bool = True, timeout: float | None = None, **kwargs: Any) dict[str, Any] | list[dict[str, Any]][source]¶
Generate text completions from text prompts or token IDs.
You can provide either prompts (text) OR input_ids (tokens), but not both.
- Parameters:
prompts – Input text prompt(s) for generation. Mutually exclusive with input_ids.
sampling_params – Sampling parameters (temperature, top_p, max_tokens, etc.)
input_ids – Input token ID(s) for generation. Can be a single list of ints or a list of lists for batch generation. Mutually exclusive with prompts.
return_logprobs – Whether to return log probabilities
return_text – Whether to return generated text
timeout – Request timeout in seconds
**kwargs – Additional sampling parameters (temperature, max_new_tokens, etc.) These are merged into sampling_params for convenience.
- Returns:
Generation results with ‘text’, ‘output_ids’, ‘meta_info’
- Return type:
dict or list[dict]
Example
>>> # Generate from text >>> result = service.generate( ... "What is the capital of France?", ... {"temperature": 0.7, "max_tokens": 100} ... ) >>> print(result["text"])
>>> # Generate from token IDs >>> result = service.generate( ... input_ids=[1, 2, 3, 4], ... sampling_params={"max_tokens": 50} ... ) >>> print(result["output_ids"])
>>> # Using kwargs for sampling params >>> result = service.generate("Hello", max_new_tokens=50, temperature=0.7)
- generate_batch(prompts: list[str], sampling_params: dict[str, Any] | None = None, **kwargs: Any) list[dict[str, Any]][source]¶
Generate text completions for a batch of prompts.
This is an alias for generate() with a list of prompts.
- Parameters:
prompts – List of input prompts
sampling_params – Sampling parameters
**kwargs – Additional arguments passed to generate()
- Returns:
List of generation results
- Return type:
list[dict]
- get_model_metadata() dict[str, tuple[dtype, Size]][source]¶
Get model parameter metadata.
Note: This requires fetching from the server. For now, returns empty dict and expects metadata to be provided externally.
- init_weight_update_group(master_address: str | None = None, master_port: int | None = None) None[source]¶
Initialize the NCCL weight update group via SGLang’s HTTP API.
This calls the SGLang server’s /init_weights_update_group endpoint to set up NCCL communication for weight synchronization.
- Parameters:
master_address – Master address for NCCL (default: “localhost”)
master_port – Master port for NCCL (auto-assigned if None)
- property server_url: str¶
Get the server URL.
- update_weights(weights: Iterator[tuple[str, Tensor]]) None[source]¶
Update model weights via NCCL broadcast.
This method coordinates with the SGLang server to broadcast weights from the trainer (rank 0) to all workers.
- Parameters:
weights – Iterator yielding (parameter_name, tensor) tuples
- update_weights_from_distributed(name: str, dtype: dtype, shape: tuple[int, ...]) None[source]¶
Signal the server to receive a weight update via NCCL broadcast.
This calls SGLang’s /update_weights_from_distributed endpoint to coordinate weight reception.
- Parameters:
name – Name of the parameter to update
dtype – Data type of the tensor
shape – Shape of the tensor