InferenceServer¶
- class torchrl.modules.inference_server.InferenceServer(model: Module, transport: InferenceTransport, *, max_batch_size: int = 64, min_batch_size: int = 1, timeout: float = 0.01, collate_fn: Callable | None = None, device: device | str | None = None, weight_sync=None, weight_sync_model_id: str = 'policy')[source]¶
Auto-batching inference server.
Actors submit individual TensorDicts via the transport and receive results asynchronously. A background worker drains the transport queue, batches inputs, runs the model, and fans results back to the callers.
- Parameters:
model (nn.Module or Callable) – a callable that maps a batched TensorDictBase to a batched TensorDictBase (e.g. a
TensorDictModule).transport (InferenceTransport) – the communication backend.
- Keyword Arguments:
max_batch_size (int, optional) – upper bound on the number of requests processed in a single forward pass. Default:
64.min_batch_size (int, optional) – minimum number of requests to accumulate before dispatching a batch. After the first request arrives the server keeps draining for up to
timeoutseconds until at least this many items are collected.1(default) dispatches immediately.timeout (float, optional) – seconds to wait for new work before dispatching a partial batch. Default:
0.01.collate_fn (Callable, optional) – function used to stack a list of TensorDicts into a batch. Default:
lazy_stack().device (torch.device or str, optional) – device to move batches to before calling the model.
Nonemeans no device transfer.weight_sync – an optional
WeightSyncSchemeused to receive updated model weights from a trainer. When set, the server polls for new weights between inference batches.weight_sync_model_id (str, optional) – the model identifier used when initialising the weight sync scheme on the receiver side. Default:
"policy".
Example
>>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.inference_server import ( ... InferenceServer, ... ThreadingTransport, ... ) >>> import torch.nn as nn >>> policy = TensorDictModule( ... nn.Linear(4, 2), in_keys=["obs"], out_keys=["act"] ... ) >>> transport = ThreadingTransport() >>> server = InferenceServer(policy, transport, max_batch_size=8) >>> server.start() >>> client = transport.client() >>> # client(td) can now be called from any thread >>> server.shutdown()
- property is_alive: bool¶
Whether the background worker thread is running.
- shutdown(timeout: float | None = 5.0) None[source]¶
Signal the background worker to stop and wait for it to finish.
- Parameters:
timeout (float or None) – seconds to wait for the worker thread to join.
Nonewaits indefinitely.
- start() InferenceServer[source]¶
Start the background inference loop.
- Returns:
self, for fluent chaining.