Shortcuts

Source code for torchao.sparsity.sparse_api

from functools import partial
from typing import Callable, Optional

import torch
from torch.sparse import to_sparse_semi_structured

from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
    WeightNormSparsifier,
)
from torchao.quantization.quant_api import (
    _get_linear_subclass_inserter,
    _is_linear,
    _replace_with_custom_fn_if_matches_filter,
)
from torchao.sparsity.blocksparse import BlockSparseTensor


# Sparsity helper functions
[docs]def apply_fake_sparsity(model, **kwargs): """ This function simulates 2:4 sparsity on all linear layers in a model. """ filter_fn = kwargs.pop("filter_fn", _is_linear) # torch.ao.pruning flow sparse_config = [] for name, mod in model.named_modules(): if filter_fn(mod, name): sparse_config.append({"tensor_fqn": f"{name}.weight"}) sparsifier = WeightNormSparsifier( sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 ) sparsifier.prepare(model, sparse_config) sparsifier.step() sparsifier.squash_mask()
def block_sparse_weight(blocksize=64): return _get_linear_subclass_inserter( partial(BlockSparseTensor.from_dense, blocksize=blocksize) )
[docs]def semi_sparse_weight(): """ Convert the weight of linear moduels to semi-structured (2:4) sparsity """ return _get_linear_subclass_inserter(to_sparse_semi_structured)
[docs]def sparsify_( model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass`. This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support three options for sparsity: - semi-structured (2:4) sparsity with `semi_sparse_weight` - int8 dynamic quantization + 2:4 sparsity with `layout=SemiSparseLayout` - int4 weight-only quantization + 2:4 sparsity with `layout=SparseMarlinLayout` Args: model (torch.nn.Module): input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module **Example:** :: import torch import torch.nn as nn from torchao.sparsity import sparsify_ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) # for 2:4 sparsity from torchao.sparse_api import semi_sparse_weight m = sparsify_(m, semi_sparse_weight(), filter_fn) # for int8 dynamic quantization + 2:4 sparsity from torchao.dtypes import SemiSparseLayout m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources