• Docs >
  • Finetuning Llama2 with LoRA
Shortcuts

Finetuning Llama2 with LoRA

This guide will teach you about LoRA, a parameter-efficient finetuning technique, and show you how you can use TorchTune to finetune a Llama2 model with LoRA. If you already know what LoRA is and want to get straight to running your own LoRA finetune in TorchTune, you can jump to LoRA finetuning recipe in TorchTune.

What you will learn
  • What LoRA is and how it saves memory during finetuning

  • An overview of LoRA components in TorchTune

  • How to run a LoRA finetune using TorchTune

  • How to experiment with different LoRA configurations

Prerequisites

What is LoRA?

LoRA is an adapter-based method for parameter-efficient finetuning that adds trainable low-rank decomposition matrices to different layers of a neural network, then freezes the network’s remaining parameters. LoRA is most commonly applied to transformer models, in which case it is common to add the low-rank matrices to some of the linear projections in each transformer layer’s self-attention.

Note

If you’re unfamiliar, check out these references for the definition of rank and discussion of low-rank approximations.

By finetuning with LoRA (as opposed to finetuning all model parameters ), you can expect to see memory savings due to a substantial reduction in the number of parameters with gradients. When using an optimizer with momentum, like AdamW, you can expect to see further memory savings from the optimizer state.

Note

LoRA memory savings come primarily from gradient and optimizer states, and so if your model’s peak memory comes in its forward(), then LoRA may not reduce peak memory.

How does LoRA work?

LoRA replaces weight update matrices with a low-rank approximation. In general, weight updates for an arbitrary nn.Linear(in_dim,out_dim) layer could have rank as high as min(in_dim,out_dim). LoRA (and other related papers such as Aghajanyan et al.) hypothesize that the intrinsic dimension of these updates during LLM fine-tuning can in fact be much lower. To take advantage of this property, LoRA finetuning will freeze the original model, then add a trainable weight update from a low-rank projection. More explicitly, LoRA trains two matrices A and B. A projects the inputs down to a much smaller rank (often four or eight in practice), and B projects back up to the dimension output by the original linear layer.

The image below gives a simplified representation of a single weight update step from a full finetune (on the left) compared to a weight update step with LoRA (on the right). The LoRA matrices A and B serve as an approximation to the full rank weight update in blue.

../_images/lora_diagram.png

Although LoRA introduces a few extra parameters in the model forward(), only the A and B matrices are trainable. This means that with a rank r LoRA decomposition, the number of gradients we need to store reduces from in_dim*out_dim to r*(in_dim+out_dim). (Remember that in general r is much smaller than in_dim and out_dim.)

For example, in the 7B Llama2’s self-attention, in_dim=out_dim=4096 for the Q, K, and V projections. This means a LoRA decomposition of rank r=8 will reduce the number of trainable parameters for a given projection from \(4096 * 4096 \approx 15M\) to \(8 * 8192 \approx 65K\), a reduction of over 99%.

Let’s take a look at a minimal implementation of LoRA in native PyTorch.

from torch import nn, Tensor

class LoRALinear(nn.Module):
  def __init__(
    self,
    in_dim: int,
    out_dim: int,
    rank: int,
    alpha: float,
    dropout: float
  ):
    # These are the weights from the original pretrained model
    self.linear = nn.Linear(in_dim, out_dim, bias=False)

    # These are the new LoRA params. In general rank << in_dim, out_dim
    self.lora_a = nn.Linear(in_dim, rank, bias=False)
    self.lora_b = nn.Linear(rank, out_dim, bias=False)

    # Rank and alpha are commonly-tuned hyperparameters
    self.rank = rank
    self.alpha = alpha

    # Most implementations also include some dropout
    self.dropout = nn.Dropout(p=dropout)

    # The original params are frozen, and only LoRA params are trainable.
    self.linear.weight.requires_grad = False
    self.lora_a.weight.requires_grad = True
    self.lora_b.weight.requires_grad = True

  def forward(self, x: Tensor) -> Tensor:
    # This would be the output of the original model
    frozen_out = self.linear(x)

    # lora_a projects inputs down to the much smaller self.rank,
    # then lora_b projects back up to the output dimension
    lora_out = self.lora_b(self.lora_a(self.dropout(x)))

    # Finally, scale by the alpha parameter (normalized by rank)
    # and add to the original model's outputs
    return frozen_out + (self.alpha / self.rank) * lora_out

There are some other details around initialization which we omit here, but if you’d like to know more you can see our implementation in LoRALinear. Now that we understand what LoRA is doing, let’s look at how we can apply it to our favorite models.

Applying LoRA to Llama2 models

With TorchTune, we can easily apply LoRA to Llama2 with a variety of different configurations. Let’s take a look at how to construct Llama2 models in TorchTune with and without LoRA.

from torchtune.models.llama2 import llama2_7b, lora_llama2_7b

# Build Llama2 without any LoRA layers
base_model = llama2_7b()

# The default settings for lora_llama2_7b will match those for llama2_7b
# We just need to define which layers we want LoRA applied to.
# We can choose from ["q_proj", "k_proj", "v_proj", and "output_proj"]
lora_model = lora_llama2_7b(lora_attn_modules=["q_proj", "v_proj"])

Note

Calling lora_llama_2_7b alone will not handle the definition of which parameters are trainable. See below for how to do this.

Let’s inspect each of these models a bit more closely.

# Print the first layer's self-attention in the usual Llama2 model
print(base_model.layers[0].attn)

CausalSelfAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (output_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)

# Print the same for Llama2 with LoRA weights
print(lora_model.layers[0].attn)

CausalSelfAttention(
  (q_proj): LoRALinear(
    (dropout): Dropout(p=0.0, inplace=False)
    (lora_a): Linear(in_features=4096, out_features=8, bias=False)
    (lora_b): Linear(in_features=8, out_features=4096, bias=False)
  )
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): LoRALinear(
    (dropout): Dropout(p=0.0, inplace=False)
    (lora_a): Linear(in_features=4096, out_features=8, bias=False)
    (lora_b): Linear(in_features=8, out_features=4096, bias=False)
  )
  (output_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (pos_embeddings): RotaryPositionalEmbeddings()
)

Notice that our LoRA model’s layer contains additional weights in the Q and V projections, as expected. Additionally, inspecting the type of lora_model and base_model, would show that they are both instances of the same TransformerDecoder. (Feel free to verify this for yourself.)

Why does this matter? TorchTune makes it easy to load checkpoints for LoRA directly from our Llama2 model without any wrappers or custom checkpoint conversion logic.

# Assuming that base_model already has the pretrained Llama2 weights,
# this will directly load them into your LoRA model without any conversion necessary.
lora_model.load_state_dict(base_model.state_dict(), strict=False)

Note

Whenever loading weights with strict=False, you should verify that any missing or extra keys in the loaded state_dict are as expected. TorchTune’s LoRA recipe does this by default via torchtune.modules.peft.validate_state_dict_for_lora().

Once we’ve loaded the base model weights, we also want to set only LoRA parameters to trainable.

from torchtune.modules.peft.peft_utils import get_adapter_params, set_trainable_params

# Fetch all params from the model that are associated with LoRA.
lora_params = get_adapter_params(lora_model)

# Set requires_grad=True on lora_params, and requires_grad=False on all others.
set_trainable_params(lora_model, lora_params)

# Print the total number of parameters
total_params = sum([p.numel() for p in lora_model.params()])
trainable_params = sum([p.numel() for p in lora_model.parameters() if p.requires_grad])
print(
  f"""
  {total_params} total params,
  {trainable_params}" trainable params,
  {(100.0 * trainable_params / total_params):.2f}% of all params are trainable.
  """
)

6742609920 total params,
4194304 trainable params,
0.06% of all params are trainable.

Note

If you are directly using the LoRA recipe (as detailed here), you need only pass the relevant checkpoint path. Loading model weights and setting trainable parameters will be taken care of in the recipe.

LoRA finetuning recipe in TorchTune

Finally, we can put it all together and finetune a model using TorchTune’s LoRA recipe. Make sure that you have first downloaded the Llama2 weights and tokenizer by following these instructions. You can then run the following command to perform a LoRA finetune of Llama2-7B using the Alpaca dataset with two GPUs (each having VRAM of at least 23GB):

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed

Note

Make sure to point to the location of your Llama2 weights and tokenizer. This can be done either by adding checkpointer.checkpoint_files=[my_model_checkpoint_path] tokenizer_checkpoint=my_tokenizer_checkpoint_path or by directly modifying the 7B_lora.yaml file. See our Configs Deep-Dive for more details on how you can easily clone and modify TorchTune configs.

Note

You can modify the value of nproc_per_node depending on (a) the number of GPUs you have available, and (b) the memory constraints of your hardware. See this table for peak memory of LoRA finetuning in a couple of common hardware setups.

The preceding command will run a LoRA finetune with TorchTune’s factory settings, but we may want to experiment a bit. Let’s take a closer look at some of the lora_finetune_distributed config.

# Model Arguments
model:
  _component_: lora_llama2_7b
  lora_attn_modules: ['q_proj', 'v_proj']
  lora_rank: 8
  lora_alpha: 16
...

We see that the default is to apply LoRA to Q and V projections with a rank of 8. Some experiments with LoRA have found that it can be beneficial to apply LoRA to all linear layers in the self-attention, and to increase the rank to 16 or 32. Note that this is likely to increase our max memory, but as long as we keep rank<<embed_dim, the impact should be relatively minor.

Let’s run this experiment. We can also increase alpha (in general it is good practice to scale alpha and rank together).

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config lora_finetune_distributed \
lora_attn_modules='[q_proj, k_proj, v_proj, output_proj]' \
lora_rank=32 lora_alpha=64 output_dir=./lora_experiment_1

A comparison of the (smoothed) loss curves between this run and our baseline over the first 500 steps can be seen below.

../_images/lora_experiment_loss_curves.png

Note

The above figure was generated with W&B. You can use TorchTune’s WandBLogger to generate similar loss curves, but you will need to install W&B and setup an account separately.

As an exercise, you can also try running some evaluation tasks or manually inspecting generations output by your saved checkpoints (which can be found in output_dir). You may want to train the model for longer first, as here we only looked at 500 steps (which corresponds to about 2% of one epoch of the Alpaca dataset).

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources