LSTMCell#
- class torch.ao.nn.quantizable.modules.rnn.LSTMCell(input_dim, hidden_dim, bias=True, device=None, dtype=None, *, split_gates=False)[source]#
A quantizable long short-term memory (LSTM) cell.
For the description and the argument types, please, refer to
LSTMCellsplit_gates: specify True to compute the input/forget/cell/output gates separately to avoid an intermediate tensor which is subsequently chunk’d. This optimization can be beneficial for on-device inference latency. This flag is cascaded down from the parent classes.
Examples:
>>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTMCell(10, 20) >>> input = torch.randn(6, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx)
- classmethod from_params(wi, wh, bi=None, bh=None, split_gates=False)[source]#
Uses the weights and biases to create a new LSTM cell.
- Parameters:
wi – Weights for the input and hidden layers
wh – Weights for the input and hidden layers
bi – Biases for the input and hidden layers
bh – Biases for the input and hidden layers