VLLMDoubleBufferWeightSender¶
- class torchrl.weight_update.llm.VLLMDoubleBufferWeightSender(scheme: VLLMDoubleBufferSyncScheme)[source]¶
Sends weights to vLLM workers using double-buffered storage.
This sender extracts weights from a training model and writes them to a shared directory using TensorDict.memmap.
Example
>>> sender = scheme.create_sender() >>> sender.register_model(policy_model) >>> >>> # During training loop >>> sender.update_weights() # Writes current weights to shared storage
- register_model(model: Any) None[source]¶
Register the model to extract weights from.
- Parameters:
model – The model to extract weights from (e.g., TransformersWrapper).
- send(weights: Any = None, worker_ids: int | list[int] | None = None) None¶
Send weights synchronously to workers.
This method: 1. Prepares weights (extracts from model if weights=None) 2. Sends to specified workers (or all if worker_ids=None) 3. Waits for acknowledgments from those workers 4. Returns when workers have applied the weights
- Parameters:
weights – Weights to send. Can be: - None: Extract from model via context.get_model(model_id) - nn.Module: Extract weights from module - TensorDict: Use directly - dict: Convert to TensorDict
worker_ids – Which workers to send to: - None: Send to all workers (default) - int: Send to single worker - list[int]: Send to specific workers
Note: This is a blocking call that ensures specified workers are updated before returning.
- send_async(weights: Any = None, worker_ids: int | list[int] | None = None) None¶
Send weights asynchronously to workers (non-blocking).
This initiates the send but returns immediately without waiting for workers to acknowledge. You must call wait_async() before the next send_async() or send() call.
- Parameters:
weights – Same as send()
worker_ids – Same as send()
- Raises:
RuntimeError – If a previous send_async() is still pending
- update_weights(weights: Any | None = None) None[source]¶
Extract and write weights to shared storage.
- Parameters:
weights – Optional weights to send. If None, extracts from registered model.
- wait_async() None¶
Wait for a pending async send to complete.
Blocks until all workers have acknowledged the previous send_async(). This must be called after send_async() before any subsequent sends.
- Raises:
RuntimeError – If no async send is pending