Inference Server¶
The inference server provides auto-batching model serving for RL actors. Multiple actors submit individual TensorDicts; the server transparently batches them, runs a single model forward pass, and routes results back.
Core API¶
|
Auto-batching inference server. |
|
Actor-side handle for an |
Abstract base class for inference server transport backends. |
Transport Backends¶
In-process transport for actors that are threads. |
|
|
Lock-free, in-process transport using per-env slots. |
|
Cross-process transport using |
|
Transport using Ray queues for distributed inference. |
|
Transport using Monarch for distributed inference on GPU clusters. |
Usage¶
The simplest setup uses ThreadingTransport for actors that are
threads in the same process:
from tensordict.nn import TensorDictModule
from torchrl.modules.inference_server import (
InferenceServer,
ThreadingTransport,
)
import torch.nn as nn
import concurrent.futures
policy = TensorDictModule(
nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)),
in_keys=["observation"],
out_keys=["action"],
)
transport = ThreadingTransport()
server = InferenceServer(policy, transport, max_batch_size=32)
server.start()
client = transport.client()
# actor threads call client(td) -- batched automatically
with concurrent.futures.ThreadPoolExecutor(16) as pool:
...
server.shutdown()
Weight Synchronisation¶
The server integrates with WeightSyncScheme
to receive updated model weights from a trainer between inference batches:
from torchrl.weight_update import SharedMemWeightSyncScheme
weight_sync = SharedMemWeightSyncScheme()
# Initialise on the trainer (sender) side first
weight_sync.init_on_sender(model=training_model, ...)
server = InferenceServer(
model=inference_model,
transport=ThreadingTransport(),
weight_sync=weight_sync,
)
server.start()
# Training loop
for batch in dataloader:
loss = loss_fn(training_model(batch))
loss.backward()
optimizer.step()
weight_sync.send(model=training_model) # pushed to server
Integration with Collectors¶
The easiest way to use the inference server with RL data collection is
through AsyncBatchedCollector, which
creates the server, transport, and env pool automatically:
from torchrl.collectors import AsyncBatchedCollector
from torchrl.envs import GymEnv
collector = AsyncBatchedCollector(
create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
policy=my_policy,
frames_per_batch=200,
total_frames=10_000,
max_batch_size=8,
)
for data in collector:
# train on data ...
pass
collector.shutdown()