Recurrent modules#
TorchRL recurrent modules wrap PyTorch RNNs in TensorDict-aware modules.
LSTMModule and GRUModule read observations and recurrent
state entries from a TensorDict, write features and
("next", ...) recurrent states back to it, and use the is_init key to
reset hidden states at trajectory boundaries.
The recurrent modules are designed for the contiguous trajectory layout
described in Data layout: contiguous trajectories: replay buffers and samplers can return flat
1-D slices, and the modules recover the sequence boundaries from is_init
instead of requiring padded [B, T] tensors and masks.
Execution modes#
By default, recurrent modules run in single-step mode. This is the mode used
during environment interaction: the input TensorDict contains one step per
environment, the previous recurrent state is read from the TensorDict, and the
next recurrent state is written under ("next", ...).
During training, wrap the recurrent policy with
set_recurrent_mode to process complete rollouts or replay-buffer
slices:
from torchrl.modules import GRUModule, set_recurrent_mode
gru = GRUModule(
input_size=4,
hidden_size=64,
in_keys=["observation", "recurrent_state", "is_init"],
out_keys=["features", ("next", "recurrent_state")],
)
with set_recurrent_mode("recurrent"):
batch = gru(batch)
In recurrent mode, every is_init=True entry resets the hidden state to the
state stored at that same position. This lets a flat batch of concatenated
trajectory slices behave like independent sequences without materializing
padding.
Backend selection#
The recurrent_backend constructor argument controls how recurrent-mode
calls handle resets inside a batch.
"pad"Splits trajectories on
is_init, pads them to a common length, and uses PyTorch’s cuDNN-backedtorch.nn.LSTMortorch.nn.GRU. This is the default and the broadest compatibility path."scan"Uses a scan over the time dimension and avoids padded trajectory chunks. This is friendlier to
torch.compile()for reset-heavy RL batches. Supports unidirectional GRU/LSTM without dropout, and (for LSTM) without projections. Unsupported configurations raise when the recurrent path is executed."triton"Uses TorchRL’s fused Triton kernels for reset-aware GRU/LSTM recurrence. This backend is CUDA-only and requires a recent Triton installation. It is intended for reset-heavy recurrent RL training where split/pad overhead is significant. Multilayer unidirectional modules (including dropout between layers) are handled directly; unsupported variants — bidirectional modules and LSTM projections — silently fall back to the pad semantics.
"auto"Uses
"pad"in eager mode and"scan"when called undertorch.compile().
For long-running experiments, prefer choosing a backend explicitly once the
model shape and deployment target are known. "pad" is the safest baseline,
"scan" is the compile-friendly baseline, and "triton" is the
performance-oriented CUDA backend.
Triton precision controls#
The Triton backend performs hidden-to-hidden recurrent matrix multiplications
inside Triton kernels and input-to-hidden projections through PyTorch/cuBLAS.
The recurrent_matmul_precision argument keeps those paths aligned.
Supported values are:
"auto"Defer to the process-wide TorchRL setting, and fall back to
torch.get_float32_matmul_precision()if the global is itself"auto". TheTORCHRL_RNN_PRECISIONenvironment variable seeds the process-wide setting at import time. It is not consulted at every kernel call; callset_recurrent_matmul_precision()with"auto"orNoneto re-read it after import."ieee"Use IEEE FP32 matmuls (~23 bits of mantissa, CUDA cores, no tensor cores). This is the most conservative setting and is useful for numerical comparisons with the scan backend.
"tf32"Use TF32 tensor cores on Ampere or newer NVIDIA GPUs (~10 bits of mantissa, highest throughput).
"tf32x3"Use Triton’s three-product TF32 decomposition for the recurrent matmul (~22 bits of mantissa on tensor cores). cuBLAS has no
tf32x3mode, so the input-to-hidden projection stays IEEE FP32. Useful when long rollouts make recurrent precision drift visible."fast"and"high-prec"GPU-aware presets. On TF32-capable NVIDIA GPUs,
"fast"resolves to"tf32"and"high-prec"resolves to"tf32x3". On devices without TF32 tensor cores, both resolve to"ieee".
The process-wide default can be changed with
set_recurrent_matmul_precision():
from torchrl.modules import set_recurrent_matmul_precision
set_recurrent_matmul_precision("high-prec")
gru = GRUModule(
input_size=4,
hidden_size=64,
recurrent_backend="triton",
recurrent_matmul_precision="auto",
in_keys=["observation", "recurrent_state", "is_init"],
out_keys=["features", ("next", "recurrent_state")],
)
A module-level recurrent_matmul_precision=... value takes precedence over
the process-wide setting. Use get_recurrent_matmul_precision() to inspect
the resolved concrete mode for the current device.
Choosing a layout and backend#
For most recurrent RL pipelines:
Use
InitTrackeror pass the policy to the env/collector so that TorchRL adds theis_initkey and recurrent-state primers automatically.Store replay data in the flat contiguous layout and sample with
SliceSampler.Run collection in single-step mode and training under
set_recurrent_mode.Start with
recurrent_backend="pad"for correctness, then benchmark"scan"or"triton"for the target hardware.
See also#
Data layout: contiguous trajectories for the contiguous trajectory layout and replay-buffer handoff.
LSTMModuleandGRUModulefor constructor arguments and examples.set_recurrent_modefor switching between single-step and recurrent execution.set_recurrent_matmul_precision()andget_recurrent_matmul_precision()for Triton precision control.