Shortcuts

Source code for torchao.kernel.intmm

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os

import torch
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

from torchao.utils import check_cpu_version

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

try:
    from torchao.kernel import intmm_triton
except ImportError:
    logger.warning(
        "Warning: Detected no triton, on systems without Triton certain kernels will not work"
    )
    # On cpu-only builds might not be available.
    intmm_triton = None

AUTOTUNER_ENABLE = bool(int(os.getenv("TORCHAO_AUTOTUNER_ENABLE", 0)))


[docs]def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: """ Performs a safe integer matrix multiplication, considering different paths for torch.compile, cublas, and fallback cases. Args: input (torch.Tensor): The input tensor of shape [i, j]. mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. Returns: torch.Tensor: The result of the matrix multiplication. Raises: AssertionError: If the tensors are not on the same device. """ # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): if input.device.type == "cpu": # Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend return out_dtype( torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float() ) return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) # error checking for cublas path assert mat2.device == input.device, ( f"need both tensors to be on the same device but got {mat2.device} and {input.device}" ) device_cpu = "cpu" in [mat2.device.type, input.device.type] # with input.shape = [i,j] and mat2.shape = [j,k] j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) bad_dimensions_for_cublas = not ( j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( input.device.type ) # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this mat2 = mat2.contiguous() if (not input.is_contiguous()) and ( input.shape[0] % 8 != 0 ): # gives cryptic error without this input = ( input.contiguous() ) # (it seems the transpose makes cublas check the above j constraint on i) try: return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) except Exception: # fallback path, would run on H100 for float8 dtypes # Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn' return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to( torch.int32 )
def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Performs integer matrix multiplication using intmm_triton if available and autotuner is enabled, otherwise falls back to safe_int_mm. Args: a (torch.Tensor): The first matrix to multiply. b (torch.Tensor): The second matrix to multiply. Returns: torch.Tensor: The result of the matrix multiplication. """ if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_matmul(a, b) return safe_int_mm(a, b)
[docs]def int_scaled_matmul( a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor ) -> torch.Tensor: """ Performs scaled integer matrix multiplication. Args: a (torch.Tensor): The first matrix to multiply. b (torch.Tensor): The second matrix to multiply. scales1 (torch.Tensor): The scaling factors for the rows of the result. Returns: torch.Tensor: The result of the scaled matrix multiplication. Raises: AssertionError: If the dimensions of the input tensors do not match the expected shapes. """ M, K = a.shape K, N = b.shape assert M == scales1.size(0) or scales1.numel() == 1 assert 1 == scales1.size(1) assert scales1.is_contiguous() scales1 = scales1.expand((M, N)) assert scales1.dim() == 2 if check_cpu_version(scales1.device): # CPU prefers decomposed version of int_scaled_matmul # to leverage the fusion capability of Inductor c = torch._int_mm(a, b) return c.to(scales1.dtype) * scales1 if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_scaled_matmul(a, b, scales1) c = safe_int_mm(a, b) return c * scales1

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources