Rate this Page

safe_int_mm#

torchao.quantization.safe_int_mm(input: Tensor, mat2: Tensor) Tensor[source][source]#

Performs a safe integer matrix multiplication, considering different paths for torch.compile, cublas, and fallback cases.

Parameters
  • input (torch.Tensor) – The input tensor of shape [i, j].

  • mat2 (torch.Tensor) – The matrix to multiply with, of shape [j, k].

Returns

The result of the matrix multiplication.

Return type

torch.Tensor

Raises

AssertionError – If the tensors are not on the same device.