Shortcuts

Source code for torchtune.modules.tanh_gate

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torch import nn


[docs]class TanhGate(nn.Module): """Implements a basic learnable gate to scale layer outputs""" def __init__(self) -> None: super().__init__() self.scale = nn.Parameter(torch.zeros(1))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): input tensor to gate Returns: torch.Tensor: The output tensor after gating. Has the same shape as ``x``. """ return x * self.scale.tanh()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources