You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/metric_precision_recall.py

118 lines
4.3 KiB

import torch
import torch.nn.functional as F
class PrecisionRecall:
def __init__(self, threshold=0.5, multi_label=False, device=None):
self.threshold = threshold
self.device = device
self.multi_label = multi_label
# statistics
# the total number of true positive instances under each class
# Shape: (num_classes, )
self._tp_sum = None
# the total number of instances
# Shape: (num_classes, )
self._total_sum = None
# the total number of instances under each _predicted_ class,
# including true positives and false positives
# Shape: (num_classes, )
self._pred_sum = None
# the total number of instances under each _true_ class,
# including true positives and false negatives
# Shape: (num_classes, )
self._true_sum = None
self.reset()
def reset(self):
self._tp_sum = None
self._total_sum = None
self._pred_sum = None
self._true_sum = None
def update(self, predictions, target):
output_type = predictions.type()
num_classes = predictions.size(-1)
if self.multi_label:
if self.threshold is not None:
predictions = (predictions > self.threshold).type(output_type)
predictions = predictions.t().reshape(num_classes, -1)
target = target.t().reshape(num_classes, -1)
else:
target = F.one_hot(target.view(-1), num_classes=num_classes)
indices = torch.argmax(predictions, dim=1).view(-1)
predictions = F.one_hot(indices, num_classes=num_classes)
# FIXME make sure binary case works
target = target.type(output_type)
correct = (target * predictions > 0).type(output_type)
pred_positives = predictions.sum(dim=0)
target_positives = target.sum(dim=0)
if correct.sum() == 0:
true_positives = torch.zeros_like(pred_positives)
else:
true_positives = correct.sum(dim=0)
if self._tp_sum is None:
self._tp_sum = torch.zeros(num_classes, device=self.device)
self._true_sum = torch.zeros(num_classes, device=self.device)
self._pred_sum = torch.zeros(num_classes, device=self.device)
self._total_sum = torch.tensor(0, device=self.device)
self._tp_sum += true_positives
self._pred_sum += pred_positives
self._true_sum += target_positives
self._total_sum += target.shape[0]
def counts_as_tuple(self, reduce=False):
tp_sum = self._tp_sum
pred_sum = self._pred_sum
true_sum = self._true_sum
total_sum = self._total_sum
if reduce:
tp_sum = reduce_tensor_sum(tp_sum)
pred_sum = reduce_tensor_sum(pred_sum)
true_sum = reduce_tensor_sum(true_sum)
total_sum = reduce_tensor_sum(total_sum)
return tp_sum, pred_sum, true_sum, total_sum
def counts(self, reduce=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum)
def confusion(self, reduce=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
fp = pred_sum - tp_sum
fn = true_sum - tp_sum
tp = tp_sum
tn = total_sum - tp - fp - fn
return dict(tp=tp, fp=fp, fn=fn, tn=tn)
def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed)
if average == 'micro':
tp_sum = tp_sum.sum()
pred_sum = pred_sum.sum()
true_sum = true_sum.sum()
precision = tp_sum / pred_sum
recall = tp_sum / true_sum
beta_sq = fscore_beta ** 2
f1_denom = beta_sq * precision + recall
fscore = (1 + beta_sq) * precision * recall / f1_denom
if average == 'macro' and not no_reduce:
precision = precision.mean()
recall = recall.mean()
fscore = fscore.mean()
return dict(fscore=fscore, precision=precision, recall=recall)
return dict(fscore=fscore, precision=precision, recall=recall)