Rate this Page

SGLangWeightSender#

class torchrl.weight_update.llm.SGLangWeightSender(scheme: SGLangWeightSyncScheme)[source]#

Sends weights to SGLang workers using NCCL broadcast.

Parameters:

scheme – The SGLangWeightSyncScheme configuration.

flush_cache() bool[source]#

Flush the SGLang server’s radix cache after weight update.

Returns:

True if cache was flushed successfully.

Return type:

bool

init_all_workers_group(model_metadata: dict[str, tuple[dtype, Size]]) None[source]#

Initialize the NCCL communication group.

Parameters:

model_metadata – Dict mapping param names to (dtype, shape) tuples.

register_collector(collector) None[source]#

Register a collector for automatic policy version increment.

After each update_weights() call, collector.increment_version() is called automatically.

register_model(model: Any) None[source]#

Register the model for weight extraction.

Parameters:

model – The PyTorch model to sync weights from.

shutdown() None[source]#

Release resources held by the sender.

update_weights(weights: dict[str, Tensor] | None = None) None[source]#

Broadcast weights to SGLang workers.

Parameters:

weights – Optional dict of weights. If None, extracts from registered model.