Shortcuts

Source code for torcheval.metrics.aggregation.auc

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[16]: Undefined attribute of metric states.

from typing import Iterable, Optional, TypeVar

import torch

from torcheval.metrics.functional.aggregation.auc import (
    _auc_compute,
    _auc_update_input_check,
)
from torcheval.metrics.metric import Metric


TAUC = TypeVar("TAUC")


[docs]class AUC(Metric[torch.Tensor]): r""" Computes Area Under the Curve (AUC) using the trapezoidal rule. Supports x and y being two dimensional tensors, each row is treated as its own list of x and y coordinates returning one dimensional tensor should be returned with the AUC for each row calculated. Args: reorder (bool): Reorder the input tensor for auc computation. Default value is True. num_tasks (int): Number of tasks that need AUC calculation. Default value is 1. >>> from torcheval.metrics.aggregation.auc import AUC >>> metric = AUC() >>> metric.update(torch.tensor([0,.2,.3,.1]), torch.tensor([1,1,1,1])) >>> metric.compute() tensor([0.3000]) >>> metric.reset() >>> metric.update(torch.tensor([0,.1,.13,.2]), torch.tensor([1,1,2,4])) >>> metric.update(torch.tensor([1.,2.,.1, 3.]), torch.tensor([1,2,3,2])) >>> metric.compute() tensor([5.8850]) >>> metric = AUC(n_tasks=2) # n_tasks should be equal to first dimension of x, y in update() >>> x = torch.tensor([[0.3941, 0.2980, 0.3080], [0.1448, 0.6090, 0.2462]]) >>> y = torch.tensor([[1, 0, 4], [0, 4, 2]]) >>> metric.update(x, y) >>> x1 = torch.tensor([[0.4562, 0.1200, 0.4238], [0.4076, 0.4448, 0.1476]]) >>> y1 = torch.tensor([[3, 4, 3], [2, 0, 4]]) >>> metric.update(x1, y1) >>> metric.compute() tensor([0.7479, 0.9898]) """
[docs] def __init__( self: TAUC, *, reorder: bool = True, n_tasks: int = 1, device: Optional[torch.device] = None, ) -> None: super().__init__(device=device) self._add_state("x", []) self._add_state("y", []) self.n_tasks = n_tasks self.reorder = reorder
@torch.inference_mode() # pyre-ignore[14]: inconsistent override on *_:Any, **__:Any def update(self: TAUC, x: torch.Tensor, y: torch.Tensor) -> TAUC: r""" Updates and returns variables required to compute area under the curve. Args: x: x-coordinates, y: y-coordinates """ _auc_update_input_check(x, y, n_tasks=self.n_tasks) if x.ndim == 1: x = x.unsqueeze(0) if y.ndim == 1: y = y.unsqueeze(0) self.x.append(x) self.y.append(y) return self @torch.inference_mode() def compute(self: TAUC) -> torch.Tensor: """Computes AUC based on inputs passed in to ``update`` previously.""" if not self.x or not self.y: return torch.tensor([]) return _auc_compute( torch.cat(self.x, dim=1), torch.cat(self.y, dim=1), reorder=self.reorder ) @torch.inference_mode() def merge_state(self: TAUC, metrics: Iterable[TAUC]) -> TAUC: self._prepare_for_merge_state() for metric in metrics: if metric.x: metric_x = torch.cat(metric.x, dim=1).to(self.device) metric_y = torch.cat(metric.y, dim=1).to(self.device) self.x.append(metric_x) self.y.append(metric_y) return self @torch.inference_mode() def _prepare_for_merge_state(self: TAUC) -> None: if self.x and self.y: self.x = [torch.cat(self.x, dim=1)] self.y = [torch.cat(self.y, dim=1)]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources