Shortcuts

Quantization

Introduction to Quantization

Quantization refers to techniques for performing computations and storing tensors at lower bitwidths than floating point precision. A quantized model executes some or all of the operations on tensors with integers rather than floating point values. This allows for a more compact model representation and the use of high performance vectorized operations on many hardware platforms. PyTorch supports INT8 quantization compared to typical FP32 models allowing for a 4x reduction in the model size and a 4x reduction in memory bandwidth requirements. Hardware support for INT8 computations is typically 2 to 4 times faster compared to FP32 compute. Quantization is primarily a technique to speed up inference and only the forward pass is supported for quantized operators.

PyTorch supports multiple approaches to quantizing a deep learning model. In most cases the model is trained in FP32 and then the model is converted to INT8. In addition, PyTorch also supports quantization aware training, which models quantization errors in both the forward and backward passes using fake-quantization modules. Note that the entire computation is carried out in floating point. At the end of quantization aware training, PyTorch provides conversion functions to convert the trained model into lower precision.

At lower level, PyTorch provides a way to represent quantized tensors and perform operations with them. They can be used to directly construct models that perform all or part of the computation in lower precision. Higher-level APIs are provided that incorporate typical workflows of converting FP32 model to lower precision with minimal accuracy loss.

Today, PyTorch supports the following backends for running quantized operators efficiently:

  • x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations)

  • ARM CPUs (typically found in mobile/embedded devices)

The corresponding implementation is chosen automatically based on the PyTorch build mode.

Note

PyTorch 1.3 doesn’t provide quantized operator implementations on CUDA yet - this is direction of future work. Move the model to CPU in order to test the quantized functionality.

Quantization-aware training (through FakeQuantize) supports both CPU and CUDA.

Quantized Tensors

PyTorch supports both per tensor and per channel asymmetric linear quantization. Per tensor means that all the values within the tensor are scaled the same way. Per channel means that for each dimension, typically the channel dimension of a tensor, the values in the tensor are scaled and offset by a different value (effectively the scale and offset become vectors). This allows for lesser error in converting tensors to quantized values.

The mapping is performed by converting the floating point tensors using

_images/math-quantizer-equation.png

Note that, we ensure that zero in floating point is represented with no error after quantization, thereby ensuring that operations like padding do not cause additional quantization error.

In order to do quantization in PyTorch, we need to be able to represent quantized data in Tensors. A Quantized Tensor allows for storing quantized data (represented as int8/uint8/int32) along with quantization parameters like scale and zero_point. Quantized Tensors allow for many useful operations making quantized arithmetic easy, in addition to allowing for serialization of data in a quantized format.

Operation coverage

Quantized Tensors support a limited subset of data manipulation methods of the regular full-precision tensor. (see list below)

For NN operators included in PyTorch, we restrict support to:

  1. 8 bit weights (data_type = qint8)

  2. 8 bit activations (data_type = quint8)

Note that operator implementations currently only support per channel quantization for weights of the conv and linear operators. Furthermore the minimum and the maximum of the input data is mapped linearly to the minimum and the maximum of the quantized data type such that zero is represented with no quantization error.

Additional data types and quantization schemes can be implemented through the custom operator mechanism.

Many operations for quantized tensors are available under the same API as full float version in torch or torch.nn. Quantized version of NN modules that perform re-quantization are available in torch.nn.quantized. Those operations explicitly take output quantization parameters (scale and zero_point) in the operation signature.

In addition, we also support fused versions corresponding to common fusion patterns that impact quantization at: torch.nn.intrinsic.quantized.

For quantization aware training, we support modules prepared for quantization aware training at torch.nn.qat and torch.nn.intrinsic.qat

Current quantized operation list is sufficient to cover typical CNN and RNN models:

Quantized torch.Tensor operations

Operations that are available from the torch namespace or as methods on Tensor for quantized tensors:

  • quantize_per_tensor() - Convert float tensor to quantized tensor with per-tensor scale and zero point

  • quantize_per_channel() - Convert float tensor to quantized tensor with per-channel scale and zero point

  • View-based operations like view(), as_strided(), expand(), flatten(), slice(), python-style indexing, etc - work as on regular tensor (if quantization is not per-channel)

  • Comparators
  • copy_() — Copies src to self in-place

  • clone() — Returns a deep copy of the passed-in tensor

  • dequantize() — Convert quantized tensor to float tensor

  • equal() — Compares two tensors, returns true if quantization parameters and all integer elements are the same

  • int_repr() — Prints the underlying integer representation of the quantized tensor

  • max() — Returns the maximum value of the tensor (reduction only)

  • mean() — Mean function. Supported variants: reduction, dim, out

  • min() — Returns the minimum value of the tensor (reduction only)

  • q_scale() — Returns the scale of the per-tensor quantized tensor

  • q_zero_point() — Returns the zero_point of the per-tensor quantized zero point

  • q_per_channel_scales() — Returns the scales of the per-channel quantized tensor

  • q_per_channel_zero_points() — Returns the zero points of the per-channel quantized tensor

  • q_per_channel_axis() — Returns the channel axis of the per-channel quantized tensor

  • relu() — Rectified linear unit (copy)

  • relu_() — Rectified linear unit (inplace)

  • resize_() — In-place resize

  • sort() — Sorts the tensor

  • topk() — Returns k largest values of a tensor

torch.nn.intrinsic

Fused modules are provided for common patterns in CNNs. Combining several operations together (like convolution and relu) allows for better quantization accuracy

  • torch.nn.intrinsic — float versions of the modules, can be swapped with quantized version 1 to 1
  • torch.nn.intrinsic.qat — versions of layers for quantization-aware training
  • torch.nn.intrinsic.quantized — quantized version of fused layers for inference (no BatchNorm variants as it’s usually folded into convolution for inference)

torch.nn.qat

Layers for the quantization-aware training

  • Linear — Linear (fully-connected) layer

  • Conv2d — 2D convolution

torch.quantization

  • Functions for quantization
    • add_observer_() — Adds observer for the leaf modules (if quantization configuration is provided)

    • add_quant_dequant()— Wraps the leaf child module using QuantWrapper

    • convert() — Converts float module with observers into its quantized counterpart. Must have quantization configuration

    • get_observer_dict() — Traverses the module children and collects all observers into a dict

    • prepare() — Prepares a copy of a model for quantization

    • prepare_qat() — Prepares a copy of a model for quantization aware training

    • propagate_qconfig_() — Propagates quantization configurations through the module hierarchy and assign them to each leaf module

    • quantize() — Converts a float module to quantized version

    • quantize_dynamic() — Converts a float module to dynamically quantized version

    • quantize_qat()— Converts a float module to quantized version used in quantization aware training

    • swap_module() — Swaps the module with its quantized counterpart (if quantizable and if it has an observer)

  • default_eval_fn() — Default evaluation function used by the torch.quantization.quantize()

  • fuse_modules()

  • FakeQuantize — Module for simulating the quantization/dequantization at training time

  • Default Observers. The rest of observers are available from torch.quantization.observer
    • default_observer — Same as MinMaxObserver.with_args(reduce_range=True)

    • default_weight_observer — Same as MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)

    • Observer — Abstract base class for observers

  • Quantization configurations
    • QConfig — Quantization configuration class

    • default_qconfig — Same as QConfig(activation=default_observer, weight=default_weight_observer) (See QConfig)

    • default_qat_qconfig — Same as QConfig(activation=default_fake_quant, weight=default_weight_fake_quant) (See QConfig)

    • default_dynamic_qconfig — Same as QConfigDynamic(weight=default_weight_observer) (See QConfigDynamic)

    • float16_dynamic_qconfig — Same as QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16)) (See QConfigDynamic)

  • Stubs

Observers for computing the quantization parameters

  • MinMaxObserver — Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per tensor variant)

  • MovingAverageObserver — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per tensor variant)

  • PerChannelMinMaxObserver— Derives the quantization parameters from the running minimum and maximum of the observed tensor inputs (per channel variant)

  • MovingAveragePerChannelMinMaxObserver — Derives the quantization parameters from the running averages of the minimums and maximums of the observed tensor inputs (per channel variant)

  • HistogramObserver — Derives the quantization parameters by creating a histogram of running minimums and maximums.

  • Observers that do not compute the quantization parameters:
    • RecordingObserver — Records all incoming tensors. Used for debugging only.

    • NoopObserver — Pass-through observer. Used for situation when there are no quantization parameters (i.e. quantization to float16)

torch.nn.quantized

Quantized version of standard NN layers.

  • Quantize — Quantization layer, used to automatically replace QuantStub

  • DeQuantize — Dequantization layer, used to replace DeQuantStub

  • FloatFunctional — Wrapper class to make stateless float operations stateful so that they can be replaced with quantized versions

  • QFunctional — Wrapper class for quantized versions of stateless operations like `torch.add

  • Conv2d — 2D convolution

  • Linear — Linear (fully-connected) layer

  • MaxPool2d — 2D max pooling

  • ReLU — Rectified linear unit

  • ReLU6 — Rectified linear unit with cut-off at quantized representation of 6

torch.nn.quantized.dynamic

Layers used in dynamically quantized models (i.e. quantized only on weights)

  • Linear — Linear (fully-connected) layer

  • LSTM — Long-Short Term Memory RNN module

torch.nn.quantized.functional

Functional versions of quantized NN layers (many of them accept explicit quantization output parameters)

  • adaptive_avg_pool2d() — 2D adaptive average pooling

  • avg_pool2d() — 2D average pooling

  • conv2d() — 2D convolution

  • interpolate() — Down-/up- sampler

  • linear() — Linear (fully-connected) op

  • max_pool2d() — 2D max pooling

  • relu() — Rectified linear unit

  • upsample() — Upsampler. Will be deprecated in favor of interpolate()

  • upsample_bilinear() — Bilenear upsampler. Will be deprecated in favor of interpolate()

  • upsample_nearest() — Nearest neighbor upsampler. Will be deprecated in favor of interpolate()

Quantized dtypes and quantization schemes

  • torch.qscheme — Type to describe the quantization scheme of a tensor. Supported types:
    • torch.per_tensor_affine — per tensor, asymmetric

    • torch.per_channel_affine — per channel, asymmetric

    • torch.per_tensor_symmetric — per tensor, symmetric

    • torch.per_channel_symmetric — per tensor, symmetric

  • torch.dtype — Type to describe the data. Supported types:
    • torch.quint8 — 8-bit unsigned integer

    • torch.qint8 — 8-bit signed integer

    • torch.qint32 — 32-bit signed integer

Quantization Workflows

PyTorch provides three approaches to quantize models.

  1. Post Training Dynamic Quantization: This is the simplest to apply form of quantization where the weights are quantized ahead of time but the activations are dynamically quantized during inference. This is used for situations where the model execution time is dominated by loading weights from memory rather than computing the matrix multiplications. This is true for for LSTM and Transformer type models with small batch size. Applying dynamic quantization to a whole model can be done with a single call to torch.quantization.quantize_dynamic(). See the quantization tutorials

  2. Post Training Static Quantization: This is the most commonly used form of quantization where the weights are quantized ahead of time and the scale factor and bias for the activation tensors is pre-computed based on observing the behavior of the model during a calibration process. Post Training Quantization is typically when both memory bandwidth and compute savings are important with CNNs being a typical use case. The general process for doing post training quantization is:

    1. Prepare the model: a. Specify where the activations are quantized and dequantized explicitly by adding QuantStub and DeQuantStub modules. b. Ensure that modules are not reused. c. Convert any operations that require requantization into modules

    2. Fuse operations like conv + relu or conv+batchnorm + relu together to improve both model accuracy and performance.

    3. Specify the configuration of the quantization methods ‘97 such as selecting symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques.

    4. Use the torch.quantization.prepare() to insert modules that will observe activation tensors during calibration

    5. Calibrate the model by running inference against a calibration dataset

    6. Finally, convert the model itself with the torch.quantization.convert() method. This does several things: it quantizes the weights, computes and stores the scale and bias value to be used each activation tensor, and replaces key operators quantized implementations.

    See the quantization tutorials

  3. Quantization Aware Training: In the rare cases where post training quantization does not provide adequate accuracy training can be done with simulated quantization using the torch.quantization.FakeQuantize. Computations will take place in FP32 but with values clamped and rounded to simulate the effects of INT8 quantization. The sequence of steps is very similar.

    1. Steps (1) and (2) are identical.

    1. Specify the configuration of the fake quantization methods ‘97 such as selecting symmetric or asymmetric quantization and MinMax or Moving Average or L2Norm calibration techniques.

    2. Use the torch.quantization.prepare_qat() to insert modules that will simulate quantization during training.

    3. Train or fine tune the model.

    4. Identical to step (6) for post training quantization

    See the quantization tutorials

While default implementations of observers to select the scale factor and bias based on observed tensor data are provided, developers can provide their own quantization functions. Quantization can be applied selectively to different parts of the model or configured differently for different parts of the model.

We also provide support for per channel quantization for conv2d() and linear()

Quantization workflows work by adding (e.g. adding observers as .observer submodule) or replacing (e.g. converting nn.Conv2d to nn.quantized.Conv2d) submodules in the model’s module hierarchy. It means that the model stays a regular nn.Module-based instance throughout the process and thus can work with the rest of PyTorch APIs.

Model Preparation for Quantization

It is necessary to currently make some modifications to the model definition prior to quantization. This is because currently quantization works on a module by module basis. Specifically, for all quantization techniques, the user needs to:

  1. Convert any operations that require output requantization (and thus have additional parameters) from functionals to module form.

  2. Specify which parts of the model need to be quantized either by assigning `.qconfig attributes on submodules or by specifying qconfig_dict

For static quantization techniques which quantize activations, the user needs to do the following in addition:

  1. Specify where activations are quantized and de-quantized. This is done using QuantStub and DeQuantStub modules.

  2. Use torch.nn.quantized.FloatFunctional to wrap tensor operations that require special handling for quantization into modules. Examples are operations like add and cat which require special handling to determine output quantization parameters.

  3. Fuse modules: combine operations/modules into a single module to obtain higher accuracy and performance. This is done using the torch.quantization.fuse_modules() API, which takes in lists of modules to be fused. We currently support the following fusions: [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu]

torch.quantization

This module implements the functions you call directly to convert your model from FP32 to quantized form. For example the prepare() is used in post training quantization to prepares your model for the calibration step and convert() actually converts the weights to int8 and replaces the operations with their quantized counterparts. There are other helper functions for things like quantizing the input to your model and performing critical fusions like conv+relu.

Top-level quantization APIs

torch.quantization.quantize(model, run_fn, run_args, mapping=None, inplace=False)[source]

Converts a float model to quantized model.

First it will prepare the model for calibration or training, then it calls run_fn which will run the calibration step or training step, after that we will call convert which will convert the model to a quantized model.

Parameters
  • model – input model

  • run_fn – a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop

  • run_args – positional arguments for run_fn

  • inplace – carry out model transformations in-place, the original module is mutated

  • mapping – correspondence between original module types and quantized counterparts

Returns

Quantized model.

torch.quantization.quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, mapping=None, inplace=False)[source]

Converts a float model to dynamic (i.e. weights-only) quantized model.

Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.

For simplest usage provide dtype argument that can be float16 or qint8. Weight-only quantization by default is performed for layers with large weights size - i.e. Linear and RNN variants.

Fine grained control is possible with qconfig and mapping that act similarly to quantize(). If qconfig is provided, the dtype argument is ignored.

Parameters
  • module – input model

  • qconfig_spec

    Either: * A dictionary that maps from name or type of submodule to quantization

    configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute). Entries in the dictionary need to be QConfigDynamic instances.

    • A set of types and/or submodule names to apply dynamic quantization to, in which case the dtype argument is used to specifiy the bit-width

  • inplace – carry out model transformations in-place, the original module is mutated

  • mapping – maps type of a submodule to a type of corresponding dynamically quantized version with which the submodule needs to be replaced

torch.quantization.quantize_qat(model, run_fn, run_args, inplace=False)[source]

Do quantization aware training and output a quantized model

Parameters
  • model – input model

  • run_fn – a function for evaluating the prepared model, can be a function that simply runs the prepared model or a training loop

  • run_args – positional arguments for run_fn

Returns

Quantized model.

torch.quantization.prepare(model, qconfig_dict=None, inplace=False)[source]

Prepares a copy of the model for quantization calibration or quantization-aware training.

Quantization configuration can be passed as an qconfig_dict or assigned preemptively to individual submodules in .qconfig attribute.

The model will be attached with observer or fake quant modules, and qconfig will be propagated.

Parameters
  • model – input model to be modified in-place

  • qconfig_dict – dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute)

  • inplace – carry out model transformations in-place, the original module is mutated

torch.quantization.prepare_qat(model, mapping={<class 'torch.nn.modules.linear.Linear'>: <class 'torch.nn.qat.modules.linear.Linear'>, <class 'torch.nn.modules.conv.Conv2d'>: <class 'torch.nn.qat.modules.conv.Conv2d'>, <class 'torch.nn.intrinsic.modules.fused.ConvBn2d'>: <class 'torch.nn.intrinsic.qat.modules.conv_fused.ConvBn2d'>, <class 'torch.nn.intrinsic.modules.fused.ConvBnReLU2d'>: <class 'torch.nn.intrinsic.qat.modules.conv_fused.ConvBnReLU2d'>, <class 'torch.nn.intrinsic.modules.fused.ConvReLU2d'>: <class 'torch.nn.intrinsic.qat.modules.conv_fused.ConvReLU2d'>, <class 'torch.nn.intrinsic.modules.fused.LinearReLU'>: <class 'torch.nn.intrinsic.qat.modules.linear_relu.LinearReLU'>}, inplace=False)[source]
torch.quantization.convert(module, mapping=None, inplace=False)[source]

Converts the float module with observers (where we can get quantization parameters) to a quantized module.

Parameters
  • module – calibrated module with observers

  • mapping – a dictionary that maps from float module type to quantized module type, can be overwrritten to allow swapping user defined Modules

  • inplace – carry out model transformations in-place, the original module is mutated

class torch.quantization.QConfig[source]

Describes how to quantize a layer or a part of the network by providing settings (observer classes) for activations and weights respectively.

Note that QConfig needs to contain observer classes (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization preparation function will instantiate observers multiple times for each of the layers.

Observer classes have usually reasonable default arguments, but they can be overwritten with with_args method (that behaves like functools.partial):

my_qconfig = QConfig(activation=MinMaxObserver.with_args(dtype=torch.qint8),

weight=default_observer.with_args(dtype=torch.qint8))

class torch.quantization.QConfigDynamic[source]

Describes how to dynamically quantize a layer or a part of the network by providing settings (observer classe) for weights.

It’s like QConfig, but for dynamic quantization.

Note that QConfigDynamic needs to contain observer classes (like MinMaxObserver) or a callable that returns instances on invocation, not the concrete observer instances themselves. Quantization function will instantiate observers multiple times for each of the layers.

Observer classes have usually reasonable default arguments, but they can be overwritten with with_args method (that behaves like functools.partial):

my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))

Preparing model for quantization

torch.quantization.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>)[source]

Fuses a list of modules into a single module

Fuses only the following sequence of modules: conv, bn conv, bn, relu conv, relu linear, relu All other sequences are left unchanged. For these sequences, replaces the first item in the list with the fused module, replacing the rest of the modules with identity.

Parameters
  • model – Model containing the modules to be fused

  • modules_to_fuse – list of list of module names to fuse. Can also be a list of strings if there is only a single list of modules to fuse.

  • inplace – bool specifying if fusion happens in place on the model, by default a new model is returned

  • fuser_func – Function that takes in a list of modules and outputs a list of fused modules of the same length. For example, fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()] Defaults to torch.quantization.fuse_known_modules

Returns

model with fused modules. A new copy is created if inplace=True.

Examples:

>>> m = myModel()
>>> # m is a module containing  the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)

>>> m = myModel()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)
class torch.quantization.QuantStub(qconfig=None)[source]

Quantize stub module, before calibration, this is same as an observer, it will be swapped as nnq.Quantize in convert.

Parameters

qconfig – quantization configuration for the tensor, if qconfig is not provided, we will get qconfig from parent modules

class torch.quantization.DeQuantStub[source]

Dequantize stub module, before calibration, this is same as identity, this will be swapped as nnq.DeQuantize in convert.

class torch.quantization.QuantWrapper(module)[source]

A wrapper class that wraps the input module, adds QuantStub and DeQuantStub and surround the call to module with call to quant and dequant modules.

This is used by the quantization utility functions to add the quant and dequant modules, before convert function QuantStub will just be observer, it observes the input tensor, after convert, QuantStub will be swapped to nnq.Quantize which does actual quantization. Similarly for DeQuantStub.

torch.quantization.add_quant_dequant(module)[source]

Wrap the leaf child module in QuantWrapper if it has a valid qconfig Note that this function will modify the children of module inplace and it can return a new module which wraps the input module as well.

Parameters

module – input module with qconfig attributes for all the leaf modules that we want to quantize

Returns

Either the inplace modified module with submodules wrapped in QuantWrapper based on qconfig or a new QuantWrapper module which wraps the input module, the latter case only happens when the input module is a leaf module and we want to quantize it.

Utility functions

torch.quantization.add_observer_(module)[source]

Add observer for the leaf child of the module.

This function insert observer module to all leaf child module that has a valid qconfig attribute.

Parameters

module – input module with qconfig attributes for all the leaf modules that we want to quantize

Returns

None, module is modified inplace with added observer modules and forward_hooks

torch.quantization.swap_module(mod, mapping)[source]

Swaps the module if it has a quantized counterpart and it has an observer attached.

Parameters
  • mod – input module

  • mapping – a dictionary that maps from nn module to nnq module

Returns

The corresponding quantized module of mod

torch.quantization.propagate_qconfig_(module, qconfig_dict=None)[source]

Propagate qconfig through the module hierarchy and assign qconfig attribute on each leaf module

Parameters
  • module – input module

  • qconfig_dict – dictionary that maps from name or type of submodule to quantization configuration, qconfig applies to all submodules of a given module unless qconfig for the submodules are specified (when the submodule already has qconfig attribute)

Returns

None, module is modified inplace with qconfig attached

torch.quantization.default_eval_fn(model, calib_data)[source]

Default evaluation function takes a torch.utils.data.Dataset or a list of input Tensors and run the model on the dataset

Observers

class torch.quantization.Observer(dtype)[source]

Observer base Module. Any observer implementation should derive from this class.

Concrete observers should follow the same API. In forward, they will update the statistics of the observed Tensor. And they should provide a calculate_qparams function that computes the quantization parameters given the collected statistics.

classmethod with_args(**kwargs)

Wrapper around functools.partial that allows chaining.

Often you want to assign it to a class as a class method:

Foo.with_args = classmethod(_with_args) Foo.with_args(x=1).with_args(y=2)

class torch.quantization.MinMaxObserver(**kwargs)[source]

Default Observer Module A default implementation of the observer module, only works for per_tensor_affine quantization scheme. The module will record the running average of max and min value of the observed Tensor and calculate_qparams will calculate scale and zero_point

class torch.quantization.PerChannelMinMaxObserver(ch_axis=0, **kwargs)[source]

Per Channel Observer Module The module will record the running average of max and min value for each channel of the observed Tensor and calculate_qparams will calculate scales and zero_points for each channel

class torch.quantization.MovingAveragePerChannelMinMaxObserver(averaging_constant=0.01, **kwargs)[source]

Per Channel Observer Module The module will record the running average of max and min value for each channel of the observed Tensor and calculate_qparams will calculate scales and zero_points for each channel

class torch.quantization.HistogramObserver(bins=2048, **kwargs)[source]

The module records the running histogram of tensor values along with min/max values. calculate_qparams will calculate scale and zero_point

class torch.quantization.FakeQuantize(observer=<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, **observer_kwargs)[source]

Simulate the quantize and dequantize operations in training time. The output of this module is given by

x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale

  • scale defines the scale factor used for quantization.

  • zero_point specifies the quantized value to which 0 in floating point maps to

  • quant_min specifies the minimum allowable quantized value.

  • quant_max specifies the maximum allowable quantized value.

  • fake_quant_enable controls the application of fake quantization on tensors, note that statistics can still be updated.

  • observer_enable controls statistics collection on tensors

  • dtype specifies the quantized dtype that is being emulated with fake-quantization,

    allowable values are torch.qint8 and torch.quint8. The values of quant_min and quant_max should be chosen to be consistent with the dtype

Parameters
  • observer (module) – Module for observing statistics on input tensors and calculating scale and zero-point.

  • quant_min (int) – The minimum allowable quantized value.

  • quant_max (int) – The maximum allowable quantized value.

  • observer_kwargs (optional) – Arguments for the observer module

Variables

~FakeQuantize.observer (Module) – User provided module that collects statistics on the input tensor and provides a method to calculate scale and zero-point.

class torch.quantization.NoopObserver(dtype=torch.float16)[source]

Observer that doesn’t do anything and just passes its configuration to the quantized module’s ``.from_float()`.

Primarily used for quantization to float16 which doesn’t require determining ranges.

Debugging utilities

torch.quantization.get_observer_dict(mod, target_dict, prefix='')[source]

Traverse the modules and save all observers into dict. This is mainly used for quantization accuracy debug :param mod: the top module we want to save all observers :param prefix: the prefix for the current module :param target_dict: the dictionary used to save all the observers

class torch.quantization.RecordingObserver(**kwargs)[source]

The module is mainly for debug and records the tensor values during runtime

torch.nn.instrinsic

This module implements the combined (fused) modules conv + relu which can be then quantized.

ConvBn2d

class torch.nn.intrinsic.ConvBn2d(conv, bn)[source]

ConvBnReLU2d

class torch.nn.intrinsic.ConvBnReLU2d(conv, bn, relu)[source]

ConvReLU2d

class torch.nn.intrinsic.ConvReLU2d(conv, relu)[source]

LinearReLU

class torch.nn.intrinsic.LinearReLU(linear, relu)[source]

torch.nn.instrinsic.qat

This module implements the versions of those fused operations needed for quantization aware training.

ConvBn2d

class torch.nn.intrinsic.qat.ConvBn2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]

A ConvBn2d module is a module fused from Conv2d and BatchNorm2d, attached with FakeQuantize modules for both output activation and weight, used in quantization aware training.

We combined the interface of torch.nn.Conv2d and torch.nn.BatchNorm2d.

Implementation details: https://arxiv.org/pdf/1806.08342.pdf section 3.2.2

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

Variables
  • ~ConvBn2d.freeze_bn

  • ~ConvBn2d.observer – fake quant module for output activation, it’s called observer to align with post training flow

  • ~ConvBn2d.weight_fake_quant – fake quant module for weight

classmethod from_float(mod, qconfig=None)[source]

Create a qat module from a float module or qparams_dict

Args: mod a float module, either produced by torch.quantization utilities or directly from user

ConvBnReLU2d

class torch.nn.intrinsic.qat.ConvBnReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', eps=1e-05, momentum=0.1, freeze_bn=False, qconfig=None)[source]

A ConvBnReLU2d module is a module fused from Conv2d, BatchNorm2d and ReLU, attached with FakeQuantize modules for both output activation and weight, used in quantization aware training.

We combined the interface of torch.nn.Conv2d and torch.nn.BatchNorm2d and torch.nn.ReLU.

Implementation details: https://arxiv.org/pdf/1806.08342.pdf

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

Variables
  • ~ConvBnReLU2d.observer – fake quant module for output activation, it’s called observer to align with post training flow

  • ~ConvBnReLU2d.weight_fake_quant – fake quant module for weight

ConvReLU2d

class torch.nn.intrinsic.qat.ConvReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)[source]

A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with FakeQuantize modules for both output activation and weight for quantization aware training.

We combined the interface of Conv2d and BatchNorm2d.

Variables
  • ~ConvReLU2d.observer – fake quant module for output activation, it’s called observer to align with post training flow

  • ~ConvReLU2d.weight_fake_quant – fake quant module for weight

LinearReLU

class torch.nn.intrinsic.qat.LinearReLU(in_features, out_features, bias=True, qconfig=None)[source]

A LinearReLU module fused from Linear and ReLU modules, attached with FakeQuantize modules for output activation and weight, used in quantization aware training.

We adopt the same interface as torch.nn.Linear.

Similar to torch.nn.intrinsic.LinearReLU, with FakeQuantize modules initialized to default.

Variables
  • ~LinearReLU.observer – fake quant module for output activation, it’s called observer to align with post training flow, TODO: rename?

  • ~LinearReLU.weight – fake quant module for weight

Examples:

>>> m = nn.qat.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

torch.nn.intrinsic.quantized

This module implements the quantized implementations of fused operations like conv + relu.

ConvReLU2d

class torch.nn.intrinsic.quantized.ConvReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]

A ConvReLU2d module is a fused module of Conv2d and ReLU

We adopt the same interface as torch.nn.quantized.Conv2d.

Variables

as torch.nn.quantized.Conv2d (Same) –

LinearReLU

class torch.nn.intrinsic.quantized.LinearReLU(in_features, out_features, bias=True)[source]

A LinearReLU module fused from Linear and ReLU modules

We adopt the same interface as torch.nn.quantized.Linear.

Variables

as torch.nn.quantized.Linear (Same) –

Examples:

>>> m = nn.intrinsic.LinearReLU(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

torch.nn.qat

This module implements versions of the key nn modules Conv2d() and Linear() which run in FP32 but with rounding applied to simulate the effect of INT8 quantization.

Conv2d

class torch.nn.qat.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)[source]

A Conv2d module attached with FakeQuantize modules for both output activation and weight, used for quantization aware training.

We adopt the same interface as torch.nn.Conv2d, please see https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d for documentation.

Similar to torch.nn.Conv2d, with FakeQuantize modules initialized to default.

Variables
  • ~Conv2d.observer – fake quant module for output activation, it’s called observer to align with post training flow

  • ~Conv2d.weight_fake_quant – fake quant module for weight

classmethod from_float(mod, qconfig=None)[source]

Create a qat module from a float module or qparams_dict

Args: mod a float module, either produced by torch.quantization utilities or directly from user

Linear

class torch.nn.qat.Linear(in_features, out_features, bias=True, qconfig=None)[source]

A linear module attached with FakeQuantize modules for both output activation and weight, used for quantization aware training.

We adopt the same interface as torch.nn.Linear, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.

Similar to torch.nn.Linear, with FakeQuantize modules initialized to default.

Variables
  • ~Linear.observer – fake quant module for output activation, it’s called observer to align with post training flow

  • ~Linear.weight – fake quant module for weight

classmethod from_float(mod, qconfig=None)[source]

Create a qat module from a float module or qparams_dict

Args: mod a float module, either produced by torch.quantization utilities or directly from user

torch.nn.quantized

This module implements the quantized versions of the nn layers such as Conv2d and ReLU.

Functional interface

Functional interface (quantized).

torch.nn.quantized.functional.relu(input, inplace=False) → Tensor[source]

Applies the rectified linear unit function element-wise. See ReLU for more details.

Parameters
  • input – quantized input

  • inplace – perform the computation inplace

torch.nn.quantized.functional.linear(input, weight, bias=None, scale=None, zero_point=None)[source]

Applies a linear transformation to the incoming quantized data: y=xAT+by = xA^T + b . See Linear

Note

Current implementation packs weights on every call, which has penalty on performance. If you want to avoid the overhead, use Linear.

Parameters
  • input (Tensor) – Quantized input of type torch.quint8

  • weight (Tensor) – Quantized weight of type torch.qint8

  • bias (Tensor) – None or fp32 bias of type torch.float

  • scale (double) – output scale. If None, derived from the input scale

  • zero_point (long) – output zero point. If None, derived from the input zero_point

Shape:
  • Input: (N,,in_features)(N, *, in\_features) where * means any number of additional dimensions

  • Weight: (out_features,in_features)(out\_features, in\_features)

  • Bias: (out_features)(out\_features)

  • Output: (N,,out_features)(N, *, out\_features)

torch.nn.quantized.functional.conv2d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', scale=1.0, zero_point=0, dtype=torch.quint8)[source]

Applies a 2D convolution over a quantized 2D input composed of several input planes.

See Conv2d for details and output shape.

Parameters
  • input – quantized input tensor of shape (minibatch,in_channels,iH,iW)(\text{minibatch} , \text{in\_channels} , iH , iW)

  • weight – quantized filters of shape (out_channels,in_channelsgroups,kH,kW)(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)

  • biasnon-quantized bias tensor of shape (out_channels)(\text{out\_channels}) . The tensor type must be torch.float.

  • stride – the stride of the convolving kernel. Can be a single number or a tuple (sH, sW). Default: 1

  • padding – implicit paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0

  • dilation – the spacing between kernel elements. Can be a single number or a tuple (dH, dW). Default: 1

  • groups – split input into groups, in_channels\text{in\_channels} should be divisible by the number of groups. Default: 1

  • padding_mode – the padding mode to use. Only “zeros” is supported for quantized convolution at the moment. Default: “zeros”

  • scale – quantization scale for the output. Default: 1.0

  • zero_point – quantization zero_point for the output. Default: 0

  • dtype – quantization data type to use. Default: torch.quint8

Examples:

>>> from torch.nn.quantized import functional as qF
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
>>> bias = torch.randn(4, dtype=torch.float)
>>>
>>> scale, zero_point = 1.0, 0
>>> dtype = torch.quint8
>>>
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype)
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype)
>>> qF.conv2d(q_inputs, q_filters, bias, scale, zero_point, padding=1)
torch.nn.quantized.functional.max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)[source]

Applies a 2D max pooling over a quantized input signal composed of several quantized input planes.

Note

The input quantization parameters are propagated to the output.

See MaxPool2d for details.

ReLU

class torch.nn.quantized.ReLU(inplace=False)[source]

Applies quantized rectified linear unit function element-wise:

ReLU(x)=max(x0,x)\text{ReLU}(x)= \max(x_0, x) , where x0x_0 is the zero point.

Please see https://pytorch.org/docs/stable/nn.html#torch.nn.ReLU for more documentation on ReLU.

Parameters

inplace – (Currently not supported) can optionally do the operation in-place.

Shape:
  • Input: (N,)(N, *) where * means, any number of additional dimensions

  • Output: (N,)(N, *) , same shape as the input

Examples:

>>> m = nn.quantized.ReLU()
>>> input = torch.randn(2)
>>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
>>> output = m(input)

ReLU6

class torch.nn.quantized.ReLU6(inplace=False)[source]

Applies the element-wise function:

ReLU6(x)=min(max(x0,x),q(6))\text{ReLU6}(x) = \min(\max(x_0, x), q(6)) , where x0x_0 is the zero_point, and q(6)q(6) is the quantized representation of number 6.

Parameters

inplace – can optionally do the operation in-place. Default: False

Shape:
  • Input: (N,)(N, *) where * means, any number of additional dimensions

  • Output: (N,)(N, *) , same shape as the input

_images/ReLU6.png

Examples:

>>> m = nn.quantized.ReLU6()
>>> input = torch.randn(2)
>>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
>>> output = m(input)

Conv2d

class torch.nn.quantized.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')[source]

Applies a 2D convolution over a quantized input signal composed of several quantized input planes.

For details on input arguments, parameters, and implementation see Conv2d.

Note

Only zeros is supported for the padding_mode argument.

Note

Only torch.quint8 is supported for the input data type.

Variables
  • ~Conv2d.weight (Tensor) – packed tensor derived from the learnable weight parameter.

  • ~Conv2d.scale (Tensor) – scalar for the output scale

  • ~Conv2d.zero_point (Tensor) – scalar for the output zero point

See Conv2d for other attributes.

Examples:

>>> # With square kernels and equal stride
>>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> # quantize input to qint8
>>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.qint32)
>>> output = m(input)
classmethod from_float(mod)[source]

Creates a quantized module from a float module or qparams_dict.

Parameters

mod (Module) – a float module, either produced by torch.quantization utilities or provided by the user

FloatFunctional

class torch.nn.quantized.FloatFunctional[source]

State collector class for float operatitons.

The instance of this class can be used instead of the torch. prefix for some operations. See example usage below.

Note

This class does not provide a forward hook. Instead, you must use one of the underlying functions (e.g. add).

Valid operation names:
  • add

  • cat

  • mul

  • add_relu

  • add_scalar

  • mul_scalar

QFunctional

class torch.nn.quantized.QFunctional[source]

Wrapper class for quantized operatitons.

The instance of this class can be used instead of the torch.ops.quantized prefix. See example usage below.

Note

This class does not provide a forward hook. Instead, you must use one of the underlying functions (e.g. add).

Valid operation names:
  • add

  • cat

  • mul

  • add_relu

  • add_scalar

  • mul_scalar

Quantize

class torch.nn.quantized.Quantize(scale, zero_point, dtype)[source]

Quantizes an incoming tensor :param out_scale: scale of the output Quantized Tensor :param out_zero_point: zero_point of output Quantized Tensor :param out_dtype: data type of output Quantized Tensor

Variables

out_zero_point, out_dtype (`out_scale`,) –

Examples::
>>> t = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> qt = qm(t)
>>> print(qt)
tensor([[ 1., -1.],
        [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)

DeQuantize

class torch.nn.quantized.DeQuantize[source]

Dequantizes an incoming tensor

Examples::
>>> input = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> quantized_input = qm(input)
>>> dqm = DeQuantize()
>>> dequantized = dqm(quantized_input)
>>> print(dequantized)
tensor([[ 1., -1.],
        [ 1., -1.]], dtype=torch.float32)

Linear

class torch.nn.quantized.Linear(in_features, out_features, bias_=True)[source]

A quantized linear module with quantized tensor as inputs and outputs. We adopt the same interface as torch.nn.Linear, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.

Similar to Linear, attributes will be randomly initialized at module creation time and will be overwritten later

Variables
  • ~Linear.weight (Tensor) – the non-learnable quantized weights of the module of shape (out_features,in_features)(\text{out\_features}, \text{in\_features}) .

  • ~Linear.bias (Tensor) – the non-learnable bias of the module of shape (out_features)(\text{out\_features}) . If bias is True, the values are initialized to zero.

  • ~Linear.scalescale parameter of output Quantized Tensor, type: double

  • ~Linear.zero_pointzero_point parameter for output Quantized Tensor, type: long

Examples:

>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
classmethod from_float(mod)[source]

Create a quantized module from a float module or qparams_dict

Parameters

mod (Module) – a float module, either produced by torch.quantization utilities or provided by the user

torch.nn.quantized.dynamic

Linear

class torch.nn.quantized.dynamic.Linear(in_features, out_features, bias_=True)[source]

A dynamic quantized linear module with quantized tensor as inputs and outputs. We adopt the same interface as torch.nn.Linear, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.

Similar to torch.nn.Linear, attributes will be randomly initialized at module creation time and will be overwritten later

Variables
  • ~Linear.weight (Tensor) – the non-learnable quantized weights of the module which are of shape (out_features,in_features)(\text{out\_features}, \text{in\_features}) .

  • ~Linear.bias (Tensor) – the non-learnable bias of the module of shape (out_features)(\text{out\_features}) . If bias is True, the values are initialized to zero.

  • ~Linear.scalescale parameter of weight Quantized Tensor, type: double

  • ~Linear.zero_pointzero_point parameter for weight Quantized Tensor, type: long

Examples:

>>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
classmethod from_float(mod)[source]

Create a dynamic quantized module from a float module or qparams_dict

Parameters

mod (Module) – a float module, either produced by torch.quantization utilities or provided by the user

LSTM

class torch.nn.quantized.dynamic.LSTM(*args, **kwargs)[source]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources