VLLMWeightSender¶
- class torchrl.weight_update.llm.VLLMWeightSender(scheme: VLLMWeightSyncScheme)[source]¶
Sends weights to vLLM workers using collective communication.
RPC + Collective Implementation
This class implements both layers:
RPC Layer: Currently uses Ray remote calls (implicit in test setup) - Can be extended to other RPC backends (torch.distributed.rpc, gRPC) - In the test, Ray actors provide the RPC mechanism
Collective Layer: Uses VLLMCollectiveTransport for NCCL broadcast - Broadcasts weights from trainer (rank 0) to workers (ranks 1+) - High-bandwidth GPU-to-GPU transfer
Extending RPC Backends
To use a different RPC backend, subclass and override coordination:
class TorchRPCVLLMSender(VLLMWeightSender): def update_weights(self, weights=None): # Custom RPC: Signal workers to prepare for worker in self.workers: torch.distributed.rpc.rpc_async(worker, "prepare_receive") # Then do collective (unchanged) super().update_weights(weights)
- init_all_workers_group(model_metadata: dict[str, tuple[torch.dtype, torch.Size]], vllm_engine: Any | None = None)[source]¶
Initialize the collective communication group.
- Parameters:
model_metadata – Dict mapping param names to (dtype, shape) tuples.
vllm_engine – Optional vLLM engine for RPC coordination. Required for NCCL broadcasts.