Shortcuts

AffineQuantizedTensor

class torchao.dtypes.AffineQuantizedTensor(tensor_impl: AQTTensorImpl, block_size: Tuple[int, ...], shape: Size, quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, dtype=None, strides=None)[source]

Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: quantized_tensor = float_tensor / scale + zero_point

To see what happens during choose_qparams, quantization and dequantization for affine quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py and check the three quant primitive ops: choose_qparams_affine, quantize_affine qand dequantize_affine

The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, regardless of the internal representation’s type or orientation.

fields:
  • tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data,

    e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel

  • block_size (Tuple[int, …]): granularity of quantization, this means the size of the tensor elements that’s sharing the same qparam

    e.g. when size is the same as the input tensor dimension, we are using per tensor quantization

  • shape (torch.Size): the shape for the original high precision Tensor

  • quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of int_data

  • quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of int_data

  • zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float

    if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization default is ZeroPointDomain.INT

  • dtype: dtype for original high precision tensor, e.g. torch.float32

dequantize() Tensor[source]

Given a quantized Tensor, dequantize it and return the dequantized float Tensor.

classmethod from_hp_to_floatx(input_float: Tensor, block_size: Tuple[int, ...], target_dtype: dtype, _layout: Layout, scale_dtype: Optional[dtype] = None)[source]

Convert a high precision tensor to a float8 quantized tensor.

classmethod from_hp_to_floatx_static(input_float: Tensor, scale: Tensor, block_size: Tuple[int, ...], target_dtype: dtype, _layout: Layout)[source]

Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters.

classmethod from_hp_to_fpx(input_float: Tensor, _layout: Layout)[source]

Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.

classmethod from_hp_to_intx(input_float: Tensor, mapping_type: MappingType, block_size: Tuple[int, ...], target_dtype: dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, eps: Optional[float] = None, scale_dtype: Optional[dtype] = None, zero_point_dtype: Optional[dtype] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False)[source]

Convert a high precision tensor to an integer affine quantized tensor.

classmethod from_hp_to_intx_static(input_float: Tensor, scale: Tensor, zero_point: Optional[Tensor], block_size: Tuple[int, ...], target_dtype: dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout())[source]

Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.

to(*args, **kwargs) Tensor[source]

Performs Tensor dtype and/or device conversion. A torch.dtype and torch.device are inferred from the arguments of self.to(*args, **kwargs).

Note

If the self Tensor already has the correct torch.dtype and torch.device, then self is returned. Otherwise, the returned tensor is a copy of self with the desired torch.dtype and torch.device.

Here are the ways to call to:

to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) Tensor[source]

Returns a Tensor with the specified dtype

Args:

memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.preserve_format.

to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) Tensor[source]

Returns a Tensor with the specified device and (optional) dtype. If dtype is None it is inferred to be self.dtype. When non_blocking is set to True, the function attempts to perform the conversion asynchronously with respect to the host, if possible. This asynchronous behavior applies to both pinned and pageable memory. However, caution is advised when using this feature. For more information, refer to the tutorial on good usage of non_blocking and pin_memory. When copy is set, a new Tensor is created even when the Tensor already matches the desired conversion.

Args:

memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.preserve_format.

to(other, non_blocking=False, copy=False) Tensor[source]

Returns a Tensor with same torch.dtype and torch.device as the Tensor other. When non_blocking is set to True, the function attempts to perform the conversion asynchronously with respect to the host, if possible. This asynchronous behavior applies to both pinned and pageable memory. However, caution is advised when using this feature. For more information, refer to the tutorial on good usage of non_blocking and pin_memory. When copy is set, a new Tensor is created even when the Tensor already matches the desired conversion.

Example:

>>> tensor = torch.randn(2, 2)  # Initially dtype=float32, device=cpu
>>> tensor.to(torch.float64)
tensor([[-0.5044,  0.0005],
        [ 0.3310, -0.0584]], dtype=torch.float64)

>>> cuda0 = torch.device('cuda:0')
>>> tensor.to(cuda0)
tensor([[-0.5044,  0.0005],
        [ 0.3310, -0.0584]], device='cuda:0')

>>> tensor.to(cuda0, dtype=torch.float64)
tensor([[-0.5044,  0.0005],
        [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')

>>> other = torch.randn((), dtype=torch.float64, device=cuda0)
>>> tensor.to(other, non_blocking=True)
tensor([[-0.5044,  0.0005],
        [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources