SGLangCollectiveTransport¶
- class torchrl.weight_update.llm.SGLangCollectiveTransport(server_url: str, master_address: str, master_port: int, rank: int, world_size: int, device: device | str | int | None = None, timeout: float = 300.0)[source]¶
Transport for SGLang using NCCL collective communication.
This transport coordinates with SGLang servers via HTTP and performs weight transfer via NCCL broadcast.
- Parameters:
server_url – URL of the SGLang server.
master_address – Address for NCCL initialization.
master_port – Port for NCCL initialization.
rank – Rank of this process (0 for trainer).
world_size – Total number of processes.
device – Device to use for communication.
timeout – HTTP request timeout in seconds.
- init_all_workers_group(model_metadata: dict[str, tuple[dtype, Size]]) None[source]¶
Initialize the NCCL communication group.
For the trainer (rank 0), this: 1. Creates a torch.distributed process group via TCP rendezvous (rank 0 is master) 2. Signals the SGLang server via HTTP to create a matching process group 3. Both sides rendezvous via the TCP store and form an NCCL group
The SGLang server uses
init_custom_process_groupinternally which creates atorch.distributedprocess group (not SGLang’s standaloneStatelessProcessGroup+PyNcclCommunicator). The trainer must use the same mechanism so both sides join the same NCCL collective.- Parameters:
model_metadata – Dict mapping param names to (dtype, shape) tuples.
- send_weights(model_id: str, weights: dict[str, Tensor]) None[source]¶
Broadcast weights to SGLang server via NCCL.
SGLang’s
/update_weights_from_distributedendpoint expects a single request with lists of all parameter names, dtypes, and shapes. The server then enters a broadcast-receive loop for each parameter in order. The trainer must broadcast each tensor in the same order, concurrently with the server receiving.- Parameters:
model_id – Identifier for the model (for logging).
weights – Dict mapping parameter names to tensors.