Rate this Page

SGLangCollectiveTransport#

class torchrl.weight_update.llm.SGLangCollectiveTransport(server_url: str, master_address: str, master_port: int, rank: int, world_size: int, device: device | str | int | None = None, timeout: float = 300.0)[source]#

Transport for SGLang using NCCL collective communication.

This transport coordinates with SGLang servers via HTTP and performs weight transfer via NCCL broadcast.

Parameters:
  • server_url – URL of the SGLang server.

  • master_address – Address for NCCL initialization.

  • master_port – Port for NCCL initialization.

  • rank – Rank of this process (0 for trainer).

  • world_size – Total number of processes.

  • device – Device to use for communication.

  • timeout – HTTP request timeout in seconds.

check_connection() bool[source]#

Check if the communication group is initialized.

init_all_workers_group(model_metadata: dict[str, tuple[dtype, Size]]) None[source]#

Initialize the NCCL communication group.

For the trainer (rank 0), this: 1. Creates a torch.distributed process group via TCP rendezvous (rank 0 is master) 2. Signals the SGLang server via HTTP to create a matching process group 3. Both sides rendezvous via the TCP store and form an NCCL group

The SGLang server uses init_custom_process_group internally which creates a torch.distributed process group (not SGLang’s standalone StatelessProcessGroup + PyNcclCommunicator). The trainer must use the same mechanism so both sides join the same NCCL collective.

Parameters:

model_metadata – Dict mapping param names to (dtype, shape) tuples.

send_weights(model_id: str, weights: dict[str, Tensor]) None[source]#

Broadcast weights to SGLang server via NCCL.

SGLang’s /update_weights_from_distributed endpoint expects a single request with lists of all parameter names, dtypes, and shapes. The server then enters a broadcast-receive loop for each parameter in order. The trainer must broadcast each tensor in the same order, concurrently with the server receiving.

Parameters:
  • model_id – Identifier for the model (for logging).

  • weights – Dict mapping parameter names to tensors.

shutdown() None[source]#

Release trainer-side resources used for weight synchronization.