TorchAOBaseTensor#
- class torchao.utils.TorchAOBaseTensor(*args, **kwargs)[source][source]#
A util tensor subclass that provides commonly used functions new tensor subclass can inherit to get all the utility functions
Attributes (defined by subclass of
TorchAOBaseTensor):tensor_data_names(List[str]): list of names of all required tensor_data, order should matchthe __init__ list of tensor subclass (optional to define to use
TorchAOBaseTensor, required for getting the util functions been defined (see Note section))
tensor_attribute_names(List[str]): list of names of non-Tensorattributes,order should match the
__init__list of tensor subclass, following all thetensor_data_namesarguments (optional to define to useTorchAOBaseTensor, required for getting the util functions been defined (see Note section))
optional_tensor_data_names(List[str]): it’s not required to get theadditional util functions been implemented, but this will be need if there are some optional
Tensordata attributes, when defined, this will be a list of names of ``Tensor``s that can be optional
optional_tensor_attribute_names(List[str]): it’s not required to get theadditional util functions been implemented, but this will be need if there are some optional non-
Tensorattributes, when defined, this will be a list of names of attributes that can be optional
Note
Argument order in
__init__and__new__of subclass ofTorchAOBaseTensorshould match exaclty withtensor_data_names+tensor_attribute_names+optional_tensor_data_names(if present) +optional_tensor_attribute_names(if present)Note
If
tensor_data_names(torch.Tensor data attribute names) andtensor_attribute_names(non-torch.Tensor attribute names) are defined, there are some additional util functions that will be added, this includes:__tensor_flatten__: flattens a subclassed tensor instance, returns atuple, first element is tensor data names for valid tensor data, second element is a dict from attribute_name to non-Tensorattributes__tensor_unflatten__: takes atensor_data_dict(a map from tensor name toTensor), and list of non-Tensorattributes, returns a new instance of the subclassed tensor_apply_fn_to_data: takes a function (Tensor -> Tensor), applies function to all tensor data and recreate a new subclassed Tensor with the transformed tensor data__repr__: the string representation of the subclassed tensor instance_same_metadata: returns whether the metadata is the same between two instances of cls__setstate__: when loading a serialized tensor subclass checkpoints, it sets the new optional tensor and tensor attribute that is saved in the old checkpoint to None, to maintain BC of old checkpoints when we add new optional tensor data or attributes to the tensor subclasstorch function supported (
__torch_function__):torch.Tensor.contiguousaten ops supported (
__torch_dispatch__):aten.detach.default,aten.clone.default,aten.alias,default,aten.contiguous.default,aten.copy_.default,aten._to_copy.default(enablest.to)Examples:
class MyTensor(torch.Tensor): tensor_data_names = ["a", "b"] tensor_attribute_names = ["c", "d"] optional_tensor_data_names = ["e", "f"] optional_tensor_attribute_names = ["g", "h"] def __new__( cls, a: Tensor, b: Tensor, c: int, d: str, e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None, ): pass def __init__( self, a: Tensor, b: Tensor, c: int, d: str e: Optional[Tensor] = None, f: Optional[Tensor] = None, g: Optional[int] = None, h: Optional[int] = None, ): pass
- get_layout()[source][source]#
Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html
- classmethod get_tensor_impl_constructor(layout_class: Callable) Callable[source]#
Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html
Get TensorImpl class constructor (TensorImplClass.from_plain) for tensor_class based on layout_class layout_class means the class type of subclass of Layout, e.g. PlainLayout
- Parameters
tensor_class – Tensor subclass type
layout_class – the class type of subclass of Layout, e.g. PlainLayout
- Returns
tensor impl subclass constructor for the layout_class
- classmethod implements(aten_ops)[source]#
Decorator for implementing aten ops like torch.ops.aten.linear.default for tensor subclass, the implemented functions are called in
__torch_dispatch__callback fortorch.TensorsubclassesExamples:
implements = MyTensor.implements @implements(torch.ops.aten.linear.default): def _(func, types, args, kwargs): ...
- classmethod implements_torch_function(torch_fns)[source]#
Decorator for implementing torch functions / ops like
torch.nn.functional.linear,torch.Tensor.tfor the tensor subclass the implemented functions are called in__torch_function__callback fortorch.TensorsubclassesExamples:
implements_torch_function = MyTensor.implements_torch_function @implements_torch_function(torch.nn.functional.linear): def _(func, types, args, kwargs): ...
- classmethod register_layout(layout_class: Callable)[source]#
Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html
Helper function for layout registrations, this is used to implement register_layout decorator for each tensor subclass, see aqt.py for example usage
- Parameters
tensor_class – Tensor subclass type
layout_class – the class type of subclass of Layout, e.g. PlainLayout
- Returns
a decorator that registers the tensor impl constructor in the table