You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/metric_accuracy.py

99 lines
3.1 KiB

import torch
from typing import Optional, Tuple, Dict
from .device_env import DeviceEnv
from .metric import Metric, ValueInfo
class Accuracy(Metric):
def __init__(self, threshold=0.5, multi_label=False, dev_env=None):
super().__init__(dev_env=dev_env)
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._register_value('correct')
self._register_value('total')
def _update(self, predictions, target):
raise NotImplemented()
def _compute(self):
raise NotImplemented()
# class AccuracyTopK(torch.nn.Module):
#
# def __init__(self, topk=(1, 5), device=None):
# super().__init__()
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
# # FIXME handle distributed operation
#
# # statistics / counts
# self.reset()
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# for k, v in correct_k.items():
# attr = f'_correct_top{k}'
# old_v = getattr(self, attr)
# setattr(self, attr, old_v + v)
# self._total_sum += batch_size
#
# def reset(self):
# for k in self.topk:
# setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
# self._total_sum = torch.tensor(0, dtype=torch.float32)
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# # FIXME handle distributed reduction
# return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
class AccuracyTopK(Metric):
def __init__(self, topk=(1, 5), dev_env: DeviceEnv = None):
super().__init__(dev_env=dev_env)
self.eps = 1e-8
self.topk = topk
self.maxk = max(topk)
# statistics / counts
for k in self.topk:
self._register_value(f'top{k}')
self._register_value('total')
self.reset()
def _update(self, predictions: torch.Tensor, target: torch.Tensor):
batch_size = predictions.shape[0]
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
target_reshape = target.reshape(-1, 1).expand_as(sorted_indices)
correct = sorted_indices.eq(target_reshape).float().sum(0)
for k in self.topk:
attr_name = f'top{k}'
correct_at_k = correct[:k].sum()
setattr(self, attr_name, getattr(self, attr_name) + correct_at_k)
self.total += batch_size
def _compute(self) -> Dict[str, torch.Tensor]:
assert self.total is not None
output = {}
for k in self.topk:
attr_name = f'top{k}'
output[attr_name] = 100 * getattr(self, attr_name) / self.total
return output