You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
115 lines
3.6 KiB
115 lines
3.6 KiB
import torch
|
|
from typing import Optional, Tuple, Dict
|
|
|
|
|
|
class Accuracy(torch.nn.Module):
|
|
|
|
def __init__(self, threshold=0.5, multi_label=False):
|
|
self.threshold = threshold
|
|
self.eps = 1e-8
|
|
self.multi_label = multi_label
|
|
|
|
# statistics / counts
|
|
self._correct_sum = torch.tensor(0, dtype=torch.long)
|
|
self._total_sum = torch.tensor(0, dtype=torch.long)
|
|
|
|
def update(self, predictions, target):
|
|
raise NotImplemented()
|
|
|
|
def reset(self):
|
|
self._correct_sum = 0
|
|
self._total_sum = 0
|
|
|
|
@property
|
|
def counts(self):
|
|
pass
|
|
|
|
def compute(self):
|
|
raise NotImplemented()
|
|
|
|
|
|
class AccuracyTopK(torch.nn.Module):
|
|
|
|
def __init__(self, topk=(1, 5), device=None):
|
|
super().__init__()
|
|
self.eps = 1e-8
|
|
self.device = device
|
|
self.topk = topk
|
|
self.maxk = max(topk)
|
|
# FIXME handle distributed operation
|
|
|
|
# statistics / counts
|
|
self.reset()
|
|
|
|
def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
|
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
|
sorted_indices.t_()
|
|
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
|
|
|
batch_size = target.shape[0]
|
|
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
|
for k, v in correct_k.items():
|
|
attr = f'_correct_top{k}'
|
|
old_v = getattr(self, attr)
|
|
setattr(self, attr, old_v + v)
|
|
self._total_sum += batch_size
|
|
|
|
def reset(self):
|
|
for k in self.topk:
|
|
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
|
|
self._total_sum = torch.tensor(0, dtype=torch.float32)
|
|
|
|
@property
|
|
def counts(self):
|
|
pass
|
|
|
|
def compute(self) -> Dict[str, torch.Tensor]:
|
|
# FIXME handle distributed reduction
|
|
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
|
|
|
|
|
|
#
|
|
# class AccuracyTopK:
|
|
#
|
|
# def __init__(self, topk=(1, 5), device=None):
|
|
# self.eps = 1e-8
|
|
# self.device = device
|
|
# self.topk = topk
|
|
# self.maxk = max(topk)
|
|
#
|
|
# # statistics / counts
|
|
# self._correct_sum = None
|
|
# self._total_sum = None
|
|
#
|
|
# def _check_init(self, device):
|
|
# to_device = self.device if self.device else device
|
|
# if self._correct_sum is None:
|
|
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
|
|
# if self._total_sum is None:
|
|
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
|
|
#
|
|
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
|
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
|
# sorted_indices.t_()
|
|
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
|
#
|
|
# batch_size = target.shape[0]
|
|
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
|
# self._check_init(device=predictions.device)
|
|
# for k, v in correct_k.items():
|
|
# old_v = self._correct_sum[k]
|
|
# self._correct_sum[k] = old_v + v
|
|
# self._total_sum += batch_size
|
|
#
|
|
# def reset(self):
|
|
# self._correct_sum = None
|
|
# self._total_sum = None
|
|
#
|
|
# @property
|
|
# def counts(self):
|
|
# pass
|
|
#
|
|
# def compute(self) -> Dict[str, torch.Tensor]:
|
|
# assert self._correct_sum is not None and self._total_sum is not None
|
|
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}
|