Source code for torchtune.modules.rms_norm
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from torch import nn
[docs]class RMSNorm(nn.Module):
"""
Root Mean Square Normalization in fp32.
See: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html
Args:
dim (int): embedding size
eps (float): small value to avoid division by zero. Default: 1e-6
"""
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.normalized_shape = (dim,)
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor to normalize
Returns:
torch.Tensor: The normalized and scaled tensor having the same shape as ``x``.
"""
# computation is in fp32
return F.rms_norm(
x.float(),
normalized_shape=self.normalized_shape,
weight=self.scale,
eps=self.eps,
).to(x.dtype)