Source code for torcheval.metrics.classification.auroc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-ignore-all-errors[16]: Undefined attribute of metric states.
from typing import Iterable, Optional, TypeVar
import torch
from torcheval.metrics.functional.classification.auroc import (
_binary_auroc_compute,
_binary_auroc_update_input_check,
_multiclass_auroc_compute,
_multiclass_auroc_param_check,
_multiclass_auroc_update_input_check,
)
from torcheval.metrics.metric import Metric
try:
import fbgemm_gpu.metrics # noqa
has_fbgemm = True
except ImportError:
has_fbgemm = False
TAUROC = TypeVar("TAUROC")
TMulticlasslAUROC = TypeVar("TMulticlassAUROC")
[docs]class BinaryAUROC(Metric[torch.Tensor]):
"""
Compute AUROC, which is the area under the ROC Curve, for binary classification.
AUROC is defined as the area under the Receiver Operating Curve, a plot with x=false positive rate y=true positive rate.
The points on the curve are sampled from the data given and the area is computed using the trapezoid method.
Multiple tasks are supported for Binary AUROC. A two-dimensional vector can given for the predicted values (inputs) and targets. This gives equivalent results to having one BinaryAUROC object for each row.
Its functional version is :func:`torcheval.metrics.functional.binary_auroc`.
See also :class:`MulticlassAUROC <MulticlassAUROC>`
Examples::
>>> import torch
>>> from torcheval.metrics import BinaryAUROC
>>> metric = BinaryAUROC()
>>> input = torch.tensor([0.1, 0.5, 0.7, 0.8])
>>> target = torch.tensor([1, 0, 1, 1])
>>> metric.update(input, target)
>>> metric.compute()
tensor([0.6667])
>>> input = torch.tensor([1, 1, 1, 0])
>>> target = torch.tensor([1, 1, 1, 0])
>>> metric.update(input, target)
>>> metric.compute()
tensor([1.0])
>>> metric = BinaryAUROC(num_tasks=2)
>>> input = torch.tensor([[1, 1, 1, 0], [0.1, 0.5, 0.7, 0.8]])
>>> target = torch.tensor([[1, 0, 1, 0], [1, 0, 1, 1]])
>>> metric.update(input, target)
>>> metric.compute()
tensor([0.7500, 0.6667])
"""
[docs] def __init__(
self: TAUROC,
*,
num_tasks: int = 1,
device: Optional[torch.device] = None,
use_fbgemm: Optional[bool] = False,
) -> None:
super().__init__(device=device)
if num_tasks < 1:
raise ValueError(
"`num_tasks` value should be greater than or equal to 1, but received {num_tasks}. "
)
if not has_fbgemm and use_fbgemm:
raise ValueError(
"`use_fbgemm` is enabled but `fbgemm_gpu` is not found. Please "
"install `fbgemm_gpu` to use this option."
)
self.num_tasks = num_tasks
self._add_state("inputs", [])
self._add_state("targets", [])
self._add_state("weights", [])
self.use_fbgemm = use_fbgemm
@torch.inference_mode()
# pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
def update(
self: TAUROC,
input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
) -> TAUROC:
"""
Update states with the ground truth labels and predictions.
Args:
input (Tensor): Tensor of label predictions
It should be predicted label, probabilities or logits with shape of (num_tasks, n_sample) or (n_sample, ).
target (Tensor): Tensor of ground truth labels with shape of (num_tasks, n_sample) or (n_sample, ).
weight (Tensor): Optional. A manual rescaling weight to match input tensor shape (num_tasks, num_samples) or (n_sample, ).
"""
input = input.to(self.device)
target = target.to(self.device)
if weight is None:
weight = torch.ones_like(input, dtype=torch.double)
_binary_auroc_update_input_check(input, target, self.num_tasks, weight)
self.inputs.append(input)
self.targets.append(target)
self.weights.append(weight)
return self
@torch.inference_mode()
def compute(
self: TAUROC,
) -> torch.Tensor:
"""
Return AUROC. If no ``update()`` calls are made before
``compute()`` is called, return an empty tensor.
Returns:
Tensor: The return value of AUROC for each task (num_tasks,).
"""
return _binary_auroc_compute(
torch.cat(self.inputs, -1),
torch.cat(self.targets, -1),
torch.cat(self.weights, -1),
self.use_fbgemm,
)
@torch.inference_mode()
def merge_state(self: TAUROC, metrics: Iterable[TAUROC]) -> TAUROC:
for metric in metrics:
if metric.inputs:
metric_inputs = torch.cat(metric.inputs, -1).to(self.device)
metric_targets = torch.cat(metric.targets, -1).to(self.device)
metric_weights = torch.cat(metric.weights, -1).to(self.device)
self.inputs.append(metric_inputs)
self.targets.append(metric_targets)
self.weights.append(metric_weights)
return self
@torch.inference_mode()
def _prepare_for_merge_state(self: TAUROC) -> None:
if self.inputs and self.targets:
self.inputs = [torch.cat(self.inputs, -1)]
self.targets = [torch.cat(self.targets, -1)]
self.weights = [torch.cat(self.weights, -1)]
[docs]class MulticlassAUROC(Metric[torch.Tensor]):
"""
Compute AUROC, which is the area under the ROC Curve, for multiclass classification in a one vs rest fashion.
One vs. rest Multiclass AUROC is equivalent to running a BinaryAUROC with `num_classes` tasks where
1. The `input` is transposed
2. The `target` is translated from a 1 dimensional tensor of the correct classes to a 2 dimensional tensor where each row is a list containing which examples belong to that class.
See examples below for more details on the connection between Multiclass and Binary AUROC.
The functional version of this metric is :func:`torcheval.metrics.functional.multiclass_auroc`.
See also :class:`BinaryAUROC <BinaryAUROC>`
Args:
num_classes (int): Number of classes.
average (str, optional):
- ``'macro'`` [default]:
Calculate metrics for each class separately, and return their unweighted mean.
- ``None``:
Calculate the metric for each class separately, and return
the metric for every class.
Examples::
>>> import torch
>>> from torcheval.metrics import MulticlassAUROC
>>> metric = MulticlassAUROC(num_classes=4)
>>> input = torch.tensor([[0.1, 0.1, 0.1, 0.1], [0.5, 0.5, 0.5, 0.5], [0.7, 0.7, 0.7, 0.7], [0.8, 0.8, 0.8, 0.8]])
>>> target = torch.tensor([0, 1, 2, 3])
>>> metric.update(input, target)
>>> metric.compute()
tensor(0.5000)
>>> metric = MulticlassAUROC(num_classes=3, average=None)
>>> input = torch.tensor([[0.1, 0, 0], [0, 1, 0], [0.1, 0.2, 0.7], [0, 0, 1]])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric.update(input, target)
>>> metric.compute()
tensor([0.8333, 1.0000, 1.0000])
the above is equivalent to
>>> from torcheval.metrics import BinaryAUROC
>>> metric = BinaryAUROC(num_tasks=3)
>>> input = torch.tensor([[0.1, 0, 0.1, 0], [0, 1, 0.2, 0], [0, 0, 0.7, 1]])
>>> target = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]])
>>> metric.update(input, target)
>>> metric.compute()
tensor([0.8333, 1.0000, 1.0000])
"""
[docs] def __init__(
self: TMulticlasslAUROC,
*,
num_classes: int,
average: Optional[str] = "macro",
device: Optional[torch.device] = None,
) -> None:
super().__init__(device=device)
_multiclass_auroc_param_check(num_classes, average)
self.num_classes = num_classes
self.average = average
self._add_state("inputs", [])
self._add_state("targets", [])
@torch.inference_mode()
# pyre-ignore[14]: inconsistent override on *_:Any, **__:Any
def update(
self: TMulticlasslAUROC,
input: torch.Tensor,
target: torch.Tensor,
) -> TMulticlasslAUROC:
"""
Update states with the ground truth labels and predictions.
Args:
input (Tensor): Tensor of label predictions
It should be probabilities or logits with shape of (n_sample, n_class).
target (Tensor): Tensor of ground truth labels with shape of (n_samples, ).
"""
input = input.to(self.device)
target = target.to(self.device)
_multiclass_auroc_update_input_check(input, target, self.num_classes)
self.inputs.append(input)
self.targets.append(target)
return self
@torch.inference_mode()
def compute(
self: TMulticlasslAUROC,
) -> torch.Tensor:
return _multiclass_auroc_compute(
torch.cat(self.inputs),
torch.cat(self.targets),
self.num_classes,
self.average,
)
@torch.inference_mode()
def merge_state(
self: TMulticlasslAUROC, metrics: Iterable[TMulticlasslAUROC]
) -> TMulticlasslAUROC:
for metric in metrics:
if metric.inputs:
metric_inputs = torch.cat(metric.inputs).to(self.device)
metric_targets = torch.cat(metric.targets).to(self.device)
self.inputs.append(metric_inputs)
self.targets.append(metric_targets)
return self
@torch.inference_mode()
def _prepare_for_merge_state(self: TMulticlasslAUROC) -> None:
if self.inputs and self.targets:
self.inputs = [torch.cat(self.inputs)]
self.targets = [torch.cat(self.targets)]