Recurrent Layers#
Recurrent layers process sequential data by maintaining hidden state across time steps. They are essential for tasks involving sequences: language modeling, speech recognition, time series prediction, and more.
RNN: Basic recurrent layer (simple but prone to vanishing gradients)
LSTM: Long Short-Term Memory (gated architecture, handles long-range dependencies)
GRU: Gated Recurrent Unit (simpler than LSTM, often similar performance)
Cell variants: Single-step versions for custom loop implementations
Key parameters:
input_size: Number of features in inputhidden_size: Number of features in hidden statenum_layers: Number of stacked recurrent layersbatch_first: If true, input shape is[batch, seq, features]bidirectional: Process sequence in both directions
RNN#
-
class RNN : public torch::nn::ModuleHolder<RNNImpl>#
A
ModuleHoldersubclass forRNNImpl.See the documentation for
RNNImplclass to learn what methods it provides, and examples of how to useRNNwithtorch::nn::RNNOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.
-
class RNNImpl : public torch::nn::detail::RNNImplBase<RNNImpl>#
A multi-layer Elman RNN module with Tanh or ReLU activation.
See https://pytorch.org/docs/main/generated/torch.nn.RNN.html to learn about the exact behavior of this module.
See the documentation for
torch::nn::RNNOptionsclass to learn what constructor arguments are supported for this module.Example:
RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
Public Functions
-
inline RNNImpl(int64_t input_size, int64_t hidden_size)#
-
explicit RNNImpl(const RNNOptions &options_)#
-
std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor> forward_with_packed_input(const torch::nn::utils::rnn::PackedSequence &packed_input, Tensor hx = {})#
Public Members
-
RNNOptions options#
Friends
- friend struct torch::nn::AnyModuleHolder
-
inline RNNImpl(int64_t input_size, int64_t hidden_size)#
Example:
auto rnn = torch::nn::RNN(
torch::nn::RNNOptions(128, 256) // input_size, hidden_size
.num_layers(2)
.batch_first(true)
.bidirectional(false));
auto input = torch::randn({32, 10, 128}); // [batch, seq_len, input_size]
auto [output, hidden] = rnn->forward(input);
LSTM#
-
class LSTM : public torch::nn::ModuleHolder<LSTMImpl>#
A
ModuleHoldersubclass forLSTMImpl.See the documentation for
LSTMImplclass to learn what methods it provides, and examples of how to useLSTMwithtorch::nn::LSTMOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.
-
class LSTMImpl : public torch::nn::detail::RNNImplBase<LSTMImpl>#
A multi-layer long-short-term-memory (LSTM) module.
See https://pytorch.org/docs/main/generated/torch.nn.LSTM.html to learn about the exact behavior of this module.
See the documentation for
torch::nn::LSTMOptionsclass to learn what constructor arguments are supported for this module.Example:
LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
Public Functions
-
inline LSTMImpl(int64_t input_size, int64_t hidden_size)#
-
explicit LSTMImpl(const LSTMOptions &options_)#
-
std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(const Tensor &input, std::optional<std::tuple<Tensor, Tensor>> hx_opt = {})#
-
std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>> forward_with_packed_input(const torch::nn::utils::rnn::PackedSequence &packed_input, std::optional<std::tuple<Tensor, Tensor>> hx_opt = {})#
Public Members
-
LSTMOptions options#
Friends
- friend struct torch::nn::AnyModuleHolder
-
inline LSTMImpl(int64_t input_size, int64_t hidden_size)#
Example:
auto lstm = torch::nn::LSTM(
torch::nn::LSTMOptions(128, 256)
.num_layers(2)
.batch_first(true)
.dropout(0.1)
.bidirectional(true));
auto input = torch::randn({32, 10, 128});
auto [output, state] = lstm->forward(input);
auto [h_n, c_n] = state; // hidden state, cell state
GRU#
-
class GRU : public torch::nn::ModuleHolder<GRUImpl>#
A
ModuleHoldersubclass forGRUImpl.See the documentation for
GRUImplclass to learn what methods it provides, and examples of how to useGRUwithtorch::nn::GRUOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.
-
class GRUImpl : public torch::nn::detail::RNNImplBase<GRUImpl>#
A multi-layer gated recurrent unit (GRU) module.
See https://pytorch.org/docs/main/generated/torch.nn.GRU.html to learn about the exact behavior of this module.
See the documentation for
torch::nn::GRUOptionsclass to learn what constructor arguments are supported for this module.Example:
GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
Public Functions
-
inline GRUImpl(int64_t input_size, int64_t hidden_size)#
-
explicit GRUImpl(const GRUOptions &options_)#
-
std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor> forward_with_packed_input(const torch::nn::utils::rnn::PackedSequence &packed_input, Tensor hx = {})#
Public Members
-
GRUOptions options#
Friends
- friend struct torch::nn::AnyModuleHolder
-
inline GRUImpl(int64_t input_size, int64_t hidden_size)#
RNNCell#
-
class RNNCell : public torch::nn::ModuleHolder<RNNCellImpl>#
A
ModuleHoldersubclass forRNNCellImpl.See the documentation for
RNNCellImplclass to learn what methods it provides, and examples of how to useRNNCellwithtorch::nn::RNNCellOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = RNNCellImpl#
-
using Impl = RNNCellImpl#
-
class RNNCellImpl : public torch::nn::detail::RNNCellImplBase<RNNCellImpl>#
An Elman RNN cell with tanh or ReLU non-linearity.
See https://pytorch.org/docs/main/nn.html#torch.nn.RNNCell to learn about the exact behavior of this module.
See the documentation for
torch::nn::RNNCellOptionsclass to learn what constructor arguments are supported for this module.Example:
RNNCell model(RNNCellOptions(20, 10).bias(false).nonlinearity(torch::kReLU));
Public Functions
-
inline RNNCellImpl(int64_t input_size, int64_t hidden_size)#
-
explicit RNNCellImpl(const RNNCellOptions &options_)#
Public Members
-
RNNCellOptions options#
Friends
- friend struct torch::nn::AnyModuleHolder
-
inline RNNCellImpl(int64_t input_size, int64_t hidden_size)#
LSTMCell#
-
class LSTMCell : public torch::nn::ModuleHolder<LSTMCellImpl>#
A
ModuleHoldersubclass forLSTMCellImpl.See the documentation for
LSTMCellImplclass to learn what methods it provides, and examples of how to useLSTMCellwithtorch::nn::LSTMCellOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = LSTMCellImpl#
-
using Impl = LSTMCellImpl#
-
class LSTMCellImpl : public torch::nn::detail::RNNCellImplBase<LSTMCellImpl>#
A long short-term memory (LSTM) cell.
See https://pytorch.org/docs/main/nn.html#torch.nn.LSTMCell to learn about the exact behavior of this module.
See the documentation for
torch::nn::LSTMCellOptionsclass to learn what constructor arguments are supported for this module.Example:
LSTMCell model(LSTMCellOptions(20, 10).bias(false));
Public Functions
-
inline LSTMCellImpl(int64_t input_size, int64_t hidden_size)#
-
explicit LSTMCellImpl(const LSTMCellOptions &options_)#
Public Members
-
LSTMCellOptions options#
Friends
- friend struct torch::nn::AnyModuleHolder
-
inline LSTMCellImpl(int64_t input_size, int64_t hidden_size)#
GRUCell#
-
class GRUCell : public torch::nn::ModuleHolder<GRUCellImpl>#
A
ModuleHoldersubclass forGRUCellImpl.See the documentation for
GRUCellImplclass to learn what methods it provides, and examples of how to useGRUCellwithtorch::nn::GRUCellOptions. See the documentation forModuleHolderto learn about PyTorch’s module storage semantics.Public Types
-
using Impl = GRUCellImpl#
-
using Impl = GRUCellImpl#
-
class GRUCellImpl : public torch::nn::detail::RNNCellImplBase<GRUCellImpl>#
A gated recurrent unit (GRU) cell.
See https://pytorch.org/docs/main/nn.html#torch.nn.GRUCell to learn about the exact behavior of this module.
See the documentation for
torch::nn::GRUCellOptionsclass to learn what constructor arguments are supported for this module.Example:
GRUCell model(GRUCellOptions(20, 10).bias(false));
Public Functions
-
inline GRUCellImpl(int64_t input_size, int64_t hidden_size)#
-
explicit GRUCellImpl(const GRUCellOptions &options_)#
Public Members
-
GRUCellOptions options#
Friends
- friend struct torch::nn::AnyModuleHolder
-
inline GRUCellImpl(int64_t input_size, int64_t hidden_size)#