Shortcuts

Source code for torchtune.modules.peft.dora

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import List

import torch
import torch.nn.functional as F

from torch import nn

from torchao.dtypes.nf4tensor import linear_nf4, to_nf4
from torchtune.modules.low_precision import _register_nf4_dispatch_ops  # noqa: F401
from torchtune.modules.peft import AdapterModule


[docs]class DoRALinear(nn.Module, AdapterModule): """DoRA linear layer as introduced in `DoRA: Weight-Decomposed Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2402.09353>`_. DoRA (Weight-Decomposed Low-Rank Adaptation) fine-tunes a layer by decomposing the pre-trained weights into two components: magnitude and direction. The magnitude component is a learnable scalar vector that scales each output channel, while the direction component, modified via LoRA, adjusts the orientation of weights. By scaling the LoRA update component :math:`BAx` with the `magnitude` vector, DoRA allows the model to apply distinct scaling adjustments across different output dimensions. Args: in_dim (int): input dimension out_dim (int): output dimension rank (int): rank of the low-rank approximation alpha (float): scaling factor for the low-rank approximation dropout (float): dropout probability. Default: 0.0 use_bias (bool): whether to include bias in the original linear layer. Default: False quantize_base (bool): Whether to quantize base linear weight or not. Default: False **quantization_kwargs: Keyword arguments to pass to `to_nf4` when quantizing the base linear weight. Examples of valid arguments are `block_size` and `scaler_block_size`, which control the granularity of weight quantization and scaler quantization respectively. This is only used if `quantize_base` is True. Default None Raises: ValueError: If ``quantize_base`` is False, but quantization kwargs are provided. """ def __init__( self, in_dim: int, out_dim: int, rank: int, alpha: float, dropout: float = 0.0, use_bias: bool = False, quantize_base: bool = False, **quantization_kwargs, ): super().__init__() self.in_dim = in_dim self.out_dim = out_dim self.scaling = alpha / rank self.use_bias = use_bias self._quantize_base = quantize_base if not self._quantize_base and quantization_kwargs: raise ValueError( f"``quantize_base`` is False, but received the following quantization arguments: {quantization_kwargs}" ) # Setup weight and bias linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self.use_bias) weight = ( linear.weight if not self._quantize_base else to_nf4(linear.weight, **quantization_kwargs) ) bias = linear.bias if self.use_bias else None # 'self.disabled' is a flag showing whether to turn off DoRA adapters, # this can be used in DPO for treating the dora adapters as the policy model # and disabling it to treat the base model as the reference model self.disabled = False self.register_parameter("weight", nn.Parameter(weight)) self.register_parameter( "bias", nn.Parameter(bias) if bias is not None else None ) self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity() self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) self.magnitude = nn.Parameter(torch.empty(out_dim)) self.initialize_parameters() def initialize_parameters(self): # Initialize as in # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119 _lora_a_init_params(self.lora_a) _lora_b_init_params(self.lora_b)
[docs] @torch.no_grad() def initialize_dora_magnitude(self): """ DoRA initializes the magnitude vector such that its outputs are initially identical to standard LoRA's outputs. """ base_weight = self.weight.to(self.lora_a.weight.dtype) lora_weight = self.lora_b.weight @ self.lora_a.weight weight_norm = self._get_weight_norm(base_weight, lora_weight) self.magnitude.copy_(weight_norm)
def _get_weight_norm(self, weight, lora_weight): weight = weight + self.scaling * lora_weight weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) return weight_norm
[docs] def adapter_params(self) -> List[str]: """ Return lora_a.weight and lora_b.weight as adapter params. If bias is enabled, also return lora_a.bias and lora_b.bias. """ adapter_params = ["lora_a.weight", "lora_b.weight", "magnitude"] return adapter_params
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): input tensor with shape ``(..., in_dim)`` Returns: Tensor: output tensor with shape ``(..., out_dim)`` """ if self._quantize_base: base_out = linear_nf4(input=x, weight=self.weight) if self.use_bias: base_out = base_out + self.bias else: base_out = F.linear(x, self.weight, self.bias) if self.disabled: return base_out x = self.dropout(x) lora_out = self.lora_b(self.lora_a(x)) # Can't use raw matmul since FSDP hooks are attached to __call__ # Instead follow the approach in https://github.com/huggingface/peft/pull/1806 x_eye = torch.eye( self.lora_a.weight.shape[1], device=self.lora_a.weight.device, dtype=x.dtype ) lora_weight = self.lora_b(self.lora_a(x_eye)).T magnitude = self.magnitude weight = self.weight.to(x.dtype) weight_norm = self._get_weight_norm(weight, lora_weight.detach()) weight_norm = weight_norm.detach() mag_norm_scale = (magnitude / weight_norm).view(1, -1) dora_out = ( mag_norm_scale - 1 ) * base_out + mag_norm_scale * lora_out * self.scaling return dora_out + base_out
def _lora_a_init_params(x: nn.Linear) -> None: """ Initialize LoRA A weight to Kaiming uniform. """ nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5)) def _lora_b_init_params(x: nn.Linear) -> None: """ Initialize LoRA B weight to zeros. """ nn.init.zeros_(x.weight)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources