You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
143 lines
5.0 KiB
143 lines
5.0 KiB
4 years ago
|
import abc
|
||
|
from typing import Callable, Union, Optional, List, Tuple, Dict
|
||
|
from dataclasses import dataclass
|
||
|
|
||
|
import torch
|
||
|
from torch.distributed import ReduceOp
|
||
|
|
||
|
from .device_env import DeviceEnv
|
||
|
from .device_env_factory import get_device
|
||
|
from .distributed import all_gather_sequence, all_reduce_sequence
|
||
|
|
||
|
MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
|
||
|
|
||
|
@dataclass
|
||
|
class ValueInfo:
|
||
|
initial: Optional[MetricValue] = 0.
|
||
|
dtype: torch.dtype = torch.float32
|
||
|
dist_reduce: str = 'sum'
|
||
|
dist_average: bool = False
|
||
|
|
||
|
|
||
|
class Metric(abc.ABC):
|
||
|
|
||
|
def __init__(self, dev_env: DeviceEnv = None):
|
||
|
self._infos: Dict[str, ValueInfo] = {}
|
||
|
self._values: Dict[str, Optional[MetricValue]] = {}
|
||
|
self._values_dist: Dict[str, Optional[MetricValue]] = {}
|
||
|
if dev_env is None:
|
||
|
dev_env = get_device()
|
||
|
self._dev_env = dev_env
|
||
|
|
||
|
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
|
||
|
info = info or ValueInfo()
|
||
|
self._infos[name] = info
|
||
|
|
||
|
# def get_value(self, name: str, use_dist=True):
|
||
|
# if use_dist:
|
||
|
# return self._values_dist.get(name, self._values.get(name))
|
||
|
# else:
|
||
|
# return self._values.get(name)
|
||
|
|
||
|
def __getattr__(self, item):
|
||
|
if item not in self._infos:
|
||
|
raise AttributeError
|
||
|
value = self._values_dist.get(item, self._values.get(item, None))
|
||
|
return value
|
||
|
|
||
|
def __setattr__(self, key, value):
|
||
|
if '_infos' in self.__dict__ and key in self._infos:
|
||
|
self._values[key] = value
|
||
|
else:
|
||
|
super().__setattr__(key, value)
|
||
|
|
||
|
def update(
|
||
|
self,
|
||
|
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||
|
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
|
||
|
self._update(predictions, target)
|
||
|
|
||
|
def _update(
|
||
|
self,
|
||
|
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||
|
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
|
||
|
pass
|
||
|
|
||
|
def reset(self):
|
||
|
self._values = {}
|
||
|
self._values_dist = {}
|
||
|
for name, info in self._infos.items():
|
||
|
# if info specifies an initial value, we reset here, otherwise set to None and leave it to child class
|
||
|
if info.initial is not None:
|
||
|
if isinstance(info.initial, torch.Tensor):
|
||
|
tensor = info.initial.detach().clone()
|
||
|
else:
|
||
|
tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar
|
||
|
self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype)
|
||
|
else:
|
||
|
self._values[name] = None
|
||
|
self._reset()
|
||
|
|
||
|
def _reset(self):
|
||
|
pass
|
||
|
|
||
|
def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
|
||
|
if self._dev_env.distributed:
|
||
|
self._distribute_values()
|
||
|
results = self._compute()
|
||
|
self._values_dist = {}
|
||
|
return results
|
||
|
|
||
|
@abc.abstractmethod
|
||
|
def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
|
||
|
pass
|
||
|
|
||
|
def _distribute_values(self):
|
||
|
if not self._infos or not self._values:
|
||
|
return
|
||
|
|
||
|
def _args(op: str):
|
||
|
if op == 'cat':
|
||
|
return True, dict(cat_dim=0)
|
||
|
else:
|
||
|
return False, dict(op=ReduceOp.SUM)
|
||
|
|
||
|
prev_dsr = None
|
||
|
same_dsr = True
|
||
|
names = []
|
||
|
values = []
|
||
|
reductions = []
|
||
|
for name, value in self._values.items():
|
||
|
if value is not None:
|
||
|
info = self._infos[name]
|
||
|
dsr = (value.dtype, value.shape, info.dist_reduce)
|
||
|
if prev_dsr is not None and prev_dsr != dsr:
|
||
|
same_dsr = False
|
||
|
prev_dsr = dsr
|
||
|
names.append(name)
|
||
|
values.append(value)
|
||
|
reductions.append(_args(info.dist_reduce))
|
||
|
same_dsr = False
|
||
|
if same_dsr:
|
||
|
do_gather, reduce_kwargs = reductions[0]
|
||
|
if do_gather:
|
||
|
reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
|
||
|
else:
|
||
|
reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
|
||
|
for name, reduced_value in zip(names, reduced_values):
|
||
|
info = self._infos[name]
|
||
|
if info.dist_average:
|
||
|
reduced_value /= self._dev_env.world_size
|
||
|
self._values_dist[name] = reduced_value
|
||
|
else:
|
||
|
for n, v, r in zip(names, values, reductions):
|
||
|
info = self._infos[n]
|
||
|
do_gather, reduce_kwargs = r
|
||
|
if do_gather:
|
||
|
reduced_value = self._dev_env.all_gather(v, **reduce_kwargs)
|
||
|
else:
|
||
|
reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs)
|
||
|
if info.dist_average:
|
||
|
reduced_value /= self._dev_env.world_size
|
||
|
self._values_dist[n] = reduced_value
|