Shortcuts

Source code for torcheval.metrics.classification.precision_recall_curve

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, List, Optional, Tuple, TypeVar

import torch

from torcheval.metrics.functional.classification.precision_recall_curve import (
    _binary_precision_recall_curve_compute,
    _binary_precision_recall_curve_update,
    _multiclass_precision_recall_curve_compute,
    _multiclass_precision_recall_curve_update,
    _multilabel_precision_recall_curve_compute,
    _multilabel_precision_recall_curve_update,
)
from torcheval.metrics.metric import Metric

"""
This file contains BinaryPrecisionRecallCurve, MulticlassPrecisionRecallCurve and MultilabelPrecisionRecallCurve classes.
"""

TBinaryPrecisionRecallCurve = TypeVar("TBinaryPrecisionRecallCurve")
TMulticlassPrecisionRecallCurve = TypeVar("TMulticlassPrecisionRecallCurve")
TMultilabelPrecisionRecallCurve = TypeVar("TMultilabelPrecisionRecallCurve")


[docs]class BinaryPrecisionRecallCurve( Metric[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ): """ Returns precision-recall pairs and their corresponding thresholds for binary classification tasks. If a class is missing from the target tensor, its recall values are set to 1.0. Its functional version is :func:`torcheval.metrics.functional.binary_precision_recall_curve`. See also :class:`MulticlassPrecisionRecallCurve <MulticlassPrecisionRecallCurve>`, :class:`MultilabelPrecisionRecallCurve <MultilabelPrecisionRecallCurve>` >>> import torch >>> from torcheval.metrics import BinaryPrecisionRecallCurve >>> metric = BinaryPrecisionRecallCurve() >>> input = torch.tensor([0.1, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 0, 1, 1]) >>> metric.update(input, target) >>> metric.compute() (tensor([1., 1., 1.]), tensor([1.0000, 0.5000, 0.0000]), tensor([0.7000, 0.8000])) """
[docs] def __init__( self: TBinaryPrecisionRecallCurve, *, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self._add_state("inputs", []) self._add_state("targets", [])
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TBinaryPrecisionRecallCurve, input: torch.Tensor, target: torch.Tensor, ) -> TBinaryPrecisionRecallCurve: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions It should be probabilities or logits with shape of (n_sample, ). target (Tensor): Tensor of ground truth labels with shape of (n_samples, ). """ input = input.to(self.device) target = target.to(self.device) _binary_precision_recall_curve_update( input, target, ) self.inputs.append(input) self.targets.append(target) return self @torch.inference_mode() def compute( self: TBinaryPrecisionRecallCurve, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: Tuple: - precision (Tensor): Tensor of precision result. Its shape is (n_thresholds + 1, ) - recall (Tensor): Tensor of recall result. Its shape is (n_thresholds + 1, ) - thresholds (Tensor): Tensor of threshold. Its shape is (n_thresholds, ) """ return _binary_precision_recall_curve_compute( torch.cat(self.inputs), torch.cat(self.targets) ) @torch.inference_mode() def merge_state( self: TBinaryPrecisionRecallCurve, metrics: Iterable[TBinaryPrecisionRecallCurve], ) -> TBinaryPrecisionRecallCurve: for metric in metrics: if metric.inputs: metric_inputs = torch.cat(metric.inputs).to(self.device) metric_targets = torch.cat(metric.targets).to(self.device) self.inputs.append(metric_inputs) self.targets.append(metric_targets) return self @torch.inference_mode() def _prepare_for_merge_state(self: TBinaryPrecisionRecallCurve) -> None: if self.inputs and self.targets: self.inputs = [torch.cat(self.inputs)] self.targets = [torch.cat(self.targets)]
[docs]class MulticlassPrecisionRecallCurve( Metric[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]] ): """ Returns precision-recall pairs and their corresponding thresholds for multi-class classification tasks. If a class is missing from the target tensor, its recall values are set to 1.0. Its class version is :func:`torcheval.metrics.functional.multiclass_precision_recall_curve`. See also :class:`BinaryPrecisionRecallCurve <BinaryPrecisionRecallCurve>`, :class:`MultilabelPrecisionRecallCurve <MultilabelPrecisionRecallCurve>` Args: num_classes (int, Optional): Number of classes. Set to the second dimension of the input if num_classes is None. Examples:: >>> import torch >>> from torcheval.metrics import MulticlassPrecisionRecallCurve >>> metric = MulticlassPrecisionRecallCurve(num_classes=4) >>> input = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.5, 0.5, 0.5, 0.5], [0.7, 0.7, 0.7, 0.7], [0.8, 0.8, 0.8, 0.8]]) >>> target = torch.tensor([0, 1, 2, 3]) >>> metric.update(input, target) >>> metric.compute() ([tensor([0.2500, 0.0000, 0.0000, 0.0000, 1.0000]), tensor([0.3333, 0.0000, 0.0000, 1.0000]), tensor([0.5000, 0.0000, 1.0000]), tensor([1., 1.])], [tensor([1., 0., 0., 0., 0.]), tensor([1., 0., 0., 0.]), tensor([1., 0., 0.]), tensor([1., 0.])], [tensor([0.1000, 0.5000, 0.7000, 0.8000]), tensor([0.5000, 0.7000, 0.8000]), tensor([0.7000, 0.8000]), tensor([0.8000])]) """
[docs] def __init__( self: TMulticlassPrecisionRecallCurve, *, num_classes: Optional[int] = None, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self.num_classes = num_classes self._add_state("inputs", []) self._add_state("targets", [])
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TMulticlassPrecisionRecallCurve, input: torch.Tensor, target: torch.Tensor, ) -> TMulticlassPrecisionRecallCurve: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions It should be probabilities or logits with shape of (n_sample, n_class). target (Tensor): Tensor of ground truth labels with shape of (n_samples, ). """ input = input.to(self.device) target = target.to(self.device) _multiclass_precision_recall_curve_update( input, target, self.num_classes, ) self.inputs.append(input) self.targets.append(target) return self @torch.inference_mode() def compute( self: TMulticlassPrecisionRecallCurve, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ Returns: - precision: List of precision result. Each index indicates the result of a class. - recall: List of recall result. Each index indicates the result of a class. - thresholds: List of threshold. Each index indicates the result of a class. """ return _multiclass_precision_recall_curve_compute( torch.cat(self.inputs), torch.cat(self.targets), self.num_classes ) @torch.inference_mode() def merge_state( self: TMulticlassPrecisionRecallCurve, metrics: Iterable[TMulticlassPrecisionRecallCurve], ) -> TMulticlassPrecisionRecallCurve: for metric in metrics: if metric.inputs: metric_inputs = torch.cat(metric.inputs).to(self.device) metric_targets = torch.cat(metric.targets).to(self.device) self.inputs.append(metric_inputs) self.targets.append(metric_targets) return self @torch.inference_mode() def _prepare_for_merge_state(self: TMulticlassPrecisionRecallCurve) -> None: if self.inputs and self.targets: self.inputs = [torch.cat(self.inputs)] self.targets = [torch.cat(self.targets)]
[docs]class MultilabelPrecisionRecallCurve( Metric[Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]] ): """ Returns precision-recall pairs and their corresponding thresholds for multi-label classification tasks. If there are no samples for a label in the target tensor, its recall values are set to 1.0. Its class version is :func:`torcheval.metrics.functional.multilabel_precision_recall_curve`. See also :class:`BinaryPrecisionRecallCurve <BinaryPrecisionRecallCurve>`, :class:`MulticlassPrecisionRecallCurve <MulticlassPrecisionRecallCurve>` Args: num_labels (int): Number of labels. Examples:: >>> import torch >>> from torcheval.metrics import MultilabelPrecisionRecallCurve >>> metric = MultilabelPrecisionRecallCurve(num_labels=3) >>> input = torch.tensor([[0.75, 0.05, 0.35], [0.45, 0.75, 0.05], [0.05, 0.55, 0.75], [0.05, 0.65, 0.05]]) >>> target = torch.tensor([[1, 0, 1], [0, 0, 0], [0, 1, 1], [1, 1, 1]]) >>> metric.update(input, target) >>> metric.compute() ([tensor([0.5, 0.5, 1.0, 1.0]), tensor([0.5, 0.66666667, 0.5, 0.0, 1.0]), tensor([0.75, 1.0, 1.0, 1.0])], [tensor([1.0, 0.5, 0.5, 0.0]), tensor([1.0, 1.0, 0.5, 0.0, 0.0]), tensor([1.0, 0.66666667, 0.33333333, 0.0])], [tensor([0.05, 0.45, 0.75]), tensor([0.05, 0.55, 0.65, 0.75]), tensor([0.05, 0.35, 0.75])]) """
[docs] def __init__( self: TMultilabelPrecisionRecallCurve, *, num_labels: int, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self.num_labels = num_labels self._add_state("inputs", []) self._add_state("targets", [])
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update( self: TMultilabelPrecisionRecallCurve, input: torch.Tensor, target: torch.Tensor, ) -> TMultilabelPrecisionRecallCurve: """ Update states with the ground truth labels and predictions. Args: input (Tensor): Tensor of label predictions It should be the probabilites with shape of (n_samples, n_labels). target (Tensor):Tensor of ground truth labels with shape of (n_samples, n_labels). """ input = input.to(self.device) target = target.to(self.device) _multilabel_precision_recall_curve_update( input, target, self.num_labels, ) self.inputs.append(input) self.targets.append(target) return self @torch.inference_mode() def compute( self: TMultilabelPrecisionRecallCurve, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ Returns: - precision: List of precision result. Each index indicates the result of a label. - recall: List of recall result. Each index indicates the result of a label. - thresholds: List of threshold. Each index indicates the result of a label. """ return _multilabel_precision_recall_curve_compute( torch.cat(self.inputs), torch.cat(self.targets), self.num_labels ) @torch.inference_mode() def merge_state( self: TMultilabelPrecisionRecallCurve, metrics: Iterable[TMultilabelPrecisionRecallCurve], ) -> TMultilabelPrecisionRecallCurve: for metric in metrics: if metric.inputs: metric_inputs = torch.cat(metric.inputs).to(self.device) metric_targets = torch.cat(metric.targets).to(self.device) self.inputs.append(metric_inputs) self.targets.append(metric_targets) return self @torch.inference_mode() def _prepare_for_merge_state(self: TMulticlassPrecisionRecallCurve) -> None: if self.inputs and self.targets: self.inputs = [torch.cat(self.inputs)] self.targets = [torch.cat(self.targets)]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources