You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
4.3 KiB
118 lines
4.3 KiB
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class PrecisionRecall:
|
|
|
|
def __init__(self, threshold=0.5, multi_label=False, device=None):
|
|
self.threshold = threshold
|
|
self.device = device
|
|
self.multi_label = multi_label
|
|
|
|
# statistics
|
|
|
|
# the total number of true positive instances under each class
|
|
# Shape: (num_classes, )
|
|
self._tp_sum = None
|
|
|
|
# the total number of instances
|
|
# Shape: (num_classes, )
|
|
self._total_sum = None
|
|
|
|
# the total number of instances under each _predicted_ class,
|
|
# including true positives and false positives
|
|
# Shape: (num_classes, )
|
|
self._pred_sum = None
|
|
|
|
# the total number of instances under each _true_ class,
|
|
# including true positives and false negatives
|
|
# Shape: (num_classes, )
|
|
self._true_sum = None
|
|
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self._tp_sum = None
|
|
self._total_sum = None
|
|
self._pred_sum = None
|
|
self._true_sum = None
|
|
|
|
def update(self, predictions, target):
|
|
output_type = predictions.type()
|
|
num_classes = predictions.size(-1)
|
|
if self.multi_label:
|
|
if self.threshold is not None:
|
|
predictions = (predictions > self.threshold).type(output_type)
|
|
predictions = predictions.t().reshape(num_classes, -1)
|
|
target = target.t().reshape(num_classes, -1)
|
|
else:
|
|
target = F.one_hot(target.view(-1), num_classes=num_classes)
|
|
indices = torch.argmax(predictions, dim=1).view(-1)
|
|
predictions = F.one_hot(indices, num_classes=num_classes)
|
|
# FIXME make sure binary case works
|
|
|
|
target = target.type(output_type)
|
|
correct = (target * predictions > 0).type(output_type)
|
|
pred_positives = predictions.sum(dim=0)
|
|
target_positives = target.sum(dim=0)
|
|
if correct.sum() == 0:
|
|
true_positives = torch.zeros_like(pred_positives)
|
|
else:
|
|
true_positives = correct.sum(dim=0)
|
|
|
|
if self._tp_sum is None:
|
|
self._tp_sum = torch.zeros(num_classes, device=self.device)
|
|
self._true_sum = torch.zeros(num_classes, device=self.device)
|
|
self._pred_sum = torch.zeros(num_classes, device=self.device)
|
|
self._total_sum = torch.tensor(0, device=self.device)
|
|
|
|
self._tp_sum += true_positives
|
|
self._pred_sum += pred_positives
|
|
self._true_sum += target_positives
|
|
self._total_sum += target.shape[0]
|
|
|
|
def counts_as_tuple(self, reduce=False):
|
|
tp_sum = self._tp_sum
|
|
pred_sum = self._pred_sum
|
|
true_sum = self._true_sum
|
|
total_sum = self._total_sum
|
|
if reduce:
|
|
tp_sum = reduce_tensor_sum(tp_sum)
|
|
pred_sum = reduce_tensor_sum(pred_sum)
|
|
true_sum = reduce_tensor_sum(true_sum)
|
|
total_sum = reduce_tensor_sum(total_sum)
|
|
return tp_sum, pred_sum, true_sum, total_sum
|
|
|
|
def counts(self, reduce=False):
|
|
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
|
|
return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum)
|
|
|
|
def confusion(self, reduce=False):
|
|
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
|
|
fp = pred_sum - tp_sum
|
|
fn = true_sum - tp_sum
|
|
tp = tp_sum
|
|
tn = total_sum - tp - fp - fn
|
|
return dict(tp=tp, fp=fp, fn=fn, tn=tn)
|
|
|
|
def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False):
|
|
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed)
|
|
if average == 'micro':
|
|
tp_sum = tp_sum.sum()
|
|
pred_sum = pred_sum.sum()
|
|
true_sum = true_sum.sum()
|
|
|
|
precision = tp_sum / pred_sum
|
|
recall = tp_sum / true_sum
|
|
beta_sq = fscore_beta ** 2
|
|
f1_denom = beta_sq * precision + recall
|
|
fscore = (1 + beta_sq) * precision * recall / f1_denom
|
|
|
|
if average == 'macro' and not no_reduce:
|
|
precision = precision.mean()
|
|
recall = recall.mean()
|
|
fscore = fscore.mean()
|
|
return dict(fscore=fscore, precision=precision, recall=recall)
|
|
|
|
return dict(fscore=fscore, precision=precision, recall=recall)
|