Quantization#
The Vulkan backend currently supports execution of quantized linear layers, where weights are symmetrically quantized to 8-bit or 4-bit with per output channel or per group quantization scales.
Support for additional quantized operators and quantization schemes (i.e. static
dynamic quantized convolution, support for statically quantized linear) is under active development and will be added soon.
4-bit quantization with torchao quantize_
#
The quantize_
API from torchao allows for
more advanced quantization schemes, and is the quantization workflow needed to
access 4-bit quantization. 4-bit quantization is commonly used for LLMs.
Two options are available to execute linear layers with 4-bit quantization:
Dynamically quantized activations via
Int8DynamicActivationIntxWeightConfig
Weight only quantization via
IntxWeightOnlyConfig
Dynamically quantized activations can provide a significant boost in latency compared to weight only quantization, since it allows GPUs to leverage accelerated integer dot product instructions when computing matrix multiplication.
Below is a simple example of quantizing a simple sequence of linear layers using
the quantize_
API.
import torch
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
from executorch.exir import to_edge_transform_and_lower
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass
class LinearSequenceModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128, 64, bias=False)
self.linear2 = torch.nn.Linear(64, 32, bias=False)
self.linear3 = torch.nn.Linear(32, 16, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
linear_sequence_module = LinearSequenceModule()
M = 32
sample_inputs = (torch.randn(M, 128),)
group_size = 32
q_config_8da4w = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
)
q_config_4w = IntxWeightOnlyConfig(
weight_dtype=torch.int4, granularity=PerGroup(group_size)
)
quantize_(linear_sequence_module, q_config_8da4w)
unwrap_tensor_subclass(linear_sequence_module)
# Regular export path from here
exported_program = torch.export.export(linear_sequence_module, sample_inputs)
etvk_program = to_edge_transform_and_lower(
exported_program,
partitioner=[VulkanPartitioner()],
).to_executorch()
8-bit quantization with PT2E quantization#
For 8-bit quantized linear layers, currently the only quantization scheme supported is weight only quantization, with weights that are symmetrically quantized to 8 bits with per output channel quantization scales.
To access this quantization mode, the PT2E quantization flow must be used. At a high level, the steps to quantize a model are:
Create an instance of the
VulkanQuantizer
class and specify desired quantization behaviourUse
torch.export.export
to prepare for quantization.Call
prepare_pt2e
to prepare the exported graph for quantization.Execute the prepared model with representative samples to calibrate the quantizated tensor activation ranges.
Call
convert_pt2e
to quantize the model.Export and lower the model using the standard flow.
For example:
import torch
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
get_symmetric_quantization_config,
VulkanQuantizer,
)
from executorch.exir import to_edge_transform_and_lower
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.utils import unwrap_tensor_subclass
class LinearSequenceModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128, 64, bias=False)
self.linear2 = torch.nn.Linear(64, 32, bias=False)
self.linear3 = torch.nn.Linear(32, 16, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
linear_sequence_module = LinearSequenceModule()
M = 32
# Create sample inputs
sample_inputs = (torch.randn(M, 128),)
# Setup quantizer
quantizer = VulkanQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_dynamic=False, weight_bits=8))
# Export the model
exported_program = torch.export.export(linear_sequence_module, sample_inputs)
graph_module = exported_program.module()
# Quantize the exported program with PT2E quantization flow
quantized_module = prepare_pt2e(graph_module, quantizer)
# Calibrate. In practice, this would be done by iterating over a real dataset
quantized_module(*sample_inputs)
quantized_module = convert_pt2e(quantized_module)
# Export once more
exported_program = torch.export.export(quantized_module, sample_inputs)
# Lower to vulkan
etvk_program = to_edge_transform_and_lower(
exported_program,
partitioner=[VulkanPartitioner()],
).to_executorch()