SelectionRate#
- class ignite.metrics.fairness.SelectionRate(output_transform=<function SelectionRate.<lambda>>, is_multilabel=False, device=device(type='cpu'), skip_unrolling=False)[source]#
Calculates the selection rate (rate of positive predictions).
updatemust receive output of the form(y_pred, y).y_pred must be in the following shape (batch_size, num_categories, …) or (batch_size, …).
y must be in the following shape (batch_size, …).
- Parameters:
Examples
from collections import OrderedDict import torch from torch import nn, optim from ignite.engine import * from ignite.handlers import * from ignite.metrics import * from ignite.metrics.clustering import * from ignite.metrics.fairness import * from ignite.metrics.rec_sys import * from ignite.metrics.regression import * from ignite.utils import * # create default evaluator for doctests def eval_step(engine, batch): return batch default_evaluator = Engine(eval_step) # create default optimizer for doctests param_tensor = torch.zeros([1], requires_grad=True) default_optimizer = torch.optim.SGD([param_tensor], lr=0.1) # create default trainer for doctests # as handlers could be attached to the trainer, # each test must define his own trainer using `.. testsetup:` def get_default_trainer(): def train_step(engine, batch): return batch return Engine(train_step) # create default model for doctests default_model = nn.Sequential(OrderedDict([ ('base', nn.Linear(4, 2)), ('fc', nn.Linear(2, 1)) ])) manual_seed(666)
metric = SelectionRate() metric.attach(default_evaluator, 'selection_rate') y_pred = torch.tensor([[0.1, 0.9], [0.2, 0.8], [0.9, 0.1], [0.9, 0.1]]) y_true = torch.tensor([1, 1, 0, 0]) # ignored state = default_evaluator.run([[y_pred, y_true]]) print(state.metrics['selection_rate'])
tensor([0.5000, 0.5000])
New in version 0.5.4.
Methods
Computes the selection rate.
Resets the metric to its initial state.
Updates the metric's state using the passed batch output.
- compute()[source]#
Computes the selection rate.
- Returns:
The selection rate for each category/label.
- Return type: