Rate this Page

Quantization#

The Vulkan backend currently supports execution of quantized linear layers, where weights are symmetrically quantized to 8-bit or 4-bit with per output channel or per group quantization scales.

Support for additional quantized operators and quantization schemes (i.e. static

  • dynamic quantized convolution, support for statically quantized linear) is under active development and will be added soon.

4-bit quantization with torchao quantize_#

The quantize_ API from torchao allows for more advanced quantization schemes, and is the quantization workflow needed to access 4-bit quantization. 4-bit quantization is commonly used for LLMs.

Two options are available to execute linear layers with 4-bit quantization:

  1. Dynamically quantized activations via Int8DynamicActivationIntxWeightConfig

  2. Weight only quantization via IntxWeightOnlyConfig

Dynamically quantized activations can provide a significant boost in latency compared to weight only quantization, since it allows GPUs to leverage accelerated integer dot product instructions when computing matrix multiplication.

Below is a simple example of quantizing a simple sequence of linear layers using the quantize_ API.

import torch

from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner

from executorch.exir import to_edge_transform_and_lower
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import (
    Int8DynamicActivationIntxWeightConfig,
    IntxWeightOnlyConfig,
    quantize_,
)
from torchao.utils import unwrap_tensor_subclass


class LinearSequenceModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(128, 64, bias=False)
        self.linear2 = torch.nn.Linear(64, 32, bias=False)
        self.linear3 = torch.nn.Linear(32, 16, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x


linear_sequence_module = LinearSequenceModule()

M = 32
sample_inputs = (torch.randn(M, 128),)

group_size = 32

q_config_8da4w = Int8DynamicActivationIntxWeightConfig(
    weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
)

q_config_4w = IntxWeightOnlyConfig(
    weight_dtype=torch.int4, granularity=PerGroup(group_size)
)

quantize_(linear_sequence_module, q_config_8da4w)
unwrap_tensor_subclass(linear_sequence_module)

# Regular export path from here
exported_program = torch.export.export(linear_sequence_module, sample_inputs)

etvk_program = to_edge_transform_and_lower(
    exported_program,
    partitioner=[VulkanPartitioner()],
).to_executorch()

8-bit quantization with PT2E quantization#

For 8-bit quantized linear layers, currently the only quantization scheme supported is weight only quantization, with weights that are symmetrically quantized to 8 bits with per output channel quantization scales.

To access this quantization mode, the PT2E quantization flow must be used. At a high level, the steps to quantize a model are:

  1. Create an instance of the VulkanQuantizer class and specify desired quantization behaviour

  2. Use torch.export.export to prepare for quantization.

  3. Call prepare_pt2e to prepare the exported graph for quantization.

  4. Execute the prepared model with representative samples to calibrate the quantizated tensor activation ranges.

  5. Call convert_pt2e to quantize the model.

  6. Export and lower the model using the standard flow.

For example:

import torch

from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner

from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
    get_symmetric_quantization_config,
    VulkanQuantizer,
)

from executorch.exir import to_edge_transform_and_lower

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from torchao.utils import unwrap_tensor_subclass


class LinearSequenceModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(128, 64, bias=False)
        self.linear2 = torch.nn.Linear(64, 32, bias=False)
        self.linear3 = torch.nn.Linear(32, 16, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x


linear_sequence_module = LinearSequenceModule()

M = 32
# Create sample inputs
sample_inputs = (torch.randn(M, 128),)

# Setup quantizer
quantizer = VulkanQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_dynamic=False, weight_bits=8))

# Export the model
exported_program = torch.export.export(linear_sequence_module, sample_inputs)
graph_module = exported_program.module()

# Quantize the exported program with PT2E quantization flow
quantized_module = prepare_pt2e(graph_module, quantizer)
# Calibrate. In practice, this would be done by iterating over a real dataset
quantized_module(*sample_inputs)
quantized_module = convert_pt2e(quantized_module)

# Export once more
exported_program = torch.export.export(quantized_module, sample_inputs)

# Lower to vulkan
etvk_program = to_edge_transform_and_lower(
    exported_program,
    partitioner=[VulkanPartitioner()],
).to_executorch()