Shortcuts

Source code for torcheval.metrics.functional.classification.confusion_matrix

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch

norm = torch.nn.functional.normalize


[docs]@torch.inference_mode() def binary_confusion_matrix( input: torch.Tensor, target: torch.Tensor, *, threshold: float = 0.5, normalize: Optional[str] = None, ) -> torch.Tensor: """ Compute binary confusion matrix, a 2 by 2 tensor with counts ( (true positive, false negative) , (false positive, true negative) ) See also :func:`multiclass_confusion_matrix <torcheval.metrics.functional.multiclass_confusion_matrix>` Args: input (Tensor): Tensor of label predictions with shape of (n_sample,). ``torch.where(input < threshold, 0, 1)`` will be applied to the input. target (Tensor): Tensor of ground truth labels with shape of (n_sample,). threshold (float, default 0.5): Threshold for converting input into predicted labels for each sample. ``torch.where(input < threshold, 0, 1)`` will be applied to the ``input``. normalize: - ``None`` [default]: Give raw counts ('none' also defaults to this) - ``'pred'``: Normalize across the prediction, i.e. such that the rows add to one. - ``'true'``: Normalize across the condition positive, i.e. such that the columns add to one. - ``'all'``" Normalize across all examples, i.e. such that all matrix entries add to one. Examples:: >>> import torch >>> from torcheval.metrics.functional import binary_confusion_matrix >>> input = torch.tensor([0, 1, 0.7, 0.6]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_confusion_matrix(input, target) tensor([[1, 1], [0, 2]]) >>> input = torch.tensor([1, 1, 0, 0]) >>> target = torch.tensor([0, 1, 1, 1]) >>> binary_confusion_matrix(input, target, threshold=1) tensor([[0, 1], [2, 1]]) >>> input = torch.tensor([1, 1, 0, 0]) >>> target = torch.tensor([0, 1, 1, 1]) >>> binary_confusion_matrix(input, target, normalize="true") tensor([[0.0000, 1.0000], [0.6667, 0.3333]]) """ _confusion_matrix_param_check(2, normalize) matrix = _binary_confusion_matrix_update(input, target, threshold) return _confusion_matrix_compute(matrix, normalize)
[docs]@torch.inference_mode() def multiclass_confusion_matrix( input: torch.Tensor, target: torch.Tensor, num_classes: int, *, normalize: Optional[str] = None, ) -> torch.Tensor: """ Compute multi-class confusion matrix, a matrix of dimension num_classes x num_classes where each element at position `(i,j)` is the number of examples with true class `i` that were predicted to be class `j`. See also :func:`binary_confusion_matrix <torcheval.metrics.functional.binary_confusion_matrix>` Args: input (Tensor): Tensor of label predictions. It could be the predicted labels, with shape of (n_sample, ). It could also be probabilities or logits with shape of (n_sample, n_class). ``torch.argmax`` will be used to convert input into predicted labels. target (Tensor): Tensor of ground truth labels with shape of (n_sample, ). num_classes (int): Number of classes. normalize: - ``None`` [default]: Give raw counts ('none' also defaults to this) - ``'pred'``: Normalize across the prediction class, i.e. such that the rows add to one. - ``'true'``: Normalize across the condition positive, i.e. such that the columns add to one. - ``'all'``" Normalize across all examples, i.e. such that all matrix entries add to one. Examples:: >>> import torch >>> from torcheval.metrics.functional import multiclass_confusion_matrix >>> input = torch.tensor([0, 2, 1, 3]) >>> target = torch.tensor([0, 1, 2, 3]) >>> multiclass_confusion_matrix(input, target, 4) tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) >>> input = torch.tensor([0, 0, 1, 1, 1]) >>> target = torch.tensor([0, 0, 0, 0, 1]) >>> multiclass_confusion_matrix(input, target, 2) tensor([[2, 2], [0, 1]]) >>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2]) >>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0]) >>> multiclass_confusion_matrix(input, target, 3) tensor([[1, 1, 1], [0, 2, 0], [1, 1, 1]]) >>> input = torch.tensor([0, 0, 1, 1, 1, 2, 1, 2]) >>> target = torch.tensor([2, 0, 2, 0, 1, 2, 1, 0]) >>> multiclass_confusion_matrix(input, target, 3, normalize="pred") tensor([[0.5000, 0.2500, 0.5000], [0.0000, 0.5000, 0.0000], [0.5000, 0.2500, 0.5000]]) >>> input = torch.tensor([0, 0, 1, 1, 1]) >>> target = torch.tensor([0, 0, 0, 0, 1]) >>> multiclass_confusion_matrix(input, target, 4) tensor([[2, 2, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) >>> input = torch.tensor([[0.9, 0.1, 0, 0], [0.1, 0.2, 0.4, 0.3], [0, 1.0, 0, 0], [0, 0, 0.2, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> multiclass_confusion_matrix(input, target, 4) tensor([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) """ _confusion_matrix_param_check(num_classes, normalize) sparse_cm = _confusion_matrix_update(input, target, num_classes) return _confusion_matrix_compute(sparse_cm, normalize=normalize)
def _binary_confusion_matrix_compute( cm: torch.Tensor, normalize: Optional[str] ) -> torch.Tensor: if normalize == "pred": return norm(cm.to(torch.float), p=1, dim=1) elif normalize == "true": return norm(cm.to(torch.float), p=1, dim=0) elif normalize == "all": return cm.to(torch.float) / torch.sum(cm) else: return cm def _binary_confusion_matrix_update( input: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, ) -> torch.Tensor: _binary_confusion_matrix_update_input_check(input, target) input = torch.where(input < threshold, 0, 1) return _update(input, target, 2) def _binary_confusion_matrix_update_input_check( input: torch.Tensor, target: torch.Tensor, ) -> None: if input.ndim != 1: raise ValueError( f"input should be a one-dimensional tensor for binary confusion matrix, got shape {input.shape}." ) if target.ndim != 1: raise ValueError( f"target should be a one-dimensional tensor for binary confusion matrix, got shape {target.shape}." ) if input.shape != target.shape: raise ValueError( "The `input` and `target` should have the same dimensions, " f"got shapes {input.shape} and {target.shape}." ) def _confusion_matrix_compute( confusion_matrix: torch.Tensor, normalize: Optional[str], ) -> torch.Tensor: if normalize == "pred": return norm(confusion_matrix.to(torch.float), p=1, dim=0) elif normalize == "true": return norm(confusion_matrix.to(torch.float), p=1, dim=1) elif normalize == "all": return confusion_matrix.to(torch.float) / torch.sum(confusion_matrix) else: return confusion_matrix def _confusion_matrix_update( input: torch.Tensor, target: torch.Tensor, num_classes: int ) -> torch.Tensor: _confusion_matrix_update_input_check(input, target, num_classes) return _update(input, target, num_classes) @torch.jit.script def _update( input: torch.Tensor, target: torch.Tensor, num_classes: int, ) -> torch.Tensor: if input.ndim == 2: input = torch.argmax(input, dim=1) coordinates = torch.vstack((target, input)) cm_shape = torch.Size([num_classes, num_classes]) # Each prediction creates an entry at the position (true, pred) sparse_cm = torch.sparse_coo_tensor(coordinates, torch.ones_like(target), cm_shape) return sparse_cm.to_dense() def _confusion_matrix_param_check( num_classes: int, normalize: Optional[str], ) -> None: if num_classes < 2: raise ValueError("Must be at least two classes for confusion matrix") if (normalize is not None) and (normalize not in ["all", "pred", "true", "none"]): raise ValueError("normalize must be one of 'all', 'pred', 'true', or 'none'.") def _confusion_matrix_update_input_check( input: torch.Tensor, target: torch.Tensor, num_classes: Optional[int], ) -> None: if input.size(0) != target.size(0): raise ValueError( "The `input` and `target` should have the same first dimension, " f"got shapes {input.shape} and {target.shape}." ) if target.ndim != 1: raise ValueError( "target should be a one-dimensional tensor, " f"got shape {target.shape}." ) # check if num classes is high enough to cover inputs. if not input.ndim == 1: if not (input.ndim == 2 and (input.shape[1] == num_classes)): raise ValueError( f"input should have shape of (num_sample,) or (num_sample, num_classes), " f"got {input.shape}." ) else: if torch.max(input) >= num_classes: raise ValueError( "Got `input` prediction class which is too large for the number of classes, " f"num_classes: {num_classes} must be strictly greater than max class predicted: {torch.max(input)}." ) # check if num classes is high enough to cover targets. if torch.max(target) >= num_classes: raise ValueError( "Got `target` class which is larger than the number of classes, " f"num_classes: {num_classes} must be strictly greater than max target: {torch.max(target)}." )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources