Rate this Page

TorchAOBaseTensor#

class torchao.utils.TorchAOBaseTensor(*args, **kwargs)[source][source]#

A util tensor subclass that provides commonly used functions new tensor subclass can inherit to get all the utility functions

Attributes (defined by subclass of TorchAOBaseTensor):

  • tensor_data_names (List[str]): list of names of all required tensor_data, order should match

    the __init__ list of tensor subclass (optional to define to use TorchAOBaseTensor, required for getting the util functions been defined (see Note section))

  • tensor_attribute_names (List[str]): list of names of non-Tensor attributes,

    order should match the __init__ list of tensor subclass, following all the tensor_data_names arguments (optional to define to use TorchAOBaseTensor, required for getting the util functions been defined (see Note section))

  • optional_tensor_data_names (List[str]): it’s not required to get the

    additional util functions been implemented, but this will be need if there are some optional Tensor data attributes, when defined, this will be a list of names of ``Tensor``s that can be optional

  • optional_tensor_attribute_names (List[str]): it’s not required to get the

    additional util functions been implemented, but this will be need if there are some optional non-Tensor attributes, when defined, this will be a list of names of attributes that can be optional

Note

Argument order in __init__ and __new__ of subclass of TorchAOBaseTensor should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present)

Note

If tensor_data_names (torch.Tensor data attribute names) and tensor_attribute_names (non-torch.Tensor attribute names) are defined, there are some additional util functions that will be added, this includes:

__tensor_flatten__: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data, second element is a dict from attribute_name to non-Tensor attributes

__tensor_unflatten__: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-Tensor attributes, returns a new instance of the subclassed tensor

_apply_fn_to_data: takes a function (Tensor -> Tensor), applies function to all tensor data and recreate a new subclassed Tensor with the transformed tensor data

__repr__: the string representation of the subclassed tensor instance

_same_metadata: returns whether the metadata is the same between two instances of cls

__setstate__: when loading a serialized tensor subclass checkpoints, it sets the new optional tensor and tensor attribute that is saved in the old checkpoint to None, to maintain BC of old checkpoints when we add new optional tensor data or attributes to the tensor subclass

torch function supported (__torch_function__): torch.Tensor.contiguous

aten ops supported (__torch_dispatch__): aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)

Examples:

class MyTensor(torch.Tensor):
    tensor_data_names = ["a", "b"]
    tensor_attribute_names = ["c", "d"]
    optional_tensor_data_names = ["e", "f"]
    optional_tensor_attribute_names = ["g", "h"]


    def __new__(
        cls,
        a: Tensor,
        b: Tensor,
        c: int,
        d: str,
        e: Optional[Tensor] = None,
        f: Optional[Tensor] = None,
        g: Optional[int] = None,
        h: Optional[int] = None,
    ):
        pass

    def __init__(
        self,
        a: Tensor,
        b: Tensor,
        c: int,
        d: str
        e: Optional[Tensor] = None,
        f: Optional[Tensor] = None,
        g: Optional[int] = None,
        h: Optional[int] = None,
    ):
        pass
get_layout()[source][source]#

Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

classmethod get_tensor_impl_constructor(layout_class: Callable) Callable[source]#

Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

Get TensorImpl class constructor (TensorImplClass.from_plain) for tensor_class based on layout_class layout_class means the class type of subclass of Layout, e.g. PlainLayout

Parameters
  • tensor_class – Tensor subclass type

  • layout_class – the class type of subclass of Layout, e.g. PlainLayout

Returns

tensor impl subclass constructor for the layout_class

classmethod implements(aten_ops)[source]#

Decorator for implementing aten ops like torch.ops.aten.linear.default for tensor subclass, the implemented functions are called in __torch_dispatch__ callback for torch.Tensor subclasses

Examples:

implements = MyTensor.implements

@implements(torch.ops.aten.linear.default):
def _(func, types, args, kwargs):
    ...
classmethod implements_torch_function(torch_fns)[source]#

Decorator for implementing torch functions / ops like torch.nn.functional.linear, torch.Tensor.t for the tensor subclass the implemented functions are called in __torch_function__ callback for torch.Tensor subclasses

Examples:

implements_torch_function = MyTensor.implements_torch_function

@implements_torch_function(torch.nn.functional.linear):
def _(func, types, args, kwargs):
    ...
classmethod register_layout(layout_class: Callable)[source]#

Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

Helper function for layout registrations, this is used to implement register_layout decorator for each tensor subclass, see aqt.py for example usage

Parameters
  • tensor_class – Tensor subclass type

  • layout_class – the class type of subclass of Layout, e.g. PlainLayout

Returns

a decorator that registers the tensor impl constructor in the table