Source code for torchtune.modules.tanh_gate
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
[docs]class TanhGate(nn.Module):
"""Implements a basic learnable gate to scale layer outputs"""
def __init__(self) -> None:
super().__init__()
self.scale = nn.Parameter(torch.zeros(1))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor to gate
Returns:
torch.Tensor: The output tensor after gating. Has the same shape as ``x``.
"""
return x * self.scale.tanh()