You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/metrics/accuracy.py

115 lines
3.6 KiB

import torch
from typing import Optional, Tuple, Dict
class Accuracy(torch.nn.Module):
def __init__(self, threshold=0.5, multi_label=False):
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._correct_sum = torch.tensor(0, dtype=torch.long)
self._total_sum = torch.tensor(0, dtype=torch.long)
def update(self, predictions, target):
raise NotImplemented()
def reset(self):
self._correct_sum = 0
self._total_sum = 0
@property
def counts(self):
pass
def compute(self):
raise NotImplemented()
class AccuracyTopK(torch.nn.Module):
def __init__(self, topk=(1, 5), device=None):
super().__init__()
self.eps = 1e-8
self.device = device
self.topk = topk
self.maxk = max(topk)
# FIXME handle distributed operation
# statistics / counts
self.reset()
def update(self, predictions: torch.Tensor, target: torch.Tensor):
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
sorted_indices.t_()
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
batch_size = target.shape[0]
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
for k, v in correct_k.items():
attr = f'_correct_top{k}'
old_v = getattr(self, attr)
setattr(self, attr, old_v + v)
self._total_sum += batch_size
def reset(self):
for k in self.topk:
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
self._total_sum = torch.tensor(0, dtype=torch.float32)
@property
def counts(self):
pass
def compute(self) -> Dict[str, torch.Tensor]:
# FIXME handle distributed reduction
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
#
# class AccuracyTopK:
#
# def __init__(self, topk=(1, 5), device=None):
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
#
# # statistics / counts
# self._correct_sum = None
# self._total_sum = None
#
# def _check_init(self, device):
# to_device = self.device if self.device else device
# if self._correct_sum is None:
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
# if self._total_sum is None:
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# self._check_init(device=predictions.device)
# for k, v in correct_k.items():
# old_v = self._correct_sum[k]
# self._correct_sum[k] = old_v + v
# self._total_sum += batch_size
#
# def reset(self):
# self._correct_sum = None
# self._total_sum = None
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# assert self._correct_sum is not None and self._total_sum is not None
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}