Shortcuts

Contributor Guide

General Guide on Extending torchao

Please start by reading our quantization overview page first.

To contribute to existing code base:

Adding New Tensor Subclasses

torchao Tensor subclasses are structured by derived dtype and packing format, please check out the quantization overview page to understand these concepts. If a new tensor subclass is needed for your use case, i.e. a new dtype, or a new packing format that does not already exist, we could define a new Tensor.

To understand how to use tensor subclass in the context of quantization, please also check Writing Your Own Quantized Tensor.

We have utility base class: torchao.utils.TorchAOBaseTensor that can help define common util functions and methods for you, if you specified the names of Tensor and non-Tensor attributes of the tensor subclass. for example:

class MyTensor(TorchAOBaseTensor):
    tensor_data_names = ["qdata", "scale"]
    tensor_attribute_names = ["device", "dtype"]

With the above, we’ll have multiple methods and functions available to use for this Tensor, for more details please check the docs for TorchAOBaseTensor

Note

Many of the existing use cases in torchao still uses AffineQuantizedTensor, but we plan to move away from it to reduce the abstractions and make it easier for people to contribute to torchao.

Adding Efficient Kernels

Custom triton kernels

Custom triton kernels can be implemented and registered in torchao/kernel

You may need to define you own autotuner as well.

Custom hand written kernels

Custom kernels (implementations) for cpu/cuda/mps can be implemented through torchao/csrc e.g. int4 cuda, and accessible through torch.ops.my_custom_op

Using hand written kernels in Tensor Subclasses

For calling optimized kernels, we have implements from the tensor subclass, for example, if we want to call into a new custom op: torch.ops.torchao.my_mm_for_mps:

class Float8Tensor(TorchAOBaseTensor):
    ...

implements = Float8Tensor.implements

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
    ...
    # call into the custom op
    res = torch.ops.torchao.my_mm_for_mps(input_tensor.qdata, weight_tensor.qdata, input_tensor.scale, weight_tensor.scale)
    return res

KernelPreference

For some tensor subclasses, there could be multiple kernel choices for quantize and mm etc. The recommended way to handle this in torchao tensor subclasses is through KernelPreference, that represents which group of kernels we want to use for quantize, mm, group_mm etc. We can use use KernelPreference.AUTO as default option, as the option for developers to choose whatever we think is the fastest under different conditions for user, so user don’t need to worry about the details, and we can have other more specific kernel options for debugging purposes.

Float8Tensor for example, has:

  • KernelPreference.AUTO that will choose the most performant quantize and mm kernel based on hardware (H100 SM89 or SM90+), availability of libraries (whether fbgemm_gpu_genai is installed), granularity (per row or per tensor)

  • KernelPreference.TORCH will use torchao quantize op (_choose_scale_float8 and _quantize_affine_float8) and _scaled_mm

  • Kerenel.FBGEMM uses fbgemm quantize and mm op (torch.ops.fbgemm.f8f8bf16_rowwise)

Flow

For model level API, people can reuse torchao.quantize_ that allows people to apply a tensor subclass conversion to weight of linear, and allows filtering function to choose which module the tensor subclass conversion should be applied to.

See Quantization Algorithms/Flows section for examples of weight only/dynamic quant and other types of model level APIs.

Using torch.compile for Performance

In order to be compatible with torch.compile, to aim for performance optimization, we should run through torch.compile with fullgraph=True first, and remove any unnecessary graph breaks. You can add TORCH_LOGS="output_code" when you run the script in order to see the inductor generated code. e.g. TORCH_LOGS="output_code" python example.py:

model = torch.compile(model, mode="max-autotune", fullgraph=True)

Serialization

To enable support for serialization (torch.save and torch.load with tensor subclasses as weights), we need to add the tensor subclass and the relevant object to safe globals (available after torch 2.5), e.g.::

torch.serialization.add_safe_globals([Float8Tensor, QuantizeTensorToFloat8Kwargs])

Please checkout the serialization doc for more details.

Note

We are integrated with huggingface transformer and supports serialization and deserialization through the huggingface save_pretrained, push_to_hub and from_pretrained APIs. We also have serialization examples with diffuser models.

Other Feature Support

The above just talks about basic feature support, we also provide examples on how to add supports for training, tensor parallel, FSDP by extending the MyDTypeTensor, we’ll put more examples in developer_api_guide folder covering the following use cases.

Tensor Subclass Functionality/Composability Testing

We are also working on test suites to test out the functionalities of tensor subclass and the composability with different systems like torch.compile, DTensor etc. (we recommend to copy paste the tests and adapt to test your own tensor subclass for now):

Kernel Microbenchmarks

Before we test performance on models, we can also do some microbenchmarks on single linear operator (or other compute intensive/memory intensive) operators with different input dimensions to get a sense of speedup. For a specific kernel that you’d like to benchmark, you can create a benchmark file like benchmarks/benchmark_aq.py and run benchmark with different shapes that’s important for target model. A quick way to get the relevant shape for linear op and other ops is by running the example with this.

Change the model with the model you are interested in optimizing, and run the following:

python tutorials/developer_api_guide/print_op_and_shapes.py

Example output:

TORCH_FUNC=<built-in function linear> (M, K, N): 10 10 10
TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects> args[0] shape: torch.Size([10, 10])

all linear shapes (M, K, N): [(10, 10, 10)]

The output of all linear shapes can be copy pasted to microbenchmarking script code under benchmarks/benchmark_your_kernel.py for benchmarking.

For benchmark helper functions, right now we have 1 and 2, feel free to use either one for now, but we’ll probably keep one in the future.

Model Benchmarks and Eval

After you have the quantization flow implemented, you can run benchmark and eval on llama (llama2/llama3) or sam models that are already modified to be friendly to torch.compile, and compare with existing techniques in torchao.

Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models.

Please checkout the --help option for each of the script to understand the supported options, e.g. you can use --profile=profile_path to get the chrome trace of the run to understand detailed chrome trace.

Please let us know if there are any new important models that makes sense to be added to torchao model benchmark/eval folder.

Please also check out Benchmarking User Guide and Benchmarking API Guide to understand how to use our benchmarking framework.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources