RMSNorm¶
- class torch.nn.modules.normalization.RMSNorm(normalized_shape, eps=None, elementwise_affine=True, device=None, dtype=None)[source]¶
Applies Root Mean Square Layer Normalization over a mini-batch of inputs.
This layer implements the operation as described in the paper Root Mean Square Layer Normalization
The root mean squared norm is taken over the last
Ddimensions, whereDis the dimension ofnormalized_shape. For example, ifnormalized_shapeis(3, 5)(a 2-dimensional shape), the rms norm is computed over the last 2 dimensions of the input.- Parameters
normalized_shape (int or list or torch.Size) –
input shape from an expected input of size
If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size.
eps (Optional[float]) – a value added to the denominator for numerical stability. Default:
torch.finfo(x.dtype).eps()elementwise_affine (bool) – a boolean value that when set to
True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default:True.
- Shape:
Input:
Output: (same shape as input)
Examples:
>>> rms_norm = nn.RMSNorm([2, 3]) >>> input = torch.randn(2, 2, 3) >>> rms_norm(input)