TensorDict in distributed settings¶
TensorDict integrates with torch.distributed to send, receive,
broadcast, reduce, gather and scatter collections of tensors with a single
method call. It also supports shared-storage workflows where nodes access
the same memory-mapped files on a shared filesystem.
Prerequisites¶
All collective and point-to-point APIs below require a
torch.distributed.init_process_group() to be set up. Refer to the
PyTorch distributed documentation
for details.
import torch.distributed as dist
dist.init_process_group("nccl", rank=rank, world_size=world_size)
consolidate() and fast serialization¶
Several distributed primitives (broadcast, all_gather, scatter,
init_remote / from_remote_init, and consolidated send / recv) rely on
consolidate(), which packs every leaf tensor
into a single contiguous storage buffer. This turns an arbitrarily nested
TensorDict into one message on the wire:
td = TensorDict(a=torch.randn(3), b={"c": torch.randn(3)}, batch_size=[3])
td_c = td.consolidate(metadata=True)
# td_c._consolidated["storage"] – a single flat uint8 tensor
# td_c._consolidated["metadata"] – a dict describing how to reconstruct the tree
Pass metadata=True so that the receiving side can reconstruct the full
tree structure.
Point-to-point: send / recv¶
The simplest pattern sends a TensorDict from one rank to another. The receiver must already hold a TensorDict with matching structure so that incoming tensors can be written in-place:
# rank 0
td = TensorDict(a=torch.randn(4), b=torch.randn(4, 3), batch_size=[4])
td.send(dst=1)
# rank 1
td = TensorDict(a=torch.zeros(4), b=torch.zeros(4, 3), batch_size=[4])
td.recv(src=0)
assert (td != 0).all()
Consolidated mode sends the entire TensorDict as a single tensor message instead of one message per leaf. This is more efficient when the TensorDict has many small tensors:
# rank 0 – sender
td.send(dst=1, consolidated=True)
# rank 1 – receiver (must already be consolidated)
td.recv(src=0, consolidated=True)
Async point-to-point: isend / irecv¶
Non-blocking variants let you overlap communication with computation:
# rank 0
td.isend(dst=1)
# rank 1
futures = td.irecv(src=0, return_premature=True)
# ... do other work ...
for f in futures:
f.wait()
Cold-start initialization: init_remote / from_remote_init¶
When the receiving rank does not know the structure of the TensorDict in
advance, init_remote sends both the metadata and the content so that
the receiver can reconstruct the full object from scratch:
# rank 0 – sender
td = TensorDict(
{"obs": torch.randn(2, 84), "reward": torch.randn(2, 1)},
batch_size=[2],
)
td.init_remote(dst=1)
# rank 1 – receiver
td = TensorDict.from_remote_init(src=0)
print(td) # fully reconstructed, including nested structure
This is handy during setup / warm-up: the first call uses init_remote
/ from_remote_init to establish the schema, and subsequent transfers
use the faster send(consolidated=True) / recv(consolidated=True)
path.
Two transport modes are available:
Point-to-point (default): only sender and receiver participate.
Broadcast (
use_broadcast=True): delegates tobroadcast(). All ranks in the group must participate.
# all ranks
if rank == 0:
td.init_remote(use_broadcast=True)
else:
td = TensorDict.from_remote_init(src=0, use_broadcast=True)
@tensorclass objects are preserved across the wire: the receiving side
will automatically reconstruct the original tensorclass type.
See init_remote() and
from_remote_init() for the full API.
Broadcast¶
broadcast() sends a consolidated
TensorDict from one rank to all ranks:
if rank == src:
td = TensorDict(a=torch.randn(3), batch_size=[3])
else:
td = TensorDict()
td = td.broadcast(src=src)
# all ranks now hold the same td
Internally this broadcasts the metadata via
dist.broadcast_object_list followed by the storage via
dist.broadcast.
See broadcast().
All-reduce¶
all_reduce() reduces every leaf tensor
across all ranks in-place:
import torch.distributed as dist
td = TensorDict(a=torch.ones(3) * rank, batch_size=[3])
td.all_reduce(op=dist.ReduceOp.SUM)
# td["a"] is now the sum across all ranks
Supports async_op=True to return a list of futures.
See all_reduce().
Reduce¶
reduce() is like all_reduce but only
the destination rank receives the result:
td.reduce(dst=0, op=dist.ReduceOp.SUM)
See reduce().
All-gather¶
all_gather() gathers TensorDict instances
from every rank and returns a list (one per rank):
local_td = TensorDict(a=torch.randn(3) + rank, batch_size=[3])
all_tds = local_td.all_gather()
# len(all_tds) == world_size
Uses consolidated transport: each rank consolidates its TensorDict,
metadata is gathered with dist.all_gather_object, and storage buffers
are gathered with dist.all_gather.
See all_gather().
Scatter¶
scatter() distributes different
TensorDict instances from one rank to all other ranks:
if rank == src:
tds = [TensorDict(a=torch.randn(3) + i, batch_size=[3]) for i in range(world_size)]
else:
tds = None
td = TensorDict().scatter(src=src, tensordicts=tds)
# each rank receives its own td
See scatter().
Quick reference¶
Method |
Transport |
Participants |
Use case |
|---|---|---|---|
point-to-point |
2 ranks |
Steady-state data transfer (supports |
|
point-to-point (async) |
2 ranks |
Overlapping communication and computation |
|
point-to-point or broadcast |
2 ranks (or all) |
Cold-start: receiver doesn’t know the schema |
|
collective |
all ranks |
One-to-all (same data) |
|
collective |
all ranks |
One-to-all (different data per rank) |
|
collective |
all ranks |
All-to-all gather |
|
collective |
all ranks |
In-place reduce (e.g. gradient averaging) |
|
collective |
all ranks |
Reduce to single rank |
High-performance streaming with UCXX¶
The torch.distributed primitives above operate within a process group
and are ideal for collective communication patterns. For persistent
point-to-point streaming — such as a data-producer on one node
continuously feeding batches to a trainer on another —
TensorDictPipe provides a higher-performance
alternative built on UCXX, the
Python bindings for UCX.
Key advantages over torch.distributed send/recv:
No process group required — connect directly by IP and port.
Zero-allocation steady-state — after the first send establishes the schema, subsequent sends transmit only the raw storage buffer with no metadata, no allocation, and no parsing on the receiver.
InfiniBand / RDMA — UCX auto-selects the fastest available transport (TCP, InfiniBand verbs, RoCE).
GPUDirect RDMA — when tensors live on CUDA, the consolidated buffer is sent directly from GPU memory over RDMA, matching NCCL throughput without CPU staging.
Async-native — built on
asyncio, with synchronous wrappers for convenience.
Installing UCXX¶
UCXX requires a Linux host with UCX system libraries. Install via pip or conda:
# pip (CUDA 12)
pip install ucxx-cu12
# pip (CUDA 13)
pip install ucxx-cu13
# conda
conda install -c conda-forge -c rapidsai ucxx
For InfiniBand support the host needs rdma-core (or MOFED 5.0+) and
a Linux kernel ≥ 5.6. For GPUDirect RDMA, the nvidia_peermem kernel
module must be loaded. Verify with:
# check IB devices
ibstat
# check GPUDirect RDMA
lsmod | grep nvidia_peermem
UCX picks up transports automatically. You can control transport selection with environment variables:
# let UCX auto-select (recommended)
export UCX_TLS=all
export UCX_NET_DEVICES=all
# restrict to TCP only (useful for debugging)
export UCX_TLS=tcp
TensorDictPipe: the two-phase protocol¶
TensorDictPipe wraps a UCXX endpoint into a
typed channel that understands TensorDict’s consolidated layout. It
implements a two-phase protocol:
Handshake (first send) — the sender consolidates the TensorDict, transmits the consolidation metadata (keys, dtypes, shapes, byte offsets — serialized as JSON) followed by the raw storage buffer. The receiver allocates a single flat tensor and rebuilds a TensorDict whose leaves are views into that buffer.
Steady-state (subsequent sends with same schema) — the sender sends a one-byte “same schema” flag followed by just the raw buffer. The receiver overwrites its pre-allocated storage in-place. No allocation, no metadata parsing, no deserialization.
When the schema changes (e.g. different keys or shapes), the pipe automatically falls back to the handshake phase.
Establishing a connection¶
One side listens, the other connects:
import asyncio
from tensordict._ucxx import TensorDictPipe
# --- node A (receiver) ---
async def receiver():
pipe = await TensorDictPipe.listen(port=13337)
# first receive: handshake — allocates buffer, builds TensorDict
td = await pipe.arecv()
print(td) # TensorDict with the sender's keys/shapes/dtypes
# steady-state: zero-alloc receive into the same buffer
for _ in range(100):
td = await pipe.arecv(td)
train(td)
await pipe.aclose()
# --- node B (sender) ---
async def sender():
pipe = await TensorDictPipe.connect("10.0.1.1", port=13337)
for batch in dataloader:
td = TensorDict({"obs": batch[0], "label": batch[1]}, batch_size=[B])
await pipe.asend(td)
await pipe.aclose()
The pipe also supports synchronous usage and context managers:
# synchronous
pipe.send(td)
td = pipe.recv()
# context manager
async with await TensorDictPipe.listen(13337) as pipe:
td = await pipe.arecv()
Async iteration¶
TensorDictPipe implements the async iterator protocol, so you can
consume incoming TensorDicts with async for:
pipe = await TensorDictPipe.listen(13337)
async for td in pipe:
train(td)
The iterator yields the received TensorDict on each iteration, automatically reusing the buffer after the first handshake.
CPU and CUDA transfers¶
The pipe is device-aware. When the sender’s TensorDict contains CUDA tensors, the consolidated buffer lives on GPU and is transferred directly over RDMA (GPUDirect) without touching host memory. The receiver automatically allocates its buffer on the same device:
# sender — data on GPU
td = TensorDict({"x": torch.randn(1024, 1024, device="cuda")}, batch_size=[1024])
await pipe.asend(td) # sends from GPU memory
# receiver — automatically receives into GPU memory
td = await pipe.arecv()
print(td.device) # cuda:0
You can override the receiver’s device by passing device=:
td = await pipe.arecv(device="cpu") # force receive onto CPU
Unified API with send / recv¶
The send() and
recv() methods accept a
TensorDictPipe in place of a rank integer:
# works with torch.distributed
td.send(dst=1)
td.recv(src=0)
# works with UCXX — same API, different transport
td.send(dst=pipe)
td.recv(src=pipe)
# async variants
await td.asend(pipe)
await td.arecv(pipe)
This makes it easy to switch between torch.distributed and UCXX
without changing application logic.
One-to-many: TensorDictServer¶
TensorDictServer accepts multiple incoming
connections, yielding a separate TensorDictPipe for each client:
from tensordict._ucxx import TensorDictServer
async def handle_client(pipe):
async for td in pipe:
result = model(td)
await pipe.asend(result)
async with TensorDictServer(port=13337) as server:
await server.serve(handle_client)
Or using the async iterator interface:
async with TensorDictServer(port=13337) as server:
async for pipe in server:
asyncio.ensure_future(handle_client(pipe))
Benchmarks¶
The numbers below were measured on a two-node cluster with NVIDIA H200
GPUs and 400 Gb/s InfiniBand (nvidia_peermem loaded for GPUDirect
RDMA). Each transfer is a TensorDict whose leaves are packed into a
single contiguous float32 buffer of the indicated size. Latencies
are averages over 50 steady-state iterations after warm-up (UCXX “first
send” includes the one-off metadata handshake). The benchmark script is
at benchmarks/distributed/bench_transport.py.
CPU — gloo (TCP) vs UCXX (InfiniBand)
Size |
Method |
Latency |
Throughput |
|---|---|---|---|
1 KB |
gloo |
900 µs |
|
1 KB |
gloo |
212 µs |
|
1 KB |
UCXX pipe (steady-state) |
105 µs |
|
1 MB |
gloo |
963 µs |
2.18 GB/s |
1 MB |
gloo |
903 µs |
2.32 GB/s |
1 MB |
UCXX pipe (steady-state) |
682 µs |
3.08 GB/s |
100 MB |
gloo |
87 ms |
2.40 GB/s |
100 MB |
gloo |
87 ms |
2.42 GB/s |
100 MB |
gloo |
159 ms |
1.32 GB/s |
100 MB |
UCXX pipe (steady-state) |
37 ms |
5.71 GB/s |
UCXX over InfiniBand is 2.4× faster than the best gloo method for large CPU transfers.
CUDA — NCCL vs UCXX (GPUDirect RDMA)
Size |
Method |
Latency |
Throughput |
|---|---|---|---|
1 KB |
NCCL |
231 µs |
|
1 KB |
UCXX pipe (steady-state) |
138 µs |
|
1 MB |
NCCL |
238 µs |
8.81 GB/s |
1 MB |
UCXX pipe (steady-state) |
231 µs |
9.08 GB/s |
100 MB |
NCCL |
7.50 ms |
27.95 GB/s |
100 MB |
UCXX pipe (steady-state) |
7.47 ms |
28.08 GB/s |
UCXX matches NCCL throughput for large CUDA transfers and is 40% faster for small messages (138 µs vs 231 µs at 1 KB) due to lower per-message overhead. Both backends use GPUDirect RDMA — the consolidated storage stays on GPU and goes directly over InfiniBand without CPU staging.