Shortcuts

Inference Server

The inference server provides auto-batching model serving for RL actors. Multiple actors submit individual TensorDicts; the server transparently batches them, runs a single model forward pass, and routes results back.

Core API

InferenceServer(model, transport, *[, ...])

Auto-batching inference server.

InferenceClient(transport)

Actor-side handle for an InferenceServer.

InferenceTransport()

Abstract base class for inference server transport backends.

Transport Backends

ThreadingTransport()

In-process transport for actors that are threads.

SlotTransport(num_slots, *[, preallocate])

Lock-free, in-process transport using per-env slots.

MPTransport([ctx])

Cross-process transport using multiprocessing queues.

RayTransport(*[, max_queue_size])

Transport using Ray queues for distributed inference.

MonarchTransport(*[, max_queue_size])

Transport using Monarch for distributed inference on GPU clusters.

Usage

The simplest setup uses ThreadingTransport for actors that are threads in the same process:

from tensordict.nn import TensorDictModule
from torchrl.modules.inference_server import (
    InferenceServer,
    ThreadingTransport,
)
import torch.nn as nn
import concurrent.futures

policy = TensorDictModule(
    nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)),
    in_keys=["observation"],
    out_keys=["action"],
)

transport = ThreadingTransport()
server = InferenceServer(policy, transport, max_batch_size=32)
server.start()
client = transport.client()

# actor threads call client(td) -- batched automatically
with concurrent.futures.ThreadPoolExecutor(16) as pool:
    ...

server.shutdown()

Weight Synchronisation

The server integrates with WeightSyncScheme to receive updated model weights from a trainer between inference batches:

from torchrl.weight_update import SharedMemWeightSyncScheme

weight_sync = SharedMemWeightSyncScheme()
# Initialise on the trainer (sender) side first
weight_sync.init_on_sender(model=training_model, ...)

server = InferenceServer(
    model=inference_model,
    transport=ThreadingTransport(),
    weight_sync=weight_sync,
)
server.start()

# Training loop
for batch in dataloader:
    loss = loss_fn(training_model(batch))
    loss.backward()
    optimizer.step()
    weight_sync.send(model=training_model)  # pushed to server

Integration with Collectors

The easiest way to use the inference server with RL data collection is through AsyncBatchedCollector, which creates the server, transport, and env pool automatically:

from torchrl.collectors import AsyncBatchedCollector
from torchrl.envs import GymEnv

collector = AsyncBatchedCollector(
    create_env_fn=[lambda: GymEnv("CartPole-v1")] * 8,
    policy=my_policy,
    frames_per_batch=200,
    total_frames=10_000,
    max_batch_size=8,
)

for data in collector:
    # train on data ...
    pass

collector.shutdown()

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources