TorchAOBaseTensor¶
- class torchao.utils.TorchAOBaseTensor[source]¶
- A util tensor subclass that provides commonly used functions
new tensor subclass can inherit it to get all the utility functions
- class MyTensor(TorchAOBaseTensor):
pass
- This includes:
- _get_to_kwargs that can get the kwargs for to
- implements:
implements = MyTensor.implements
@implements(torch.nn.functional.linear): def _(func, types, args, kwargs):
…
- register_layout:
register_layout = MyTensor.register_layout
@register_layout(PlainLayout) class PlainAQTTensorImpl(…):
…
- get_tensor_impl_constructor:
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # in constructor of MyTensor: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
- class variables to define to simplify implmentation of tensor subclasses:
- tensor_data_names (List[str]): list of names of all requires tensor_data, order should match
the __init__ list of tensor subclass
optional_tensor_data_names (List[str]): it’s optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional tensor_attribute_names (List[str]): list of names of non-Tensor attributes,
order should match the __init__ list of tensor subclass, following all the tensor_data_names arguments and optional_tensor_data_names
If tensor_data_names and tensor_attribute_names are defined, there are some additional functions that will be added, this includes: __tensor_flatten__: flattens a subclassed tensor instance, returns a tuple, first element is tensor data names for valid tensor data,
second element is a list of non-Tensor attributes
__tensor_unflatten__: takes a tensor_data_dict (a map from tensor name to Tensor), and list of non-tensor attributes, returns a new instance of the subclassed tensor _apply_fn_to_data: takes a function (Tensor -> Tensor), applies function to all tensor data and
recreate a new subclassed Tensor with the transformed tensor data
__repr__: the string representation of the subclassed tensor instance torch ops: torch.Tensor.contiguous aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to)
Example
- class MyTensor(torch.Tensor):
tensor_data_names = [“a”, “b”] optional_tensor_data_names = [“c”, “d”] tensor_attribute_names = [“e”, “f”]
- def __new__(
cls, a: Tensor, b: Tensor, c: Optional[Tensor], d: Optional[Tensor], e: int, f: str
- ):
pass
- def __init__(
self, a: Tensor, b: Tensor, c: Optional[Tensor], d: Optional[Tensor], e: int, f: str
- ):
pass
- classmethod get_tensor_impl_constructor(layout_class: Callable) Callable ¶
Get TensorImpl class constructor (TensorImplClass.from_plain) for tensor_class based on layout_class layout_class means the class type of subclass of Layout, e.g. PlainLayout
- Parameters:
tensor_class – Tensor subclass type
layout_class – the class type of subclass of Layout, e.g. PlainLayout
- Returns:
tensor impl subclass constructor for the layout_class
- classmethod implements(aten_ops_or_torch_fns)¶
Use this decorator to implement a function for an aten ops in __torch_dispatch__ (if user passed in a list of ops) or torch function in __torch_function__ (if user passed in a single object)
- class MyTensor(torch.Tensor):
… implements = classmethod(_implements)
implements = MyTensor.implements
@implements(torch.nn.functional.linear): def _(func, types, args, kwargs):
…
- classmethod register_layout(layout_class: Callable)¶
Helper function for layout registrations, this is used to implement register_layout decorator for each tensor subclass, see aqt.py for example usage
- Parameters:
tensor_class – Tensor subclass type
layout_class – the class type of subclass of Layout, e.g. PlainLayout
- Returns:
a decorator that registers the tensor impl constructor in the table