Shortcuts

Source code for torchtune.modules.peft.lora

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List

import torch.nn.functional as F

from torch import nn, Tensor

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune.modules.peft.peft_utils import AdapterModule
from torchtune.utils import _register_nf4_dispatch_ops  # noqa: F401


[docs]class LoRALinear(nn.Module, AdapterModule): """LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_. LoRA perturbs a given layer via a low-rank approximation where only the rank decomposition matrices are trainable. In a linear layer instead of :math:`x \\mapsto W_0x` a LoRALinear layer is defined as :math:`x \\mapsto W_0x + (\\alpha / r)BAx`, where :math:`r` is the rank of the matrices :math:`A` and :math:`B` and :math:`\\alpha` is a scaling factor. As in the original implementation, we support dropout before multiplication by the low-rank matrices. Args: in_dim (int): input dimension out_dim (int): output dimension rank (int): rank of the low-rank approximation alpha (float): scaling factor for the low-rank approximation dropout (float): dropout probability. Default: 0.0 use_bias (bool): whether to include bias in the original linear layer. Default: False quantize_base (bool): Whether to quantize base linear weight or not. Default: False """ def __init__( self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float = 0.0, use_bias: bool = False, quantize_base: bool = False, ): super().__init__() self.in_dim = in_dim self.rank = rank self.alpha = alpha self.out_dim = out_dim self.use_bias = use_bias self._quantize_base = quantize_base weight, bias = self._create_weight_and_bias() # 'self.disabled' is a flag showing whether to turn off LoRA adapters, # this can be used in DPO for treating the lora adapters as the policy model # and disabling it to treat the base model as the reference model self.disabled = False self.register_parameter("weight", nn.Parameter(weight)) self.register_parameter( "bias", nn.Parameter(bias) if bias is not None else None ) self.dropout = nn.Dropout(p=dropout) self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) self.merged = False # Note: FSDP's meta device initialization contract assumes that a module's # reset_parameters method only initializes its own parameters (i.e. no child # params are initialized, as is done in initialize_parameters below). # For that reason, we patch reset_parameters directly on lora_a and lora_b submodules # when using meta device. This is done in # torchtune.utils.prepare_model_for_fsdp_with_meta_device. # See this issue for more details: https://github.com/pytorch/pytorch/issues/104187. # Without meta device, we only need the following: self.initialize_parameters() def initialize_parameters(self): # Initialize as in # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b) def _create_weight_and_bias(self): """ Creates a linear weight and bias tensor, using NF4 dtype if we're quantizing (indicated via quantize_base=True). """ in_dim, out_dim, use_bias = self.in_dim, self.out_dim, self.use_bias linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias) weight = linear.weight if not self._quantize_base else to_nf4(linear.weight) bias = None if self.use_bias: if self._quantize_base: raise NotImplementedError( "Quantized LoRALinear does not support bias at the moment." ) bias = linear.bias return weight, bias
[docs] def adapter_params(self) -> List[str]: """ Return lora_a.weight and lora_b.weight as adapter params. If bias is enabled, also return lora_a.bias and lora_b.bias. """ # NOTE: this function has to be updated if the names of "lora_a" and "lora_b" # in this module change. adapter_params = ["lora_a.weight", "lora_b.weight"] return adapter_params
[docs] def forward(self, x: Tensor) -> Tensor: """ Args: x (Tensor): input tensor with shape ``(..., in_dim)`` Returns: Tensor: output tensor with shape ``(..., out_dim)`` """ if self._quantize_base: out = linear_nf4(input=x, weight=self.weight) else: out = F.linear(x, self.weight, self.bias) if self.disabled: return out lora_out = self.lora_a(self.dropout(x)) lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) return out + lora_out
def _lora_a_init_params(x: nn.Linear) -> None: """ Initialize LoRA A weight to Kaiming uniform. """ nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) def _lora_b_init_params(x: nn.Linear) -> None: """ Initialize LoRA B weight to zeros. """ nn.init.zeros_(x.weight)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources