import abc from typing import Callable, Union, Optional, List, Tuple, Dict from dataclasses import dataclass import torch from torch.distributed import ReduceOp from .device_env import DeviceEnv from .distributed import all_gather_sequence, all_reduce_sequence MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]] @dataclass class ValueInfo: initial: Optional[MetricValueT] = 0. dtype: torch.dtype = torch.float32 dist_reduce: str = 'sum' dist_average: bool = False class Metric(abc.ABC): def __init__( self, dev_env: DeviceEnv = None ): self._infos: Dict[str, ValueInfo] = {} self._values: Dict[str, Optional[MetricValueT]] = {} self._values_dist: Dict[str, Optional[MetricValueT]] = {} if dev_env is None: dev_env = DeviceEnv.instance() self._dev_env = dev_env def _register_value(self, name: str, info: Optional[ValueInfo] = None): info = info or ValueInfo() self._infos[name] = info # def get_value(self, name: str, use_dist=True): # if use_dist: # return self._values_dist.get(name, self._values.get(name)) # else: # return self._values.get(name) def __getattr__(self, item): if item not in self._infos: raise AttributeError value = self._values_dist.get(item, self._values.get(item, None)) return value def __setattr__(self, key, value): if '_infos' in self.__dict__ and key in self._infos: self._values[key] = value else: super().__setattr__(key, value) def update( self, predictions: Union[torch.Tensor, Dict[str, torch.Tensor]], target: Union[torch.Tensor, Dict[str, torch.Tensor]]): self._update(predictions, target) def _update( self, predictions: Union[torch.Tensor, Dict[str, torch.Tensor]], target: Union[torch.Tensor, Dict[str, torch.Tensor]]): pass def reset(self): self._values = {} self._values_dist = {} for name, info in self._infos.items(): # if info specifies an initial value, we reset here, otherwise set to None and leave it to child class if info.initial is not None: if isinstance(info.initial, torch.Tensor): tensor = info.initial.detach().clone() else: tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype) else: self._values[name] = None self._reset() def _reset(self): pass def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]: if self._dev_env.distributed: self._distribute_values() results = self._compute() self._values_dist = {} return results @abc.abstractmethod def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]: pass def _distribute_values(self): if not self._infos or not self._values: return def _args(op: str): if op == 'cat': return True, dict(cat_dim=0) else: return False, dict(op=ReduceOp.SUM) prev_dsr = None same_dsr = True names = [] values = [] reductions = [] for name, value in self._values.items(): if value is not None: info = self._infos[name] dsr = (value.dtype, value.shape, info.dist_reduce) if prev_dsr is not None and prev_dsr != dsr: same_dsr = False prev_dsr = dsr names.append(name) values.append(value) reductions.append(_args(info.dist_reduce)) if same_dsr: do_gather, reduce_kwargs = reductions[0] if do_gather: reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs) else: reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs) for name, reduced_value in zip(names, reduced_values): info = self._infos[name] if info.dist_average: reduced_value /= self._dev_env.world_size self._values_dist[name] = reduced_value else: for n, v, r in zip(names, values, reductions): info = self._infos[n] do_gather, reduce_kwargs = r if do_gather: reduced_value = self._dev_env.all_gather(v, **reduce_kwargs) else: reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs) if info.dist_average: reduced_value /= self._dev_env.world_size self._values_dist[n] = reduced_value