Rate this Page

TorchAOBaseTensor#

class torchao.utils.TorchAOBaseTensor(*args, **kwargs)[source][source]#

A util tensor subclass that provides commonly used functions new tensor subclass can inherit to get all the utility functions

Attributes (defined by subclass of TorchAOBaseTensor):

  • tensor_data_names (List[str]): list of names of all required tensor_data, order should match

    the __init__ list of tensor subclass (optional to define to use TorchAOBaseTensor, required for getting the util functions been defined (see Note section))

  • tensor_attribute_names (List[str]): list of names of non-Tensor attributes,

    order should match the __init__ list of tensor subclass, following all the tensor_data_names arguments (optional to define to use TorchAOBaseTensor, required for getting the util functions been defined (see Note section))

  • optional_tensor_data_names (List[str]): it’s not required to get the

    additional util functions been implemented, but this will be need if there are some optional Tensor data attributes, when defined, this will be a list of names of ``Tensor``s that can be optional

  • optional_tensor_attribute_names (List[str]): it’s not required to get the

    additional util functions been implemented, but this will be need if there are some optional non-Tensor attributes, when defined, this will be a list of names of attributes that can be optional

Note

Argument Order

Argument order in __init__ and __new__ of subclass of TorchAOBaseTensor should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present)

Note

How to Get Predefined Util Functions

If tensor_data_names (torch.Tensor data attribute names) and tensor_attribute_names (non-torch.Tensor attribute names) are defined, there are some additional util functions that will be added, this includes:

__tensor_flatten__: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data, second element is a dict from attribute_name to non-Tensor attributes

__tensor_unflatten__: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-Tensor attributes, returns a new instance of the subclassed tensor

_apply_fn_to_data: takes a function (Tensor -> Tensor), applies function to all tensor data and recreate a new subclassed Tensor with the transformed tensor data

__repr__: the string representation of the subclassed tensor instance

_same_metadata: returns whether the metadata is the same between two instances of cls

__setstate__: when loading a serialized tensor subclass checkpoints, it sets the new optional tensor and tensor attribute that is saved in the old checkpoint to None, to maintain BC of old checkpoints when we add new optional tensor data or attributes to the tensor subclass

torch function supported (__torch_function__): torch.Tensor.contiguous

aten ops supported (__torch_dispatch__): aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)

Note

Subclassing and Op Inheritance

Subclasses of TorchAOBaseTensor automatically inherit aten op (__torch_dispatch__) and torch function (__torch_function__) implementations from their parent classes. Each subclass gets its own independent dispatch tables, so registering a new op on a child does not affect the parent, and vice versa.

A child class can override an inherited op by registering its own implementation with @ChildClass.implements(...). If no override is provided, the parent’s implementation is used automatically.

For multiple inheritance (e.g., class C(B, A)), ops are inherited from all parents following Python’s MRO (Method Resolution Order), with later bases taking priority.

Example:

class Parent(TorchAOBaseTensor):
    tensor_data_names = ["qdata"]
    tensor_attribute_names = ["attr"]

    def __new__(cls, qdata, attr):
        r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape)
        r.qdata = qdata
        r.attr = attr
        return r

    def __init__(self, qdata, attr):
        pass

@Parent.implements([torch.ops.aten.cat.default])
def parent_cat(func, types, args, kwargs):
    # parent implementation
    ...

# Child inherits Parent's aten.cat implementation automatically
class Child(Parent):
    tensor_data_names = ["qdata"]
    tensor_attribute_names = ["attr"]

# Optionally override an inherited op:
@Child.implements([torch.ops.aten.cat.default])
def child_cat(func, types, args, kwargs):
    # child-specific implementation
    ...

Note

Safetensors Support

TorchAOBaseTensor subclasses can be serialized to and loaded from the safetensors format. Since safetensors only stores plain torch.Tensor objects, the serialization layer (in torchao.prototype.safetensors) decomposes each tensor subclass into its constituent plain tensors (from tensor_data_names/optional_tensor_data_names) plus JSON metadata (from tensor_attribute_names/optional_tensor_attribute_names), and reconstructs the subclass on load.

To add safetensors support for a new TorchAOBaseTensor subclass:

  1. Define tensor_data_names, tensor_attribute_names (and optionally optional_tensor_data_names, optional_tensor_attribute_names) on the subclass, which is already required for the other utility functions above.

  2. Register the subclass in torchao/prototype/safetensors/safetensors_utils.py:

    • Add the class name string to ALLOWED_TENSORS_SUBCLASSES.

    • Add a "ClassName": ClassName entry to ALLOWED_CLASSES.

    • If the subclass has non-Tensor attributes with custom types (dataclasses, enums, named tuples), add those types to ALLOWED_CLASSES as well.

Once registered, Hugging Face Transformers users can use save_pretrained and push_to_hub with the default safe_serialization=True option. See https://docs.pytorch.org/ao/main/eager_tutorials/torchao_hf_integration.html#saving-the-model for a full end-to-end example.

Examples:

class MyTensor(TorchAOBaseTensor):
    tensor_data_names = ["a", "b"]
    tensor_attribute_names = ["c", "d"]
    optional_tensor_data_names = ["e", "f"]
    optional_tensor_attribute_names = ["g", "h"]


    def __new__(
        cls,
        a: Tensor,
        b: Tensor,
        c: int,
        d: str,
        e: Optional[Tensor] = None,
        f: Optional[Tensor] = None,
        g: Optional[int] = None,
        h: Optional[int] = None,
    ):
        pass

    def __init__(
        self,
        a: Tensor,
        b: Tensor,
        c: int,
        d: str,
        e: Optional[Tensor] = None,
        f: Optional[Tensor] = None,
        g: Optional[int] = None,
        h: Optional[int] = None,
    ):
        pass
get_layout()[source][source]#

Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

classmethod get_tensor_impl_constructor(layout_class: Callable) Callable[source]#

Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

Get TensorImpl class constructor (TensorImplClass.from_plain) for tensor_class based on layout_class layout_class means the class type of subclass of Layout, e.g. PlainLayout

Parameters:
  • tensor_class – Tensor subclass type

  • layout_class – the class type of subclass of Layout, e.g. PlainLayout

Returns:

tensor impl subclass constructor for the layout_class

classmethod implements(aten_ops)[source]#

Decorator for implementing aten ops like torch.ops.aten.linear.default for tensor subclass, the implemented functions are called in __torch_dispatch__ callback for torch.Tensor subclasses

Examples:

implements = MyTensor.implements

@implements(torch.ops.aten.linear.default):
def _(func, types, args, kwargs):
    ...
classmethod implements_torch_function(torch_fns)[source]#

Decorator for implementing torch functions / ops like torch.nn.functional.linear, torch.Tensor.t for the tensor subclass the implemented functions are called in __torch_function__ callback for torch.Tensor subclasses

Examples:

implements_torch_function = MyTensor.implements_torch_function

@implements_torch_function(torch.nn.functional.linear):
def _(func, types, args, kwargs):
    ...
classmethod register_layout(layout_class: Callable)[source]#

Deprecated since version 0.15.1: This method is deprecated as of version 0.15.1 since it’s part of the older tensor subclass development stack, for information about new dev stack, please check https://docs.pytorch.org/ao/main/quantization_overview.html and https://docs.pytorch.org/ao/main/contributor_guide.html

Helper function for layout registrations, this is used to implement register_layout decorator for each tensor subclass, see aqt.py for example usage

Parameters:
  • tensor_class – Tensor subclass type

  • layout_class – the class type of subclass of Layout, e.g. PlainLayout

Returns:

a decorator that registers the tensor impl constructor in the table