Shortcuts

Source code for torchrl.modules.inference_server._server

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import threading
import time
from collections.abc import Callable
from concurrent.futures import Future

import torch
from tensordict import lazy_stack
from tensordict.base import TensorDictBase
from torch import nn

from torchrl.modules.inference_server._transport import InferenceTransport


[docs] class InferenceServer: """Auto-batching inference server. Actors submit individual TensorDicts via the *transport* and receive results asynchronously. A background worker drains the transport queue, batches inputs, runs the model, and fans results back to the callers. Args: model (nn.Module or Callable): a callable that maps a batched TensorDictBase to a batched TensorDictBase (e.g. a :class:`~tensordict.nn.TensorDictModule`). transport (InferenceTransport): the communication backend. Keyword Args: max_batch_size (int, optional): upper bound on the number of requests processed in a single forward pass. Default: ``64``. min_batch_size (int, optional): minimum number of requests to accumulate before dispatching a batch. After the first request arrives the server keeps draining for up to ``timeout`` seconds until at least this many items are collected. ``1`` (default) dispatches immediately. timeout (float, optional): seconds to wait for new work before dispatching a partial batch. Default: ``0.01``. collate_fn (Callable, optional): function used to stack a list of TensorDicts into a batch. Default: :func:`~tensordict.lazy_stack`. device (torch.device or str, optional): device to move batches to before calling the model. ``None`` means no device transfer. weight_sync: an optional :class:`~torchrl.weight_update.WeightSyncScheme` used to receive updated model weights from a trainer. When set, the server polls for new weights between inference batches. weight_sync_model_id (str, optional): the model identifier used when initialising the weight sync scheme on the receiver side. Default: ``"policy"``. Example: >>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.inference_server import ( ... InferenceServer, ... ThreadingTransport, ... ) >>> import torch.nn as nn >>> policy = TensorDictModule( ... nn.Linear(4, 2), in_keys=["obs"], out_keys=["act"] ... ) >>> transport = ThreadingTransport() >>> server = InferenceServer(policy, transport, max_batch_size=8) >>> server.start() >>> client = transport.client() >>> # client(td) can now be called from any thread >>> server.shutdown() """ def __init__( self, model: nn.Module, transport: InferenceTransport, *, max_batch_size: int = 64, min_batch_size: int = 1, timeout: float = 0.01, collate_fn: Callable | None = None, device: torch.device | str | None = None, weight_sync=None, weight_sync_model_id: str = "policy", ): self.model = model self.transport = transport self.max_batch_size = max_batch_size self.min_batch_size = min_batch_size self.timeout = timeout self.collate_fn = collate_fn if collate_fn is not None else lazy_stack self.device = torch.device(device) if device is not None else None self.weight_sync = weight_sync self._weight_sync_model_id = weight_sync_model_id self._shutdown_event = threading.Event() self._worker: threading.Thread | None = None # Protects model access during weight updates self._model_lock = threading.Lock() # -- lifecycle ------------------------------------------------------------
[docs] def start(self) -> InferenceServer: """Start the background inference loop. Returns: self, for fluent chaining. """ if self._worker is not None and self._worker.is_alive(): raise RuntimeError("Server is already running.") self._shutdown_event.clear() self._worker = threading.Thread( target=self._run, daemon=True, name="InferenceServer-worker" ) self._worker.start() return self
[docs] def shutdown(self, timeout: float | None = 5.0) -> None: """Signal the background worker to stop and wait for it to finish. Args: timeout (float or None): seconds to wait for the worker thread to join. ``None`` waits indefinitely. """ self._shutdown_event.set() if self._worker is not None: self._worker.join(timeout=timeout) self._worker = None
@property def is_alive(self) -> bool: """Whether the background worker thread is running.""" return self._worker is not None and self._worker.is_alive() # -- background loop ------------------------------------------------------ def _init_weight_sync(self) -> None: """Initialise the weight sync scheme on the receiver (server) side.""" ws = self.weight_sync if ws is None: return if not ws.initialized_on_receiver: ws.init_on_receiver( model_id=self._weight_sync_model_id, model=self.model, worker_idx=0, ) if not ws.synchronized_on_receiver: ws.connect(worker_idx=0) def _poll_weight_update(self) -> None: """Non-blocking check for fresh weights from the trainer.""" ws = self.weight_sync if ws is None: return with self._model_lock: ws.receive(timeout=0.0) @torch.no_grad() def _run(self) -> None: self._init_weight_sync() try: while not self._shutdown_event.is_set(): self._poll_weight_update() self.transport.wait_for_work(timeout=self.timeout) items, callbacks = self.transport.drain(self.max_batch_size) if not items: continue # Accumulate up to min_batch_size (or until timeout expires) if len(items) < self.min_batch_size: deadline = time.monotonic() + self.timeout while len(items) < self.min_batch_size: remaining = deadline - time.monotonic() if remaining <= 0: break self.transport.wait_for_work(timeout=remaining) more_items, more_cbs = self.transport.drain( self.max_batch_size - len(items) ) items.extend(more_items) callbacks.extend(more_cbs) batch = self.collate_fn(items) if self.device is not None: batch = batch.to(self.device) try: with self._model_lock: results = self.model(batch).unbind(0) if len(results) != len(callbacks): raise RuntimeError( f"Model returned {len(results)} results for a " f"batch of {len(callbacks)} inputs." ) for cb, res in zip(callbacks, results): self.transport.resolve(cb, res) except Exception as exc: for cb in callbacks: self.transport.resolve_exception(cb, exc) finally: self._drain_pending_on_shutdown() def _drain_pending_on_shutdown(self) -> None: """Resolve all pending requests with an error during shutdown.""" shutdown_exc = RuntimeError("InferenceServer is shutting down.") while True: items, callbacks = self.transport.drain(self.max_batch_size) if not items: break for cb in callbacks: self.transport.resolve_exception(cb, shutdown_exc) # -- context manager ------------------------------------------------------ def __enter__(self) -> InferenceServer: return self.start() def __exit__(self, *exc_info) -> None: self.shutdown() def __del__(self) -> None: if self._worker is not None and self._worker.is_alive(): self.shutdown(timeout=1.0)
[docs] class InferenceClient: """Actor-side handle for an :class:`InferenceServer`. Wraps a transport's :meth:`~InferenceTransport.submit` so that calling ``client(td)`` looks like a regular synchronous policy call, while the actual computation is batched on the server. Args: transport (InferenceTransport): the transport shared with the server. Example: >>> client = transport.client() >>> td_out = client(td_in) # blocking >>> future = client.submit(td_in) # non-blocking >>> td_out = future.result() """ def __init__(self, transport: InferenceTransport): self._transport = transport def __call__(self, td: TensorDictBase) -> TensorDictBase: """Submit a request and block until the result is ready.""" return self._transport.submit(td).result()
[docs] def submit(self, td: TensorDictBase) -> Future[TensorDictBase]: """Submit a request and return a Future immediately.""" return self._transport.submit(td)

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources