Note
Go to the end to download the full example code.
Recurrent training on sequence batches#
Author: Achintya Paningapalli
How TorchRL auto-wires
InitTrackerand the recurrent-state primer when a recurrent policy is detectedHow to sample multi-step slices from a replay buffer with
SliceSamplerHow to train a recurrent policy in full-sequence mode using
set_recurrent_modeHow
set_recurrent_modeis used in collectors (sequential,False) vs. loss modules (recurrent,True), and how to control it across processes and workersWhy
is_initat slice starts is what keeps hidden state from leaking across episode boundaries inside a replay batch
PyTorch v2.0.0
gymnasium[classic_control]
Familiarity with Recurrent DQN is helpful but not required — this tutorial is its multi-step complement
from __future__ import annotations
Overview#
There are two ways to train a recurrent policy in TorchRL, and choosing between them is the first design decision in any recurrent project:
Sequential mode (one step per forward call). The policy reads the previous step’s hidden state from the TensorDict, runs the LSTM for one step, and writes the new hidden state under the
("next", ...)keys. This is the natural mode during collection and is what Recurrent DQN covers in depth.Recurrent mode (full
[B, T]sequence per forward call). The LSTM processes a whole time dimension at once. This is used inside loss / advantage code over replayed trajectory slices, where you want to backprop through several timesteps in a single batched call. Rectangular[B, T]tensors are one option, but TorchRL also accepts the flat, ragged slices returned bySliceSampler: the flat dimension is treated as a packed time axis andis_initpartitions it into independent chunks.
This tutorial focuses on (2): how to collect data, sample multi-step slices, and run a recurrent policy in full-sequence mode without leaking hidden state across episode boundaries.
The key building blocks — most of which the collector now auto-wires for you — are:
LSTMModule(the recurrent policy core),InitTracker(writesis_init=Trueat the start of every trajectory),A
TensorDictPrimerthat seeds the initial recurrent state on reset,SliceSamplerfor trajectory-aware replay,set_recurrent_modefor switching the LSTM between single-step and full-sequence execution.
See Recurrent state lifecycle for the full reference on how hidden state flows through the pipeline.
If you are running this in Google Colab, install the dependencies first:
!pip3 install torchrl
!pip3 install gymnasium
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch import nn
from torchrl.collectors import Collector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.envs import GymEnv
from torchrl.modules import LSTMModule, set_recurrent_mode
torch.manual_seed(0)
device = torch.device("cpu")
Environment and policy#
We use CartPole-v1 for a small, fast, fully-observed env. A recurrent
policy is overkill for CartPole, but that is precisely the point: it lets
us focus on the sequence-batching machinery without the env being the
bottleneck. The same pattern scales unchanged to partially-observed
environments where memory actually matters.
The policy is a tiny LSTMModule followed by a
linear head that maps the LSTM output to the action logits. We pick
hidden_size=16 so the tutorial finishes quickly on CPU.
OBS_DIM = 4 # CartPole observation: cart pos, cart vel, pole angle, pole vel
N_ACTIONS = 2 # left or right
HIDDEN = 16
def make_policy() -> Seq:
"""Construct the recurrent policy: LSTMModule + linear logits head."""
lstm = LSTMModule(
input_size=OBS_DIM,
hidden_size=HIDDEN,
in_keys=["observation", "rs_h", "rs_c"],
out_keys=["features", ("next", "rs_h"), ("next", "rs_c")],
python_based=True, # avoids cuDNN for vmap / torch.compile compatibility
)
head = Mod(
nn.Linear(HIDDEN, N_ACTIONS),
in_keys=["features"],
out_keys=["logits"],
)
# Deterministic argmax action selection — enough to demonstrate the
# collection / replay / sequence-training plumbing.
chooser = Mod(
lambda logits: logits.argmax(-1, keepdim=False),
in_keys=["logits"],
out_keys=["action"],
)
return Seq(lstm, head, chooser)
policy = make_policy()
policy.eval()
TensorDictSequential(
module=ModuleList(
(0): LSTMModule()
(1): TensorDictModule(
module=Linear(in_features=16, out_features=2, bias=True),
device=cpu,
in_keys=['features'],
out_keys=['logits'])
(2): TensorDictModule(
module=<function make_policy.<locals>.<lambda> at 0x7f2cb5520220>,
device=cpu,
in_keys=['logits'],
out_keys=['action'])
),
device=cpu,
in_keys=['observation', 'rs_h', 'rs_c', 'is_init'],
out_keys=['features', ('next', 'rs_h'), ('next', 'rs_c'), 'logits', 'action'])
Auto-wiring recurrent env transforms#
A recurrent policy needs two env-side transforms to behave correctly:
InitTracker, which writesis_init=Trueat the first step after every reset.A
TensorDictPrimerthat zero-fills the initial("rs_h", "rs_c")recurrent state slots on reset.
Forgetting either of these is the most common source of silent
hidden-state bugs. The collector now detects recurrent submodules in the
policy and auto-appends both transforms when auto_register_policy_transforms=True
is passed. As of v0.15 this becomes the default; for earlier versions you
need to opt in explicitly.
Inspecting a single batch#
Run one rollout and look at what comes back. Three things to notice:
The batch shape is
[T]whereT == frames_per_batch(single env).is_initisTrueat the very first step and immediately after eachdone. These mark trajectory boundaries.("next", "rs_h")/("next", "rs_c")carry the LSTM’s next-step hidden / cell state across timesteps.
data = next(iter(collector))
print("Batch shape:", data.shape)
print("Available keys:", sorted(k for k in data.keys()))
print("is_init shape:", data["is_init"].shape)
print("# trajectory boundaries in batch:", int(data["is_init"].sum().item()))
print(
"Next-step hidden shape:",
data["next", "rs_h"].shape,
"(batch, num_layers, hidden_size)",
)
Batch shape: torch.Size([64])
Available keys: ['action', 'collector', 'done', 'features', 'is_init', 'logits', 'next', 'observation', 'rs_c', 'rs_h', 'terminated', 'truncated']
is_init shape: torch.Size([64, 1])
# trajectory boundaries in batch: 7
Next-step hidden shape: torch.Size([64, 1, 16]) (batch, num_layers, hidden_size)
Replay with SliceSampler#
We store the rollout in a replay buffer and sample fixed-length slices of
trajectories from it. SliceSampler
is trajectory-aware: it uses the boundary information already in the
batch (("next", "done") and the collector-written ("collector",
"traj_ids")) to draw whole sub-trajectories.
Two design choices in this section:
Slices remain *ragged* by default.
SliceSamplerreturns a flat batch of concatenated variable-length slices, not a rectangular[B, T]tensor. The trajectory boundaries inside that flat batch are marked byis_init=True. This avoids padding waste and matches how TorchRL’s recurrent modules already consume input.``pad_output=True`` is available but discouraged. It pads short slices to
slice_lenand writes a("collector", "mask")key for consumers (e.g. mask-aware losses), but the docstring on SliceSampler explicitly steers users toward the ragged path when possible.
SLICE_LEN = 8
NUM_SLICES = 4
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(max_size=1024, device=device),
sampler=SliceSampler(slice_len=SLICE_LEN, strict_length=False),
batch_size=SLICE_LEN * NUM_SLICES,
)
rb.extend(data)
sample = rb.sample()
is_init_positions = sample["is_init"].squeeze(-1).nonzero().squeeze(-1).tolist()
print("Sampled batch shape:", sample.shape)
print("is_init True positions (slice starts):", is_init_positions)
Sampled batch shape: torch.Size([32])
is_init True positions (slice starts): [0, 8, 16, 24]
What is_init marks#
Each True value in the sampled batch’s is_init flags the first
timestep of a slice (or a trajectory boundary that fell inside a slice
because the underlying episode ended mid-slice). Downstream code — the
recurrent module’s full-sequence forward and GAE / advantage computation —
reads these markers to know where to reset hidden state. Padding masks are a
separate signal: when pad_output=True is used, SliceSampler writes a
("collector", "mask") key for consumers that need to ignore padded
timesteps.
Sequential mode vs. recurrent mode#
set_recurrent_mode has two modes:
set_recurrent_mode(False)(sequential): the LSTM processes one timestep at a time, readingrs_h/rs_cfrom the TensorDict and writing the updated state into("next", "rs_h")/("next", "rs_c"). This is the natural mode during collection: the collector calls the policy once per environment step.set_recurrent_mode(True)(recurrent): the LSTM processes a whole time batch in a singlenn.LSTMcall, usingis_initto reset the hidden state at trajectory boundaries. This is the mode used during training over replayed slices.
How set_recurrent_mode is managed in collectors and loss modules:
During collection, the collector calls the policy inside sequential mode. The default
default_recurrent_mode=FalseonLSTMModulemeans no explicit context manager is required — the LSTM automatically runs step-by-step. If you need to override this (e.g. in a custom collector), wrap the policy call withset_recurrent_mode(False).All built-in TorchRL loss modules wrap their
forwardwithset_recurrent_mode(True)automatically, so you never need to set the mode manually when calling a loss. The same applies to advantage estimators and value estimators.In multi-process / distributed settings (e.g.
MultiaSyncDataCollector), each worker process carries its own thread-local recurrent mode. The mode set in the main process does not propagate to worker processes. Workers always start with the default (sequential,False); the training process sets recurrent mode independently when computing the loss.
Equivalence of sequential and recurrent mode:
Running the LSTM step-by-step under set_recurrent_mode(False) and
running it in a single batch under set_recurrent_mode(True) produce
numerically identical outputs. We verify this now on a short sequence.
# Build a minimal hand-crafted 4-step trajectory so the equivalence is
# unambiguous: one trajectory, is_init only at t=0, zero initial hidden.
T_equiv = 4
demo = TensorDict(
{
"observation": torch.randn(T_equiv, OBS_DIM),
"rs_h": torch.zeros(T_equiv, 1, HIDDEN),
"rs_c": torch.zeros(T_equiv, 1, HIDDEN),
"is_init": torch.tensor([True, False, False, False]),
},
batch_size=[T_equiv],
)
# --- set_recurrent_mode(False): one call per step ---
# The LSTM reads rs_h / rs_c, processes a single timestep, and writes the
# updated state into ("next", "rs_h") / ("next", "rs_c"). We thread the
# hidden state from one step to the next by hand.
outs_seq = []
h = torch.zeros(1, 1, HIDDEN)
c = torch.zeros(1, 1, HIDDEN)
for t in range(T_equiv):
step = demo[t : t + 1].clone()
step["rs_h"] = h
step["rs_c"] = c
with set_recurrent_mode(False):
step_out = policy(step)
outs_seq.append(step_out["features"])
h = step_out["next", "rs_h"]
c = step_out["next", "rs_c"]
features_seq = torch.cat(outs_seq, dim=0)
# --- set_recurrent_mode(True): one batched call ---
with set_recurrent_mode(True):
out_rec = policy(demo.clone())
# They must be numerically identical.
torch.testing.assert_close(features_seq, out_rec["features"], rtol=1e-4, atol=1e-5)
print(
"set_recurrent_mode(False) step-by-step == set_recurrent_mode(True) batched: verified."
)
# Now run the full replay sample in recurrent mode (the training-time pattern).
with set_recurrent_mode(True):
out = policy(sample.clone())
print("Recurrent-mode output shape:", out["features"].shape, "(total_steps, hidden)")
print(
"Recurrent-mode logits shape:",
out["logits"].shape,
"(total_steps, n_actions)",
)
set_recurrent_mode(False) step-by-step == set_recurrent_mode(True) batched: verified.
Recurrent-mode output shape: torch.Size([32, 16]) (total_steps, hidden)
Recurrent-mode logits shape: torch.Size([32, 2]) (total_steps, n_actions)
Why the boundary handling matters: a controlled check#
The forward above already ran the interior-split path on real sampled data. To make the no-leakage guarantee provable rather than merely exercised, we now build a small two-trajectory packed batch by hand, seed the first trajectory with non-zero noise in its incoming hidden, and check that the second trajectory’s outputs match a standalone forward over just that second trajectory.
If hidden state leaked across the boundary, the noisy first half would pollute the second half’s outputs and the comparison would fail.
# Two adjacent slices of length 4 each, packed end-to-end. is_init=True at
# index 0 (first traj start) and index 4 (second traj start).
T_A = T_B = 4
T = T_A + T_B
is_init_packed = torch.zeros(1, T, dtype=torch.bool)
is_init_packed[0, 0] = True
is_init_packed[0, T_A] = True
obs = torch.randn(1, T, OBS_DIM)
# Seed the incoming hidden with noise. The recurrent forward should
# *override* this at every is_init=True position.
noisy_h = torch.randn(1, T, 1, HIDDEN)
noisy_c = torch.randn(1, T, 1, HIDDEN)
packed = TensorDict(
{
"observation": obs,
"rs_h": noisy_h,
"rs_c": noisy_c,
"is_init": is_init_packed,
},
batch_size=[1, T],
)
# Isolated forward over only the second trajectory, with is_init=True at
# step 0 (its real start). If the packed run handles boundaries correctly,
# the packed batch's second-half output must match this isolated output.
is_init_b = torch.zeros(1, T_B, dtype=torch.bool)
is_init_b[0, 0] = True
b_only = TensorDict(
{
"observation": obs[:, T_A:].clone(),
"rs_h": noisy_h[:, T_A:].clone(), # same noise — must be overridden
"rs_c": noisy_c[:, T_A:].clone(),
"is_init": is_init_b,
},
batch_size=[1, T_B],
)
with set_recurrent_mode(True):
packed_out = policy(packed)
b_out = policy(b_only)
# Trajectory B's features inside the packed batch == trajectory B alone.
torch.testing.assert_close(
packed_out["features"][:, T_A:], b_out["features"], rtol=1e-5, atol=1e-6
)
print("Hidden-state isolation across is_init boundary: verified.")
Hidden-state isolation across is_init boundary: verified.
A tiny gradient-path smoke test#
We close with a minimal supervised update on top of the same
infrastructure. We use the simplest possible objective — match the
action logits to a constant target — purely to exercise the
set_recurrent_mode(True) + replay-buffer + LSTM gradient path. In a real
recurrent training job, this block is where you would compute returns or
advantages and call the objective module, still under the same recurrent
context manager.
trainable_policy = make_policy()
optimizer = torch.optim.Adam(trainable_policy.parameters(), lr=3e-4)
target_logits = torch.zeros(N_ACTIONS)
target_logits[0] = 1.0 # arbitrary supervised target
losses = []
for _step in range(4):
sample = rb.sample()
with set_recurrent_mode(True):
out = trainable_policy(sample.clone())
loss = (out["logits"] - target_logits.expand_as(out["logits"])).pow(2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
collector.shutdown()
print("Training loss trajectory:", [round(v, 4) for v in losses])
Training loss trajectory: [0.4146, 0.4089, 0.4061, 0.4059]
Conclusion#
You have built a recurrent training pipeline that:
Lets the collector auto-wire
InitTrackerand the recurrent-state primer for you, removing two manual steps that used to be a frequent source of silent bugs.Samples multi-step trajectory slices from a replay buffer with
SliceSampler.Runs the LSTM in full-sequence mode under
set_recurrent_mode(True)and verified by construction that hidden state does not leak across trajectory boundaries inside a sampled batch.
This is the canonical TorchRL pattern for recurrent / sequence-based RL. It composes cleanly with every loss module, every advantage estimator, and every replay-buffer extension in the library.
Further reading#
Recurrent DQN — the single-step / collection-time complement to this tutorial.
examples/replay-buffers/recurrent_slice_sampler_pipeline.py— a minimal, runnable end-to-end version of this pipeline (GRU policy, the collector writing directly into the buffer in a background thread, andSliceSamplerauto-detecting the trajectory key).Recurrent state lifecycle — full reference on how hidden state flows from collection through replay to the loss, and what
is_initcontrols along the way.Collector Internals — the per-step rollout flow,
_carriersemantics, and how the device-cast flags interact with recurrent state.Glossary — short definitions of
is_init,TensorDictPrimer,recurrent mode,set_keys, and other shorthand that appears throughout the recurrent code paths.
Total running time of the script: (0 minutes 0.168 seconds)