Shortcuts

TorchAOBaseTensor

class torchao.utils.TorchAOBaseTensor[source]
A util tensor subclass that provides commonly used functions

new tensor subclass can inherit it to get all the utility functions

class MyTensor(TorchAOBaseTensor):

pass

This includes:
_get_to_kwargs that can get the kwargs for to
class MyTensor(TorchAOBaseTensor):
def to(self, *args, **kwargs):

kwargs = _get_to_kwargs(*args, **kwargs) …

implements:

implements = MyTensor.implements

@implements(torch.nn.functional.linear): def _(func, types, args, kwargs):

register_layout:

register_layout = MyTensor.register_layout

@register_layout(PlainLayout) class PlainAQTTensorImpl(…):

get_tensor_impl_constructor:

get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # in constructor of MyTensor: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

class variables to define to simplify implmentation of tensor subclasses:
tensor_data_names (List[str]): list of names of all requires tensor_data, order should match

the __init__ list of tensor subclass

optional_tensor_data_names (List[str]): it’s optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional tensor_attribute_names (List[str]): list of names of non-Tensor attributes,

order should match the __init__ list of tensor subclass, following all the tensor_data_names arguments and optional_tensor_data_names

If tensor_data_names and tensor_attribute_names are defined, there are some additional functions that will be added, this includes: __tensor_flatten__: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data,

second element is a list of non-Tensor attributes

__tensor_unflatten__: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor _apply_fn_to_data: takes a function (Tensor -> Tensor), applies function to all tensor data and

recreate a new subclassed Tensor with the transformed tensor data

__repr__: the string representation of the subclassed tensor instance torch ops: torch.Tensor.contiguous aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)

Example

class MyTensor(torch.Tensor):

tensor_data_names = [“a”, “b”] optional_tensor_data_names = [“c”, “d”] tensor_attribute_names = [“e”, “f”]

def __new__(

cls, a: Tensor, b: Tensor, c: Optional[Tensor], d: Optional[Tensor], e: int, f: str

):

pass

def __init__(

self, a: Tensor, b: Tensor, c: Optional[Tensor], d: Optional[Tensor], e: int, f: str

):

pass

classmethod get_tensor_impl_constructor(layout_class: Callable) Callable

Get TensorImpl class constructor (TensorImplClass.from_plain) for tensor_class based on layout_class layout_class means the class type of subclass of Layout, e.g. PlainLayout

Parameters:
  • tensor_class – Tensor subclass type

  • layout_class – the class type of subclass of Layout, e.g. PlainLayout

Returns:

tensor impl subclass constructor for the layout_class

classmethod implements(aten_ops_or_torch_fns)

Use this decorator to implement a function for an aten ops in __torch_dispatch__ (if user passed in a list of ops) or torch function in __torch_function__ (if user passed in a single object)

class MyTensor(torch.Tensor):

… implements = classmethod(_implements)

implements = MyTensor.implements

@implements(torch.nn.functional.linear): def _(func, types, args, kwargs):

classmethod register_layout(layout_class: Callable)

Helper function for layout registrations, this is used to implement register_layout decorator for each tensor subclass, see aqt.py for example usage

Parameters:
  • tensor_class – Tensor subclass type

  • layout_class – the class type of subclass of Layout, e.g. PlainLayout

Returns:

a decorator that registers the tensor impl constructor in the table

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources