Rate this Page

Quantization#

To quantize a PyTorch model for the Core ML backend, use the CoreMLQuantizer. Quantizers are backend specific, which means the CoreMLQuantizer is configured to quantize models to leverage the quantized operators offered by the Core ML backend.

Supported Quantization Schemes#

The CoreML delegate supports the following quantization schemes:

  • 8-bit static and weight-only quantization via the PT2E flow; dynamic quantization is not supported by CoreML.

  • 4-bit weight-only affine quantization (per-group or per-channel) via the quantize_ flow

  • 1-8 bit weight-only LUT quantization (per grouped-channel) via the quantize_ flow

8-bit Quantization using the PT2E Flow#

Quantization with the Core ML backend requires exporting the model for iOS 17 or later. To perform 8-bit quantization with the PT2E flow, follow these steps:

  1. Create a coremltools.optimize.torch.quantization.LinearQuantizerConfig and use it to create an instance of a CoreMLQuantizer.

  2. Use torch.export.export to export a graph module that will be prepared for quantization.

  3. Call prepare_pt2e to prepare the model for quantization.

  4. Run the prepared model with representative samples to calibrate the quantizated tensor activation ranges.

  5. Call convert_pt2e to quantize the model.

  6. Export and lower the model using the standard flow.

The output of convert_pt2e is a PyTorch model which can be exported and lowered using the normal flow. As it is a regular PyTorch model, it can also be used to evaluate the accuracy of the quantized model using standard PyTorch techniques.

import torch
import coremltools as ct
import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.apple.coreml.compiler import CoreMLBackend

mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )

# Step 1: Define a LinearQuantizerConfig and create an instance of a CoreMLQuantizer
# Note that "linear" here does not mean only linear layers are quantized, but that linear (aka affine) quantization
# is being performed
static_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig(
    global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
        quantization_scheme="symmetric",
        activation_dtype=torch.quint8,
        weight_dtype=torch.qint8,
        weight_per_channel=True,
    )
)
quantizer = CoreMLQuantizer(static_8bit_config)

# Step 2: Export the model for training
training_gm = torch.export.export(mobilenet_v2, sample_inputs).module()

# Step 3: Prepare the model for quantization
prepared_model = prepare_pt2e(training_gm, quantizer)

# Step 4: Calibrate the model on representative data
# Replace with your own calibration data
for calibration_sample in [torch.randn(1, 3, 224, 224)]:
    prepared_model(calibration_sample)

# Step 5: Convert the calibrated model to a quantized model
quantized_model = convert_pt2e(prepared_model)

# Step 6: Export the quantized model to Core ML
et_program = to_edge_transform_and_lower(
    torch.export.export(quantized_model, sample_inputs),
    partitioner=[
        CoreMLPartitioner(
             # iOS17 is required for the quantized ops in this example
            compile_specs=CoreMLBackend.generate_compile_specs(
                minimum_deployment_target=ct.target.iOS17
            )
        )
    ],
).to_executorch()

The above does static quantization (activations and weights are quantized).

You can see a full description of available quantization configs in the coremltools documentation. For example, the config below will perform weight-only quantization:

weight_only_8bit_config = ct.optimize.torch.quantization.LinearQuantizerConfig(
    global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig(
        quantization_scheme="symmetric",
        activation_dtype=torch.float32,
        weight_dtype=torch.qint8,
        weight_per_channel=True,
    )
)
quantizer = CoreMLQuantizer(weight_only_8bit_config)

Quantizing activations requires calibrating the model on representative data. Also note that PT2E currently requires passing at least 1 calibration sample before calling convert_pt2e, even for data-free weight-only quantization.

See PyTorch 2 Export Post Training Quantization for more information.

LLM quantization with quantize_#

The Core ML backend also supports quantizing models with the torchao quantize_ API. This is most commonly used for LLMs, requiring more advanced quantization. Since quantize_ is not backend aware, it is important to use a config that is compatible with Core ML:

  • Quantize embedding/linear layers with IntxWeightOnlyConfig (with weight_dtype torch.int4 or torch.int8, using PerGroup or PerAxis granularity). Using 4-bit or PerGroup quantization requires exporting with minimum_deployment_target >= ct.target.iOS18. Using 8-bit quantization with per-axis granularity is supported on ct.target.IOS16+. See Core ML CompileSpec for more information on setting the deployment target.

  • Quantize embedding/linear layers with CodebookWeightOnlyConfig (with dtype torch.uint1 through torch.uint8, using various block sizes). Quantizing with CodebookWeightOnlyConfig requires exporting with minimum_deployment_target >= ct.target.iOS18, see Core ML CompileSpec for more information on setting the deployment target.

Below is an example that quantizes embeddings to 8-bits per-axis and linear layers to 4-bits using group_size=32 with affine quantization:

from torchao.quantization.granularity import PerGroup, PerAxis
from torchao.quantization.quant_api import (
    IntxWeightOnlyConfig,
    quantize_,
)

# Quantize embeddings with 8-bits, per channel
embedding_config = IntxWeightOnlyConfig(
    weight_dtype=torch.int8,
    granularity=PerAxis(0),
)
quantize_(
    eager_model,
    embedding_config,
    lambda m, fqn: isinstance(m, torch.nn.Embedding),
)

# Quantize linear layers with 4-bits, per-group
linear_config = IntxWeightOnlyConfig(
    weight_dtype=torch.int4,
    granularity=PerGroup(32),
)
quantize_(
    eager_model,
    linear_config,
)

Below is another example that uses codebook quantization to quantize both embeddings and linear layers to 3-bits. In the coremltools documentation, this is called palettization:

from torchao.quantization.quant_api import (
    quantize_,
)
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig

quant_config = CodebookWeightOnlyConfig(
    dtype=torch.uint3,
    # There is one LUT per 16 rows
    block_size=[16, -1],
)

quantize_(
    eager_model,
    quant_config,
    lambda m, fqn: isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear),
)

Both of the above examples will export and lower to Core ML with the to_edge_transform_and_lower API.