(Part 2) Fine-tuning with QAT, QLoRA, and float8
------------------------------------------------
TorchAO provides an end-to-end pre-training, fine-tuning, and serving
model optimization flow by leveraging our quantization and sparsity
techniques integrated into our partner frameworks. This is part 2 of 3
such tutorials showcasing this end-to-end flow, focusing on the
fine-tuning step.
.. image:: ../static/e2e_flow_part2.png
Fine-tuning is an important step for adapting your pre-trained model
to more domain-specific data. In this tutorial, we demonstrate 3 model
optimization techniques that can be applied to your model during fine-tuning:
1. **Quantization-Aware Training (QAT)**, for adapting your model to
quantization numerics during fine-tuning, with the goal of mitigating
quantization degradations in your fine-tuned model when it is quantized
eventually, e.g. in the serving step. Check out `our blog `__
and `README `__ for more details!
2. **Quantized Low-Rank Adaptation (QLoRA)**, for reducing the resource
requirement of fine-tuning by introducing small, trainable low-rank
matrices and freezing the original pre-trained checkpoint, a type of
Parameter-Efficient Fine-Tuning (PEFT). Please refer to the `original
paper `__ for more details.
3. **Float8 Quantized Fine-tuning**, for speeding up fine-tuning by
dynamically quantizing high precision weights and activations to float8,
similar to `pre-training in float8 `__.
Quantization-Aware Training (QAT)
##################################
The goal of Quantization-Aware Training is to adapt the model to
quantization numerics during training or fine-tuning, so as to mitigate
the inevitable quantization degradation when the model is actually
quantized eventually, presumably during the serving step after fine-tuning.
TorchAO's QAT support has been used successfully for the recent release of
the `Llama-3.2 quantized 1B/3B `__
and the `LlamaGuard-3-8B `__ models to improve the quality of the quantized models.
TorchAO's QAT support involves two separate steps: prepare and convert.
The prepare step "fake" quantizes activations and/or weights during
training, which means, the high precision values (e.g. bf16) are mapped
to their corresponding quantized values *without* actually casting them
to the target lower precision dtype (e.g. int4). The convert step,
applied after training, replaces "fake" quantization operations in the
model with "real" quantization that does perform the dtype casting:
.. image:: ../../torchao/quantization/qat/images/qat_diagram.png
There are multiple options for using TorchAO's QAT for fine-tuning:
1. Use our integration with `TorchTune `__
2. Use our integration with `Axolotl `__
3. Directly use our QAT APIs with your own training loop
Option 1: TorchTune QAT Integration
===================================
TorchAO's QAT support is integrated into TorchTune's distributed fine-tuning recipe.
Instead of the following command, which applies full distributed fine-tuning without QAT:
.. code::
# Regular fine-tuning without QAT
tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config llama3_2/3B_full batch_size=16
Users can run the following equivalent command instead. Note that specifying the quantizer
is optional:
.. code::
# Fine-tuning with QAT, by default:
# activations are fake quantized to asymmetric per token int8
# weights are fake quantized to symmetric per group int4
# configurable through "quantizer._component_" in the command
tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama3_2/3B_qat_full batch_size=16
After fine-tuning, users can quantize and evaluate the resulting model as follows.
This is the same whether or not QAT was used during the fine-tuning process:
.. code::
# Quantize model weights to int4
tune run quantize --config quantization \
model._component_=torchtune.models.llama3_2.llama3_2_3b \
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
'checkpointer.checkpoint_files=[model-00001-of-00002.safetensors,model-00002-of-00002.safetensors]' \
checkpointer.model_type=LLAMA3 \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
quantizer.groupsize=32
# Evaluate the int4 model on hellaswag and wikitext
tune run eleuther_eval --config eleuther_evaluation \
batch_size=1 \
'tasks=[hellaswag, wikitext]' \
model._component_=torchtune.models.llama3_2.llama3_2_3b \
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
'checkpointer.checkpoint_files=[model-00001-of-00002-8da4w.ckpt]' \
checkpointer.model_type=LLAMA3 \
tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
quantizer.groupsize=32
This should print the following after fine-tuning:
.. code::
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|---------|------:|------|------|--------|---|-----:|---|-----:|
|hellaswag| 1|none |None |acc |↑ |0.5021|± |0.0050|
| | |none |None |acc_norm|↑ |0.6797|± |0.0047|
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext| 2|none |None |bits_per_byte |↓ | 0.6965|± | N/A|
| | |none |None |byte_perplexity|↓ | 1.6206|± | N/A|
| | |none |None |word_perplexity|↓ |13.2199|± | N/A|
You can compare these values with and without QAT to see how much QAT helped mitigate quantization degradation!
For example, when fine-tuning Llama-3.2-3B on the
`OpenAssistant Conversations (OASST1) `__
dataset, we find that the quantized model achieved 3.4% higher accuracy
with QAT than without, recovering 69.8% of the overall accuracy degradation
from quantization:
.. image:: ../static/qat_eval.png
In addition to vanilla QAT as in the above example, TorchAO's QAT can also be composed with LoRA to yield a `1.89x training speedup `__ and lower memory usage by 36.1%. This is implemented in TorchTune's `QAT + LoRA fine-tuning recipe `__, which can be run using the following command:
.. code::
# Fine-tuning with QAT + LoRA
tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora batch_size=16
For more details about how QAT is set up in TorchTune, please refer to `this tutorial `__.
Option 2: Axolotl QAT Integration
=================================
Axolotl also recently added a QAT fine-tuning recipe that leverages TorchAO's QAT support.
To get started, try fine-tuning Llama-3.2-3B with QAT using the following command:
.. code::
axolotl train examples/llama-3/3b-qat-fsdp2.yaml
# once training is complete, perform the quantization step
axolotl quantize examples/llama-3/3b-qat-fsdp2.yaml
# you should now have a quantized model saved in ./outputs/qat_out/quatized
Please refer to the `Axolotl QAT documentation `__ for full details.
Option 3: TorchAO QAT API
=========================
If you prefer to use a different training framework or your own custom training loop,
you can call TorchAO's QAT APIs directly to transform the model before fine-tuning.
These APIs are what the TorchTune and Axolotl QAT integrations call under the hood.
In this example, we will fine-tune a mini version of Llama3 on a single GPU:
.. code:: py
import torch
from torchtune.models.llama3 import llama3
# Set up a smaller version of llama3 to fit in a single A100 GPU
# For smaller GPUs, adjust the model attributes accordingly
def get_model():
return llama3(
vocab_size=4096,
num_layers=16,
num_heads=16,
num_kv_heads=4,
embed_dim=2048,
max_seq_len=2048,
).cuda()
# Example training loop
def train_loop(m: torch.nn.Module):
optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(10):
example = torch.randint(0, 4096, (2, 16)).cuda()
target = torch.randn((2, 16, 4096)).cuda()
output = m(example)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Next, run the prepare step, which fake quantizes the model. In this example,
we use int8 per token dynamic activations and int4 symmetric per group weights
as our quantization scheme. Note that although we are targeting lower integer
precisions, training still performs arithmetic in higher float precision (float32)
because we are not actually casting the fake quantized values.
.. code:: py
from torchao.quantization import (
quantize_,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
model = get_model()
# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
quantize_(model, qat_config)
# fine-tune
train_loop(model)
After fine-tuning, we end up with a model in the original high precision.
This fine-tuned model has the exact same structure as the original model.
The only difference is the QAT fine-tuned model has weights that are more
attuned to quantization, which will be beneficial later during inference.
The next step is to actually quantize the model:
.. code:: py
from torchao.quantization import (
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FromIntXQuantizationAwareTrainingConfig,
)
# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
Now our model is ready for serving, and will typically have higher quantized
accuracy than if we did not apply the prepare step (fake quantization) during
fine-tuning.
For full details of using TorchAO's QAT API, please refer to the `QAT README `__.
.. raw:: html
Alternative Legacy API
The above `quantize_` API is the recommended flow for using TorchAO QAT.
We also offer an alternative legacy "quantizer" API for specific quantization
schemes, but these are not customizable unlike the above example.
.. code::
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
model = qat_quantizer.prepare(model)
# train
train_loop(model)
# convert: transform fake quantization ops into actual quantized ops
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
model = qat_quantizer.convert(model)
.. raw:: html
Quantized Low-Rank Adaptation (QLoRA)
#####################################
(Coming soon!)
Float8 Quantized Fine-tuning
############################
(Coming soon!)