diff --git a/timm/bits/__init__.py b/timm/bits/__init__.py index 33080c73..c9960341 100644 --- a/timm/bits/__init__.py +++ b/timm/bits/__init__.py @@ -1,10 +1,25 @@ +from .avg_scalar import AvgMinMaxScalar +from .avg_tensor import AvgTensor +from .device_env import DeviceEnv, DeviceEnvType +from .device_env_cuda import DeviceEnvCuda from .device_env_factory import initialize_device, get_device -from .device_env import DeviceEnv -#from .evaluate import evaluate, eval_step +from .device_env_xla import DeviceEnvXla +from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\ + all_reduce_sequence, all_gather_sequence +# from .evaluate import evaluate, eval_step from .logger import Logger -#from .task import TaskClassify +from .metric import Metric, MetricValue +from .metric_accuracy import AccuracyTopK +from .tracker import Tracker +# from .task_metrics import TaskMetrics, TaskMetricsClassify +from .train_cfg import TrainCfg +from .train_services import TrainServices +from .train_setup import setup_model_and_optimizer +from .train_state import TrainState +# from .task import TaskClassify from .updater import Updater +from .updater_cuda import UpdaterCudaWithScaler +from .updater_deepspeed import UpdaterDeepSpeed from .updater_factory import create_updater -from .tracker import Tracker -#from .task_metrics import TaskMetrics, TaskMetricsClassify -#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment \ No newline at end of file +from .updater_xla import UpdaterXla, UpdaterXlaWithScaler +# from .train import train_one_epoch, Experiment diff --git a/timm/metrics/scalar_avg.py b/timm/bits/avg_scalar.py similarity index 96% rename from timm/metrics/scalar_avg.py rename to timm/bits/avg_scalar.py index f5d95807..04d41c8e 100644 --- a/timm/metrics/scalar_avg.py +++ b/timm/bits/avg_scalar.py @@ -1,4 +1,4 @@ -class ScalarAvgMinMax: +class AvgMinMaxScalar: """Computes and stores the average and current value""" def __init__(self): diff --git a/timm/metrics/tensor_avg.py b/timm/bits/avg_tensor.py similarity index 98% rename from timm/metrics/tensor_avg.py rename to timm/bits/avg_tensor.py index c9a3489b..0aaf92e3 100644 --- a/timm/metrics/tensor_avg.py +++ b/timm/bits/avg_tensor.py @@ -1,7 +1,7 @@ import torch -class TensorAvg: +class AvgTensor: """Computes and stores the average and current value""" def __init__(self): diff --git a/timm/bits/checkpoint.py b/timm/bits/checkpoint.py new file mode 100644 index 00000000..3c191b0a --- /dev/null +++ b/timm/bits/checkpoint.py @@ -0,0 +1,58 @@ +import logging +import os +from collections import OrderedDict + +import torch + +from .train_state import TrainState, serialize_train_state, deserialize_train_state + +_logger = logging.getLogger(__name__) + + +def resume_train_checkpoint( + train_state, + checkpoint_path, + resume_opt=True, + deserialize_fn=deserialize_train_state, + log_info=True): + + raise NotImplementedError + + # resume_epoch = None + # if os.path.isfile(checkpoint_path): + # checkpoint = torch.load(checkpoint_path, map_location='cpu') + # + # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + # if log_info: + # _logger.info('Restoring model state from checkpoint...') + # new_state_dict = OrderedDict() + # for k, v in checkpoint['state_dict'].items(): + # name = k[7:] if k.startswith('module') else k + # new_state_dict[name] = v + # model.load_state_dict(new_state_dict) + # + # if optimizer is not None and 'optimizer' in checkpoint: + # if log_info: + # _logger.info('Restoring optimizer state from checkpoint...') + # optimizer.load_state_dict(checkpoint['optimizer']) + # + # if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + # if log_info: + # _logger.info('Restoring AMP loss scaler state from checkpoint...') + # loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + # + # if 'epoch' in checkpoint: + # resume_epoch = checkpoint['epoch'] + # if 'version' in checkpoint and checkpoint['version'] > 1: + # resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + # + # if log_info: + # _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + # else: + # model.load_state_dict(checkpoint) + # if log_info: + # _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + # return resume_epoch + # else: + # _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + # raise FileNotFoundError() diff --git a/timm/bits/device_env.py b/timm/bits/device_env.py index 646d64f4..7307823e 100644 --- a/timm/bits/device_env.py +++ b/timm/bits/device_env.py @@ -1,58 +1,130 @@ -import torch import abc +from contextlib import suppress +from enum import Enum +from typing import Callable, Union, Optional, List, Tuple +from dataclasses import dataclass, field, InitVar +import torch +import torch.distributed as dist -class DeviceEnv(abc.ABC): +TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] - @property - @abc.abstractmethod - def device(self) -> torch.device: - pass - @property - @abc.abstractmethod - def local_rank(self) -> int: - pass +class DeviceEnvType(Enum): + """ Device Environment Types + """ + CPU = "cpu" + CUDA = "cuda" + XLA = "xla" + + +@dataclass +class DeviceEnv: + device_type: InitVar[Optional[str]] = None + device_index: InitVar[Optional[int]] = None + + device: torch.device = field(init=False) # set from device_type + device_index or post_init logic + world_size: Optional[int] = None # set by post_init from env when None + local_rank: Optional[int] = None # set by post_init from env when None + global_rank: Optional[int] = None # set by post_init from env when None + amp: bool = False + autocast: Optional[Callable] = None # set by post_init from env when None + memory_format: Optional[torch.memory_format] = None + dtype: Optional[torch.dtype] = None + + def __post_init__(self, device_type: Optional[str], device_index: Optional[int]): + device_type = device_type or 'cpu' + self.device = torch.device(device_type) if device_index is None \ + else torch.device(device_type, device_index) + self.world_size = 1 if self.world_size is None else self.world_size + self.local_rank = 0 if self.local_rank is None else self.local_rank + self.global_rank = 0 if self.global_rank is None else self.global_rank + if self.autocast is None: + self.autocast = suppress @property - @abc.abstractmethod - def global_rank(self) -> int: - pass + def type(self) -> DeviceEnvType: + if self.device.type == 'cpu': + return DeviceEnvType.CPU + elif self.device.type == 'cuda': + return DeviceEnvType.CUDA + elif self.device.type == 'xla': + return DeviceEnvType.XLA + else: + assert False, "Unexpected device type for base DevEnv impl." @property - @abc.abstractmethod - def is_distributed(self) -> bool: - pass + def type_cuda(self): + # shortcut for common cuda device type + return self.type == DeviceEnvType.CUDA @property - @abc.abstractmethod - def world_size(self) -> int: - pass + def type_xla(self): + # shortcut for common xla device type + return self.type == DeviceEnvType.XLA @property - @abc.abstractmethod - def is_master(self) -> bool: - pass + def distributed(self): + return self.world_size > 1 @property - @abc.abstractmethod - def type(self) -> str: - pass + def primary(self): + return self.local_rank == 0 @property - @abc.abstractmethod - def autocast(self): - pass + def global_primary(self): + return self.global_rank == 0 - @abc.abstractmethod def wrap_distributed(self, *modules): pass - @abc.abstractmethod - def to_device(self, *modules: torch.nn.Module): + def wrap_parallel(self, *modules): pass - #@abc.abstractmethod + def to_device(self, *modules: torch.nn.Module): + # FIXME handling dtype / memformat... disable flags, enable flags, diff fn? + moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules] + return moved[0] if len(moved) == 1 else moved + def mark_step(self): - # FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops? - pass \ No newline at end of file + pass # NO-OP for non-XLA devices + + def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False): + print(len(tensor), type(tensor)) + print(tensor.shape) + dist.all_reduce(tensor, op=op) + if average: + tensor.div_(self.world_size) + return tensor + + def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False): + reduce_tensor = tensor.clone() + dist.all_reduce(reduce_tensor, op=op) + if average: + reduce_tensor = reduce_tensor / self.world_size + return reduce_tensor + + def all_gather(self, tensor: torch.Tensor, cat_dim=0): + output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)] + dist.all_gather(output_tensors, tensor) + return torch.cat(output_tensors, cat_dim) + + def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0): + input_tensors = torch.chunk(tensor, num_splits, split_dim) + output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)] + dist.all_to_all(output_tensors, input_tensors) + return torch.cat(output_tensors, cat_dim) + + def broadcast_(self, tensor: torch.Tensor, src_rank=0): + dist.broadcast(tensor, src=src_rank) + return tensor + + def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0): + if self.global_rank != src_rank: + tensor = torch.empty_like(tensor) + assert tensor is not None + dist.broadcast(tensor, src=src_rank) + return tensor + + def barrier(self): + dist.barrier() diff --git a/timm/bits/device_env_cuda.py b/timm/bits/device_env_cuda.py index d609bd2a..7358e405 100644 --- a/timm/bits/device_env_cuda.py +++ b/timm/bits/device_env_cuda.py @@ -1,92 +1,58 @@ import os from contextlib import suppress +from dataclasses import dataclass, field, InitVar +from typing import Optional import torch -from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel import DistributedDataParallel, DataParallel -from .device_env import DeviceEnv +from .device_env import DeviceEnv, DeviceEnvType def is_cuda_available(): return torch.cuda.is_available() +@dataclass class DeviceEnvCuda(DeviceEnv): - def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None): + def __post_init__(self, device_type: str, device_index: Optional[int]): assert torch.cuda.device_count() torch.backends.cudnn.benchmark = True - self._local_rank = 0 - self._distributed = False - self._world_size = 1 - self._global_rank = 0 - if 'WORLD_SIZE' in os.environ: - self._distributed = int(os.environ['WORLD_SIZE']) > 1 - if self._distributed: - if local_rank is None: + setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1)) + assert setup_world_size + if setup_world_size > 1: + # setup distributed + assert device_index is None + if self.local_rank is None: lr = os.environ.get('LOCAL_RANK', None) if lr is None: raise RuntimeError( 'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.') - self._local_rank = lr - else: - self._local_rank = int(local_rank) - self._device = torch.device('cuda:%d' % self._local_rank) - torch.cuda.set_device(self._local_rank) + self.local_rank = int(lr) + self.device = torch.device('cuda:%d' % self.local_rank) + torch.cuda.set_device(self.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') - self._world_size = torch.distributed.get_world_size() - self._global_rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + assert self.world_size == setup_world_size + self.global_rank = torch.distributed.get_rank() else: - self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}') - self._memory_format = memory_format - if amp: - self._amp = amp - self._autocast = torch.cuda.amp.autocast - else: - self._amp = amp - self._autocast = suppress - - @property - def device(self): - return self._device - - @property - def local_rank(self): - return self._local_rank - - @property - def global_rank(self): - return self._global_rank - - @property - def is_distributed(self): - return self._distributed + self.device = torch.device('cuda' if device_index is None else f'cuda:{device_index}') + self.local_rank = 0 + self.world_size = 1 + self.global_rank = 0 + if self.autocast is None: + self.autocast = torch.cuda.amp.autocast if self.amp else suppress @property - def world_size(self): - return self._world_size - - @property - def is_master(self): - return self._local_rank == 0 - - @property - def type(self) -> str: - return 'cuda' - - @property - def amp(self) -> bool: - return self._amp - - @property - def autocast(self): - return self._autocast + def type(self) -> DeviceEnvType: + return DeviceEnvType.CUDA def wrap_distributed(self, *modules, **kwargs): - wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules] + wrapped = [DistributedDataParallel(m, device_ids=[self.local_rank], **kwargs) for m in modules] return wrapped[0] if len(wrapped) == 1 else wrapped - def to_device(self, *modules: torch.nn.Module): - # FIXME handling dtype / memformat... disable flags, enable flags, diff fn? - moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules] - return moved[0] if len(moved) == 1 else moved + def wrap_parallel(self, *modules, **kwargs): + assert not self.distributed + wrapped = [DataParallel(m, **kwargs) for m in modules] + return wrapped[0] if len(wrapped) == 1 else wrapped diff --git a/timm/bits/device_env_factory.py b/timm/bits/device_env_factory.py index f6dc14f3..2037a39e 100644 --- a/timm/bits/device_env_factory.py +++ b/timm/bits/device_env_factory.py @@ -1,10 +1,11 @@ +from .device_env import DeviceEnv from .device_env_cuda import DeviceEnvCuda, is_cuda_available from .device_env_xla import DeviceEnvXla, is_xla_available _device_env = None -def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs): +def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv: global _device_env if _device_env is not None: # warning @@ -12,21 +13,22 @@ def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs): denv = None if not force_cpu: + xla_device_type = kwargs.get('xla_device_type', None) if is_xla_available(xla_device_type): - # XLA supports more than just TPU, but by default will only look at TPU - denv = DeviceEnvXla(**kwargs, xla_device_type=xla_device_type) + # XLA supports more than just TPU, will search in order TPU, GPU, CPU + denv = DeviceEnvXla(**kwargs) elif is_cuda_available(): denv = DeviceEnvCuda(**kwargs) if denv is None: - # FIXME implement CPU support - raise NotImplementedError() + denv = DeviceEnv() + print(denv) # FIXME DEBUG _device_env = denv return denv -def get_device(): +def get_device() -> DeviceEnv: if _device_env is None: raise RuntimeError('Please initialize device environment by calling initialize_device first.') return _device_env diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py index 518cd993..cc9ea3dd 100644 --- a/timm/bits/device_env_xla.py +++ b/timm/bits/device_env_xla.py @@ -1,6 +1,10 @@ import os from contextlib import suppress +from dataclasses import dataclass, field, InitVar +from typing import Optional + import torch +from torch.distributed import ReduceOp try: import torch_xla.core.xla_model as xm @@ -15,78 +19,102 @@ try: except ImportError as e: xa = None -from .device_env import DeviceEnv +from .device_env import DeviceEnv, DeviceEnvType, TensorList + + +_PT_TO_XM_OP = { + ReduceOp.SUM: 'sum', + ReduceOp.PRODUCT: 'prod', + ReduceOp.MIN: 'min', + ReduceOp.MAX: 'max', + ReduceOp.BAND: 'and', + ReduceOp.BOR: 'or', +} def is_xla_available(xla_device_type=None): if not _HAS_XLA: return False supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type) - print(supported_devs) return len(supported_devs) >= 1 +@dataclass class DeviceEnvXla(DeviceEnv): - def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False): - self._device = xm.xla_device(n=device_idx, devkind=xla_device_type) - self._local_rank = xm.get_local_ordinal(local_rank) - self._world_size = xm.xrt_world_size() - self._distributed = self._world_size > 1 - self._global_rank = 0 - if self._distributed: - self._global_rank = xm.get_ordinal() - if amp: - assert xa is not None, 'XLA AMP is not present on this build' - self._autocast = xa.autocast + def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]): + if device_type is not None: + device_type = device_type.upper() + assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')" + self.device = xm.xla_device(n=device_idx, devkind=device_type) + self.world_size = xm.xrt_world_size() + if self.distributed: + assert device_idx is None, "device_index is based on local rank for distributed XLA mode" + self.local_rank = xm.get_local_ordinal() + self.global_rank = xm.get_ordinal() else: - self._autocast = suppress - self._memory_format = None - - @property - def device(self): - return self._device - - @property - def local_rank(self): - return self._local_rank - - @property - def global_rank(self): - return self._global_rank - - @property - def is_distributed(self): - return self._distributed - - @property - def world_size(self): - return self._world_size - - @property - def is_master(self): - return self._global_rank == 0 - - @property - def type(self) -> str: - return 'xla' - - @property - def amp(self) -> bool: - return False + self.local_rank = 0 + self.global_rank = 0 + if self.amp: + assert xa is not None, 'XLA AMP is not present on this build' + if self.autocast is None: + self.autocast = xa.autocast if self.amp else suppress @property - def autocast(self): - return self._autocast + def type(self) -> DeviceEnvType: + return DeviceEnvType.XLA def wrap_distributed(self, *modules): - # NO-OP - wrapped = [m for m in modules] + wrapped = [m for m in modules] # NO-OP return wrapped[0] if len(wrapped) == 1 else wrapped - def to_device(self, *modules: torch.nn.Module): - moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules] - return moved[0] if len(moved) == 1 else moved + def wrap_parallel(self, *modules): + assert False, "Not implemented" def mark_step(self): xm.mark_step() + + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False): + assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed + op = _PT_TO_XM_OP[op] + scale = 1.0 + if average: + scale /= self.world_size + return xm.all_reduce(op, tensor, scale=scale) + + def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False): + op = _PT_TO_XM_OP[op] + scale = 1.0 + wrapped = False + if isinstance(tensor, torch.Tensor): + tensor = [tensor] # bare tensors are not operated on in-place + wrapped = True + if average: + scale /= self.world_size + xm.all_reduce(op, tensor, scale=scale) + if wrapped: + tensor = tensor[0] + return tensor + + def all_gather(self, tensor: torch.Tensor, cat_dim=0): + output = xm.all_gather(tensor, cat_dim) + return output + + def all_to_all(self, tensor, num_splits, split_dim, cat_dim=0): + output = xm.all_to_all(tensor, split_dim, cat_dim, num_splits) + return output + + def broadcast(self, tensor: torch.Tensor, src_rank=0): + if self.global_rank != src_rank: + reduce_tensor = torch.zeros_like(tensor) + xm.all_reduce('sum', reduce_tensor) + else: + xm.all_reduce('sum', tensor) + return tensor + + def broadcast_(self, tensor: torch.Tensor, src_rank=0): + out_tensor = self.broadcast(tensor, src_rank) + return tensor.copy_(out_tensor) + + def barrier(self): + xm.rendezvous('timm.bits.dist_barrier') diff --git a/timm/bits/distributed.py b/timm/bits/distributed.py new file mode 100644 index 00000000..55f9adf5 --- /dev/null +++ b/timm/bits/distributed.py @@ -0,0 +1,151 @@ +from typing import Dict, Tuple, List, Union, Any, Callable + +import torch +from torch.distributed import ReduceOp + +from timm.utils import unwrap_model + +from .device_env import DeviceEnv, DeviceEnvType +from .device_env_factory import get_device + + +TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]] + + +def _validate_type(tensor: TensorSeq): + if isinstance(tensor, (dict, list, tuple)): + if not tensor: + return + else: + assert isinstance(tensor, torch.Tensor) + + +def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None): + if dev_env is None: + dev_env = get_device() + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + if reduce: + # average bn stats across whole group + dev_env.all_reduce_(bn_buf, average=True) + else: + # broadcast bn stats from rank 0 to whole group + dev_env.broadcast_(bn_buf, 0) + + +def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None): + """ Recursive all gather via DeviceEnv distributed primitives + FIXME add group support + """ + _validate_type(tensor) + if dev_env is None: + dev_env = get_device() + if isinstance(tensor, torch.Tensor): + return dev_env.all_gather(tensor, cat_dim=cat_dim) + elif isinstance(tensor, dict): + return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()} + elif isinstance(tensor, (tuple, list)): + return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor) + + +def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None): + """ Recursive all reduce via DeviceEnv distributed primitives + FIXME add group support + """ + _validate_type(tensor) + if dev_env is None: + dev_env = get_device() + if isinstance(tensor, torch.Tensor): + return dev_env.all_reduce_(tensor, op=op, average=average) + elif isinstance(tensor, dict): + return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()} + elif isinstance(tensor, (tuple, list)): + return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor) + + +def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None): + """ Recursive broadcast via DeviceEnv distributed primitives + FIXME add group support + """ + _validate_type(tensor) + if dev_env is None: + dev_env = get_device() + if isinstance(tensor, torch.Tensor): + return dev_env.broadcast_(tensor, src_rank=src_rank) + elif isinstance(tensor, dict): + return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()} + elif isinstance(tensor, (tuple, list)): + return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor) + + +def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None): + """ All gather a flat Tensor sequence (dict, list, tuple) of same shape + + """ + _validate_type(tensor) + if dev_env is None: + dev_env = get_device() + + with torch.no_grad(): + names = None + # merge values into one tensor for reduction + if isinstance(tensor, dict): + names = tensor.keys() + gather_values = tuple(tensor.values()) + elif isinstance(tensor, (tuple, list)): + gather_values = tensor + else: + gather_values = (tensor,) + + gather_values = torch.stack(gather_values, dim=0) + gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0) + + # separate reduced values into original structure + if isinstance(tensor, dict): + gather_values = {k: v for k, v in zip(names, gather_values)} + elif isinstance(tensor, (tuple, list)): + gather_values = type(tensor)(v for v in gather_values) + else: + gather_values = gather_values[0] + + return gather_values + + +def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None): + """ + All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape + + Args: + tensor (dict): inputs to be reduced. All the values must be scalar Tensor. + average (bool): whether to do average or sum + Returns: + a sequence with the same type as input (dict, list, tuple) + """ + _validate_type(tensor) + if dev_env is None: + dev_env = get_device() + + with torch.no_grad(): + names = None + # merge values into one tensor for reduction + if isinstance(tensor, dict): + names = tensor.keys() + reduce_values = tuple(tensor.values()) + elif isinstance(tensor, (tuple, list)): + reduce_values = tensor + else: + reduce_values = (tensor,) + + reduce_values = torch.stack(reduce_values, dim=0) + dev_env.all_reduce_(reduce_values, op=op, average=average) + reduce_values = reduce_values.unbind(dim=0) + # separate reduced values into original structure + if isinstance(tensor, dict): + reduce_values = {k: v for k, v in zip(names, reduce_values)} + elif isinstance(tensor, (tuple, list)): + reduce_values = type(tensor)(v for v in reduce_values) + else: + reduce_values = reduce_values[0] + + return reduce_values \ No newline at end of file diff --git a/timm/bits/distributed_torch.py b/timm/bits/distributed_torch.py new file mode 100644 index 00000000..20f7036c --- /dev/null +++ b/timm/bits/distributed_torch.py @@ -0,0 +1,190 @@ +""" PyTorch distributed helpers + +Some of this lifted from Detectron2 with other fns added by myself. + +FIXME many functions remain unfinished/untested +""" +from typing import Dict, Tuple, List, Union, Any, Callable + +import torch +import torch.distributed as dist +from torch.distributed import ReduceOp + +TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]] + + +def synchronize_torch(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None): + """ + All reduce the tensors in a sequence (dict, list, tuple) + + Args: + values (dict): inputs to be reduced. All the values must be scalar Tensor. + average (bool): whether to do average or sum + Returns: + a sequence with the same type as input (dict, list, tuple) + """ + world_size = dist.get_world_size(group) + if world_size <= 1: + return values + + with torch.no_grad(): + names = None + if isinstance(values, dict): + names = values.keys() + reduce_values = torch.stack(tuple(values.values()), dim=0) + elif isinstance(values, (tuple, list)): + reduce_values = torch.stack(values, dim=0) + else: + reduce_values = values + dist.all_reduce(reduce_values, op=op, group=group) + if average: + reduce_values /= world_size + if isinstance(values, dict): + reduce_values = {k: v for k, v in zip(names, reduce_values)} + elif isinstance(values, (tuple, list)): + reduce_values = type(values)(v for v in reduce_values) + return reduce_values + + +def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None): + """ + All reduce the tensors in a sequence (dict, list, tuple) + + Args: + values (dict): inputs to be reduced. All the values must be scalar Tensor. + average (bool): whether to do average or sum + Returns: + a sequence with the same type as input (dict, list, tuple) + """ + world_size = dist.get_world_size(group) + this_rank = dist.get_rank() + if world_size <= 1: + return values + + with torch.no_grad(): + names = None + if isinstance(values, dict): + names = values.keys() + reduce_values = torch.stack(tuple(values.values()), dim=0) + elif isinstance(values, (tuple, list)): + reduce_values = torch.stack(values, dim=0) + else: + reduce_values = values + reduce_values = torch.stack(reduce_values, dim=0) + dist.reduce(reduce_values, dst=dst_rank, op=op, group=group) + if average and this_rank == dst_rank: + reduce_values /= world_size + if isinstance(values, dict): + reduce_values = {k: v for k, v in zip(names, reduce_values)} + elif isinstance(values, (tuple, list)): + reduce_values = type(values)(v for v in reduce_values) + return reduce_values + + +def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0): + world_size = dist.get_world_size(group) + + def _do_gather(tensor): + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=group) + return join_fn(tensor_list, dim=join_dim) + + if isinstance(values, dict): + gathered = {k: _do_gather(v) for k, v in values.items()} + return gathered + elif isinstance(values, (list, tuple)): + gathered = type(values)(_do_gather(v) for v in values) + return gathered + else: + # if not a dict, list, tuple, expect a singular tensor + assert isinstance(values, torch.Tensor) + return _do_gather(values) + + +def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0): + world_size = dist.get_world_size(group) + this_rank = dist.get_rank(group) + + def _do_gather(tensor): + tensor_list = None + if this_rank == dst_rank: + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.gather(tensor, tensor_list, dst=dst_rank, group=group) + return join_fn(tensor_list, dim=join_dim) + + if isinstance(values, dict): + gathered = {k: _do_gather(v) for k, v in values.items()} + return gathered + elif isinstance(values, (list, tuple)): + gathered = type(values)(_do_gather(v) for v in values) + return gathered + else: + # if not a dict, list, tuple, expect a singular tensor + assert isinstance(values, torch.Tensor) + return _do_gather(values) + + +def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0): + if isinstance(value, torch.Tensor): + world_size = dist.get_world_size(group) + out_tensors = [torch.empty_like(value) for _ in range(world_size)] + dist.all_gather(out_tensors, value, group=group) + if join_fn is not None: + out_tensors = join_fn(out_tensors, dim=join_dim) + return out_tensors + elif isinstance(value, dict): + return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()} + elif isinstance(value, (tuple, list)): + return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value) + + +def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0): + if isinstance(value, torch.Tensor): + world_size = dist.get_world_size(group) + this_rank = dist.get_rank() + out_tensors = None + if this_rank == dst_rank: + out_tensors = [torch.empty_like(value) for _ in range(world_size)] + dist.gather(value, out_tensors, dst=dst_rank, group=group) + if join_fn is not None: + out_tensors = join_fn(out_tensors, dim=join_dim) + return out_tensors + elif isinstance(value, dict): + return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()} + elif isinstance(value, (tuple, list)): + return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value) + + +def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None): + if isinstance(value, torch.Tensor): + dist.all_reduce(value, op=op, group=group) + if average: + value /= dist.get_world_size(group) + elif isinstance(value, dict): + return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()} + elif isinstance(value, (tuple, list)): + return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value) + + +def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None): + if isinstance(value, torch.Tensor): + return dist.broadcast(value, src=src_rank, group=group) + elif isinstance(value, dict): + return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()} + elif isinstance(value, (tuple, list)): + return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value) \ No newline at end of file diff --git a/timm/bits/grad_clipper.py b/timm/bits/grad_clip.py similarity index 58% rename from timm/bits/grad_clipper.py rename to timm/bits/grad_clip.py index 232f5fc0..ba1d846c 100644 --- a/timm/bits/grad_clipper.py +++ b/timm/bits/grad_clip.py @@ -16,21 +16,11 @@ def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0): assert False, f"Unknown clip mode ({mode})." -def get_clip_parameters(model): +def get_clip_parameters(model, skip_last=0): if hasattr(model, 'get_clip_parameters'): return model.get_clip_parameters() else: - return model.parameters() - - -class GradClipper: - - def __init__(self, model, clip_value, clip_mode='norm'): - self.model = model - self.clip_fn = get_clip_grad_fn(clip_mode) - self.clip_value = clip_value - self.enabled = True - - def __call__(self): - if self.enabled: - self.clip_fn(get_clip_parameters(self.model), self.clip_value) \ No newline at end of file + if skip_last: + return list(model.parameters())[::-skip_last] + else: + return model.parameters() diff --git a/timm/bits/logger.py b/timm/bits/logger.py index d9ad41af..a7948a8b 100644 --- a/timm/bits/logger.py +++ b/timm/bits/logger.py @@ -21,6 +21,8 @@ except ImportError: HAS_WANDB = False +from .device_env_factory import get_device + # FIXME old formatting for reference, to remove # # def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''): @@ -84,10 +86,16 @@ class SummaryCsv: dw.writerow(row_dict) +_sci_keys = {'lr'} + + def _add_kwargs(text_update, name_map=None, **kwargs): def _to_str(key, val): if isinstance(val, float): - return f'{key}: {val:.4f}' + if key.lower() in _sci_keys: + return f'{key}: {val:.3e} ' + else: + return f'{key}: {val:.4f}' else: return f'{key}: {val}' @@ -120,12 +128,13 @@ class Logger: self, experiment_name=None, output_dir=None, - logger=None, + python_logger=None, + hparams=None, log_wandb=False, - hparams=None): - - self.output_dir = output_dir # for tensorboard, csv, console logging to file? - self.logger = logger or logging.getLogger('log') + output_enabled=True, + ): + self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging + self.logger = python_logger or logging.getLogger('log') hparams = hparams or {} # Setup CSV writer(s) @@ -146,28 +155,32 @@ class Logger: _logger.warning("You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") + self.output_enabled = output_enabled # FIXME image save def log_step( self, phase: str, step: int, - end_step: Optional[int] = None, + step_end: Optional[int] = None, + epoch: Optional[int] = None, loss: Optional[float] = None, rate: Optional[float] = None, - epoch: Optional[int] = None, phase_suffix: str = '', **kwargs, ): """ log train/eval step """ - phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}' - progress = 100. * step / end_step if end_step else 0. + if not self.output_enabled: + return + + phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:' + progress = 100. * step / step_end if step_end else 0. text_update = [ phase_title, - f'Epoch: {epoch}' if epoch is not None else None, - f'Step: {step}' if end_step is None else None, - f'Step: [{step}/{end_step} ({progress:>3.0f}%)]' if end_step is not None else None, + f'{epoch}' if epoch is not None else None, + f'[{step}]' if step_end is None else None, + f'[{step}/{step_end} ({progress:>3.0f}%)]' if step_end is not None else None, f'Rate: {rate:.2f}/s' if rate is not None else None, f'Loss: {loss:.5f}' if loss is not None else None, ] @@ -187,6 +200,9 @@ class Logger: ): """log completion of evaluation or training phase """ + if not self.output_enabled: + return + title = [ f'{phase.capitalize()}', f'epoch: {epoch}' if epoch is not None else None, @@ -212,6 +228,9 @@ class Logger: index: value for row index (typically epoch #) index_name: name for row index header (typically 'epoch') """ + if not self.output_enabled: + return + row_dict = summary_row_dict(index=index, index_name=index_name, results=results) if self.csv_writer: self.csv_writer.update(row_dict) diff --git a/timm/bits/metric.py b/timm/bits/metric.py new file mode 100644 index 00000000..7a5cc997 --- /dev/null +++ b/timm/bits/metric.py @@ -0,0 +1,142 @@ +import abc +from typing import Callable, Union, Optional, List, Tuple, Dict +from dataclasses import dataclass + +import torch +from torch.distributed import ReduceOp + +from .device_env import DeviceEnv +from .device_env_factory import get_device +from .distributed import all_gather_sequence, all_reduce_sequence + +MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]] + +@dataclass +class ValueInfo: + initial: Optional[MetricValue] = 0. + dtype: torch.dtype = torch.float32 + dist_reduce: str = 'sum' + dist_average: bool = False + + +class Metric(abc.ABC): + + def __init__(self, dev_env: DeviceEnv = None): + self._infos: Dict[str, ValueInfo] = {} + self._values: Dict[str, Optional[MetricValue]] = {} + self._values_dist: Dict[str, Optional[MetricValue]] = {} + if dev_env is None: + dev_env = get_device() + self._dev_env = dev_env + + def _register_value(self, name: str, info: Optional[ValueInfo] = None): + info = info or ValueInfo() + self._infos[name] = info + + # def get_value(self, name: str, use_dist=True): + # if use_dist: + # return self._values_dist.get(name, self._values.get(name)) + # else: + # return self._values.get(name) + + def __getattr__(self, item): + if item not in self._infos: + raise AttributeError + value = self._values_dist.get(item, self._values.get(item, None)) + return value + + def __setattr__(self, key, value): + if '_infos' in self.__dict__ and key in self._infos: + self._values[key] = value + else: + super().__setattr__(key, value) + + def update( + self, + predictions: Union[torch.Tensor, Dict[str, torch.Tensor]], + target: Union[torch.Tensor, Dict[str, torch.Tensor]]): + self._update(predictions, target) + + def _update( + self, + predictions: Union[torch.Tensor, Dict[str, torch.Tensor]], + target: Union[torch.Tensor, Dict[str, torch.Tensor]]): + pass + + def reset(self): + self._values = {} + self._values_dist = {} + for name, info in self._infos.items(): + # if info specifies an initial value, we reset here, otherwise set to None and leave it to child class + if info.initial is not None: + if isinstance(info.initial, torch.Tensor): + tensor = info.initial.detach().clone() + else: + tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar + self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype) + else: + self._values[name] = None + self._reset() + + def _reset(self): + pass + + def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]: + if self._dev_env.distributed: + self._distribute_values() + results = self._compute() + self._values_dist = {} + return results + + @abc.abstractmethod + def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]: + pass + + def _distribute_values(self): + if not self._infos or not self._values: + return + + def _args(op: str): + if op == 'cat': + return True, dict(cat_dim=0) + else: + return False, dict(op=ReduceOp.SUM) + + prev_dsr = None + same_dsr = True + names = [] + values = [] + reductions = [] + for name, value in self._values.items(): + if value is not None: + info = self._infos[name] + dsr = (value.dtype, value.shape, info.dist_reduce) + if prev_dsr is not None and prev_dsr != dsr: + same_dsr = False + prev_dsr = dsr + names.append(name) + values.append(value) + reductions.append(_args(info.dist_reduce)) + same_dsr = False + if same_dsr: + do_gather, reduce_kwargs = reductions[0] + if do_gather: + reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs) + else: + reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs) + for name, reduced_value in zip(names, reduced_values): + info = self._infos[name] + if info.dist_average: + reduced_value /= self._dev_env.world_size + self._values_dist[name] = reduced_value + else: + for n, v, r in zip(names, values, reductions): + info = self._infos[n] + do_gather, reduce_kwargs = r + if do_gather: + reduced_value = self._dev_env.all_gather(v, **reduce_kwargs) + else: + reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs) + if info.dist_average: + reduced_value /= self._dev_env.world_size + self._values_dist[n] = reduced_value diff --git a/timm/bits/metric_accuracy.py b/timm/bits/metric_accuracy.py new file mode 100644 index 00000000..0db72c6d --- /dev/null +++ b/timm/bits/metric_accuracy.py @@ -0,0 +1,98 @@ +import torch +from typing import Optional, Tuple, Dict + +from .device_env import DeviceEnv +from .metric import Metric, ValueInfo + + +class Accuracy(Metric): + + def __init__(self, threshold=0.5, multi_label=False, dev_env=None): + super().__init__(dev_env=dev_env) + self.threshold = threshold + self.eps = 1e-8 + self.multi_label = multi_label + + # statistics / counts + self._register_value('correct') + self._register_value('total') + + def _update(self, predictions, target): + raise NotImplemented() + + def _compute(self): + raise NotImplemented() + + +# class AccuracyTopK(torch.nn.Module): +# +# def __init__(self, topk=(1, 5), device=None): +# super().__init__() +# self.eps = 1e-8 +# self.device = device +# self.topk = topk +# self.maxk = max(topk) +# # FIXME handle distributed operation +# +# # statistics / counts +# self.reset() +# +# def update(self, predictions: torch.Tensor, target: torch.Tensor): +# sorted_indices = predictions.topk(self.maxk, dim=1)[1] +# sorted_indices.t_() +# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices)) +# +# batch_size = target.shape[0] +# correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk} +# for k, v in correct_k.items(): +# attr = f'_correct_top{k}' +# old_v = getattr(self, attr) +# setattr(self, attr, old_v + v) +# self._total_sum += batch_size +# +# def reset(self): +# for k in self.topk: +# setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32)) +# self._total_sum = torch.tensor(0, dtype=torch.float32) +# +# @property +# def counts(self): +# pass +# +# def compute(self) -> Dict[str, torch.Tensor]: +# # FIXME handle distributed reduction +# return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk} + + +class AccuracyTopK(Metric): + + def __init__(self, topk=(1, 5), dev_env: DeviceEnv = None): + super().__init__(dev_env=dev_env) + self.eps = 1e-8 + self.topk = topk + self.maxk = max(topk) + + # statistics / counts + for k in self.topk: + self._register_value(f'top{k}') + self._register_value('total') + self.reset() + + def _update(self, predictions: torch.Tensor, target: torch.Tensor): + batch_size = predictions.shape[0] + sorted_indices = predictions.topk(self.maxk, dim=1)[1] + target_reshape = target.reshape(-1, 1).expand_as(sorted_indices) + correct = sorted_indices.eq(target_reshape).float().sum(0) + for k in self.topk: + attr_name = f'top{k}' + correct_at_k = correct[:k].sum() + setattr(self, attr_name, getattr(self, attr_name) + correct_at_k) + self.total += batch_size + + def _compute(self) -> Dict[str, torch.Tensor]: + assert self.total is not None + output = {} + for k in self.topk: + attr_name = f'top{k}' + output[attr_name] = 100 * getattr(self, attr_name) / self.total + return output diff --git a/timm/metrics/precision_recall.py b/timm/bits/metric_precision_recall.py similarity index 100% rename from timm/metrics/precision_recall.py rename to timm/bits/metric_precision_recall.py diff --git a/timm/bits/tracker.py b/timm/bits/tracker.py index 12e0106b..7abbf95e 100644 --- a/timm/bits/tracker.py +++ b/timm/bits/tracker.py @@ -1,16 +1,16 @@ import time from typing import Optional -from timm.metrics import ScalarAvgMinMax +from .avg_scalar import AvgMinMaxScalar class Tracker: def __init__(self): - self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples - self.step_time = ScalarAvgMinMax() # time for model step - self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping - self.epoch_time = ScalarAvgMinMax() + self.data_time = AvgMinMaxScalar() # time for data loader to produce batch of samples + self.step_time = AvgMinMaxScalar() # time for model step + self.iter_time = AvgMinMaxScalar() # full iteration time incl. data, step, and book-keeping + self.epoch_time = AvgMinMaxScalar() self.iter_timestamp: Optional[float] = None self.prev_timestamp: Optional[float] = None @@ -48,3 +48,12 @@ class Tracker: self.epoch_time.update(epoch_time) self.epoch_timestamp = timestamp + def get_avg_iter_rate(self, num_per_iter: int): + if num_per_iter == 0 or self.iter_time.avg == 0: + return 0 + return num_per_iter / self.iter_time.avg + + def get_last_iter_rate(self, num_per_iter: int): + if num_per_iter == 0 or self.iter_time.val == 0: + return 0 + return num_per_iter / self.iter_time.val diff --git a/timm/bits/train_cfg.py b/timm/bits/train_cfg.py new file mode 100644 index 00000000..d7b35faf --- /dev/null +++ b/timm/bits/train_cfg.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass + + +@dataclass +class TrainCfg: + """ Train Loop Configuration + Dataclass to propagate training configuration values + """ + num_epochs: int = 0 + log_interval: int = 50 + recovery_interval: int = 0 + accumulate_steps: int = 0 diff --git a/timm/bits/train_services.py b/timm/bits/train_services.py new file mode 100644 index 00000000..286a4afc --- /dev/null +++ b/timm/bits/train_services.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from .logger import Logger +from timm.utils.checkpoint_saver import CheckpointSaver + + +@dataclass +class TrainServices: + """ Train Loop Services + """ + logger: Logger = None + saver: CheckpointSaver = None + diff --git a/timm/bits/train_setup.py b/timm/bits/train_setup.py new file mode 100644 index 00000000..992546a7 --- /dev/null +++ b/timm/bits/train_setup.py @@ -0,0 +1,153 @@ +import dataclasses +from typing import Callable, Union, Optional +import logging + +import torch +import torch.nn as nn + +from timm.optim import create_optimizer_v2 +from timm.utils import ModelEmaV2 + +try: + import deepspeed as ds +except ImportError: + ds = None + +from .checkpoint import resume_train_checkpoint +from .device_env import DeviceEnv +from .train_cfg import TrainCfg +from .train_state import TrainState +from .updater_factory import create_updater + + +_logger = logging.getLogger(__name__) + + +def setup_model_and_optimizer( + dev_env: DeviceEnv, + model: nn.Module, + optimizer: Union[Callable, str], + optimizer_cfg, + clip_fn: Optional[Union[Callable, str]] = None, + clip_value: Optional[float] = None, + model_ema: bool = False, + model_ema_decay: float = 0.9999, + use_syncbn: bool = False, + resume_path: str = '', + resume_opt: bool = True, + deepspeed: bool = False, +): + """ + + Args: + dev_env: + model: + optimizer: + optimizer_cfg: + clip_value: + clip_fn: + model_ema: + model_ema_decay: + use_syncbn: + resume_path: + resume_opt: + + Returns: + + """ + if deepspeed: + return setup_model_and_optimizer_deepspeed( + dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg, + clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay, + resume_path=resume_path, resume_opt=resume_opt, + ) + + dev_env.to_device(model) + + if use_syncbn and dev_env.distributed: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if dev_env.primary: + _logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if isinstance(optimizer, Callable): + optimizer = optimizer(model=model, **optimizer_cfg) + else: + optimizer = create_optimizer_v2(model=model, **optimizer_cfg) + + updater = create_updater( + model=model, + optimizer=optimizer, + clip_fn=clip_fn, + clip_value=clip_value, + ) + + # ema model + model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None + + train_state = TrainState(model=model, updater=updater, model_ema=model_ema) + + if resume_path: + resume_train_checkpoint( + train_state, + resume_path, + resume_opt=resume_opt, + log_info=dev_env.primary) + + if dev_env.distributed: + train_state = dataclasses.replace( + train_state, model=dev_env.wrap_distributed(train_state.model)) + + return train_state + + +def setup_model_and_optimizer_deepspeed( + dev_env: DeviceEnv, + model: nn.Module, + optimizer: Union[Callable, str], + optimizer_cfg, + clip_fn: Optional[Union[Callable, str]] = None, + clip_value: Optional[float] = None, + model_ema: bool = False, + model_ema_decay: float = 0.9999, + use_syncbn: bool = False, + resume_path: str = '', + resume_opt: bool = True, +): + dev_env.to_device(model) + + if isinstance(optimizer, Callable): + optimizer = optimizer(model=model, **optimizer_cfg) + else: + optimizer = create_optimizer_v2(model=model, **optimizer_cfg) + + model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False) + + updater = create_updater( + model=model, + optimizer=optimizer, + clip_fn=clip_fn, + clip_value=clip_value, + deepspeed=True, + ) + + # ema model + # FIXME how to do EMA w/ deepspeed? + model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None + + train_state = TrainState(model=model, updater=updater, model_ema=model_ema) + + if resume_path: + # FIXME deepspeed resumes differently + resume_train_checkpoint( + train_state, + resume_path, + resume_opt=resume_opt, + log_info=dev_env.primary) + + if dev_env.distributed: + train_state = dataclasses.replace( + train_state, model=dev_env.wrap_distributed(train_state.model)) + + return train_state diff --git a/timm/bits/train_state.py b/timm/bits/train_state.py new file mode 100644 index 00000000..9a9a0d92 --- /dev/null +++ b/timm/bits/train_state.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import Dict, Any + +from torch import nn as nn + +from timm.scheduler import Scheduler +from .updater import Updater + + +@dataclass +class TrainState: + model: nn.Module = None + train_loss: nn.Module = None + eval_loss: nn.Module = None + updater: Updater = None + lr_scheduler: Scheduler = None + model_ema: nn.Module = None + + step_count_epoch: int = 0 + step_count_global: int = 0 + epoch: int = 0 + + def __post_init__(self): + assert self.model is not None + assert self.updater is not None + + +def serialize_train_state(train_state: TrainState): + pass + + +def deserialize_train_state(checkpoint: Dict[str, Any]): + pass \ No newline at end of file diff --git a/timm/bits/updater.py b/timm/bits/updater.py index 6612c8ea..422d12ec 100644 --- a/timm/bits/updater.py +++ b/timm/bits/updater.py @@ -1,54 +1,68 @@ +from dataclasses import dataclass, field, InitVar +from functools import partial from typing import Callable, Optional, Union import torch +import torch.nn as nn -from .grad_clipper import GradClipper +from .grad_clip import get_clip_grad_fn, get_clip_parameters +@dataclass class Updater: - def __init__( - self, - optimizer: torch.optim.Optimizer, - clip_value: Optional[Union[Callable, float]] = None, - clip_mode: str = 'norm'): - - self.optimizer = optimizer - self.clipper: Optional[GradClipper] = None - if clip_value is not None: - if isinstance(clip_value, Callable): - self.clipper = clip_value + model: nn.Module = None + optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model + clip_fn: Optional[Union[Callable, str]] = None + clip_value: Optional[float] = None + clip_params_fn: Optional[Callable] = None + grad_scaler: Optional[Callable] = None + create_graph: Optional[bool] = None + after_step_closure: bool = False + + def __post_init__(self): + assert self.model is not None + assert self.optimizer is not None + if self.clip_fn is not None: + if isinstance(self.clip_fn, Callable): + skip_last = 0 else: - GradClipper(clip_value, clip_mode) - self.scaler = None - self.create_graph = getattr(self.optimizer, 'second_order', False) - self.num_accumulated = 0 + assert isinstance(self.clip_fn, str) + skip_last = 2 if 'agc' in self.clip_fn else 0 + self.clip_fn = get_clip_grad_fn(self.clip_fn) + assert self.clip_value is not None + self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last) + if self.create_graph is None: + self.create_graph = getattr(self.optimizer, 'second_order', False) self.after_step_closure = False + def reset(self): + self.optimizer.zero_grad() + def apply(self, loss: torch.Tensor, accumulate=False): loss.backward(create_graph=self.create_graph) - if self.clipper is not None: - self.clipper() - if not accumulate: - self.optimizer.step() - self.reset() - else: - self.num_accumulated += 1 + if accumulate: + return + if self.clip_fn is not None: + self.clip_fn(self.clip_params_fn(), self.clip_value) + self.optimizer.step() + self.reset() - def reset(self): - self.optimizer.zero_grad() - self.num_accumulated = 0 + def get_average_lr(self): + lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0] + return sum(lrl) / len(lrl) def state_dict(self): state_dict = dict(optimizer=self.optimizer.state_dict()) - if self.scaler is not None: - state_dict['scaler'] = self.scaler.state_dict() + if self.grad_scaler is not None: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() def load_state_dict(self, state_dict): if 'optimizer' in state_dict: self.optimizer.load_state_dict(state_dict['optimizer']) - if 'scaler' in state_dict and self.scaler is not None: - self.scaler.load_state_dict(state_dict['scaler']) - + if 'grad_scaler' in state_dict and self.grad_scaler is not None: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + def after_step(self, after_step_fn, *args): + after_step_fn(*args) diff --git a/timm/bits/updater_cuda.py b/timm/bits/updater_cuda.py index 799aef00..33f984db 100644 --- a/timm/bits/updater_cuda.py +++ b/timm/bits/updater_cuda.py @@ -1,36 +1,30 @@ -from typing import Callable, Optional, Union, Any +from dataclasses import dataclass, field, InitVar +from typing import Dict, Any import torch from .updater import Updater -class UpdaterCuda(Updater): - def __init__( - self, - optimizer: torch.optim.Optimizer, - clip_value: Optional[Union[Callable, float]] = None, - clip_mode: str = 'norm', - use_scaler: bool = False, - scaler_kwargs: Any = None, - ): - super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode) +@dataclass +class UpdaterCudaWithScaler(Updater): + + scaler_kwargs: InitVar[Dict[str, Any]] = None + + def __post_init__(self, scaler_kwargs: Dict[str, Any]): + super().__post_init__() scaler_kwargs = scaler_kwargs or {} - if use_scaler: - self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs) + self.grad_scaler = torch.cuda.amp.GradScaler(**scaler_kwargs) def apply(self, loss: torch.Tensor, accumulate=False): - if self.scaler is not None: - self.scaler.scale(loss).backward(create_graph=self.create_graph) - if self.clipper is not None: - self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place - self.clipper() - if not accumulate: - self.scaler.step(self.optimizer) - self.reset() - else: - self.num_accumulated += 1 - self.scaler.update() - else: - Updater.apply(self, loss, accumulate) - + self.grad_scaler.scale(loss).backward(create_graph=self.create_graph) + if accumulate: + # unscale first? + return + if self.clip_fn is not None: + # unscale the gradients of optimizer's assigned params in-place + self.grad_scaler.unscale_(self.optimizer) + self.clip_fn(self.clip_params_fn(), self.clip_value) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + self.reset() diff --git a/timm/bits/updater_deepspeed.py b/timm/bits/updater_deepspeed.py new file mode 100644 index 00000000..e080a7de --- /dev/null +++ b/timm/bits/updater_deepspeed.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass, field, InitVar + +import torch +try: + import deepspeed as ds +except ImportError as e: + ds = None + +from .updater import Updater + + +@dataclass +class UpdaterDeepSpeed(Updater): + + def __post_init__(self): + super().__post_init__() + # FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface + assert isinstance(self.model, ds.DeepSpeedEngine) + + def reset(self): + self.model.zero_grad() + + def apply(self, loss: torch.Tensor, accumulate=False): + self.model.backward(loss) + self.model.step() + self.reset() diff --git a/timm/bits/updater_factory.py b/timm/bits/updater_factory.py index aba008d2..24ef76c0 100644 --- a/timm/bits/updater_factory.py +++ b/timm/bits/updater_factory.py @@ -2,29 +2,38 @@ from typing import Callable, Optional, Union, Any import torch -from .device_env import DeviceEnv +from .device_env import DeviceEnv, DeviceEnvType from .device_env_factory import get_device from .updater import Updater -from .updater_cuda import UpdaterCuda -from .updater_xla import UpdaterXla +from .updater_cuda import UpdaterCudaWithScaler +from .updater_deepspeed import UpdaterDeepSpeed +from .updater_xla import UpdaterXla, UpdaterXlaWithScaler def create_updater( + model: torch.nn.Module, optimizer: torch.optim.Optimizer, + clip_fn: Optional[Union[Callable, str]] = None, + clip_value: Optional[float] = None, + scaler_kwargs: Any = None, dev_env: Optional[DeviceEnv] = None, - clip_value: Optional[Union[Callable, float]] = None, - clip_mode: str = 'norm', - scaler_kwargs: Any = None) -> Updater: + deepspeed: bool = False, +) -> Updater: if not dev_env: dev_env = get_device() - updater_kwargs = dict( - optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs) - if dev_env.type == 'xla': - return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp) - elif dev_env.type == 'cuda': - return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp) - else: - updater_kwargs.pop('scaler_kwargs', None) - return Updater(**updater_kwargs) + updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value) + use_scaler = dev_env.amp + if use_scaler: + updater_kwargs['scaler_kwargs'] = scaler_kwargs + updater_cls = Updater + if dev_env.type == DeviceEnvType.XLA: + updater_cls = UpdaterXlaWithScaler if use_scaler else UpdaterXla + elif dev_env.type == DeviceEnvType.CUDA and use_scaler: + updater_cls = UpdaterCudaWithScaler + elif deepspeed: + del updater_kwargs['scaler_kwargs'] + updater_cls = UpdaterDeepSpeed + + return updater_cls(**updater_kwargs) diff --git a/timm/bits/updater_xla.py b/timm/bits/updater_xla.py index 25287ad9..935e1994 100644 --- a/timm/bits/updater_xla.py +++ b/timm/bits/updater_xla.py @@ -1,6 +1,8 @@ -from typing import Callable, Optional, Union, Any +from dataclasses import dataclass, field, InitVar +from typing import Any, Dict import torch +import torch.nn as nn try: import torch_xla.core.xla_model as xm @@ -18,41 +20,49 @@ except ImportError as e: from .updater import Updater +@dataclass class UpdaterXla(Updater): - def __init__( - self, - optimizer: torch.optim.Optimizer, - clip_value: Optional[Union[Callable, float]] = None, - clip_mode: str = 'norm', - use_scaler: bool = False, - scaler_kwargs: Any = None, - ): - super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode) + def __post_init__(self): + super().__post_init__() self.after_step_closure = True - if use_scaler: - assert xa is not None, 'XLA AMP not present in this build' - self.scaler = xa.GradScaler(**scaler_kwargs) def apply(self, loss: torch.Tensor, accumulate: bool = False): - if self.scaler is None: - loss.backward(create_graph=self.create_graph) - gradients = xm._fetch_gradients(self.optimizer) - xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) - if self.clipper is not None: - self.clipper() - if not accumulate: - xm.optimizer_step(self.optimizer) - else: - self.scaler.scale(loss).backward(create_graph=self.create_graph) - if self.clipper is not None: - self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place - self.clipper() - if not accumulate: - self.scaler.step(self.optimizer) - self.reset() - self.scaler.update() + loss.backward(create_graph=self.create_graph) + if accumulate: + return + xm.reduce_gradients(self.optimizer) + if self.clip_fn is not None: + self.clip_fn(self.clip_params_fn(), self.clip_value) + self.optimizer.step() + xm.mark_step() + self.reset() def after_step(self, after_step_fn, *args): - xm.add_step_closure(after_step_fn, *args) + xm.add_step_closure(after_step_fn, args) + +@dataclass +class UpdaterXlaWithScaler(UpdaterXla): + + scaler_kwargs: InitVar[Dict[str, Any]] = None + + def __post_init__(self, scaler_kwargs: Dict[str, Any]): + super().__post_init__() + scaler_kwargs = scaler_kwargs or {} + assert xa is not None, 'XLA AMP not present in this build' + self.scaler = xa.GradScaler(**scaler_kwargs) + + def apply(self, loss: torch.Tensor, accumulate: bool = False): + self.scaler.scale(loss).backward(create_graph=self.create_graph) + if accumulate: + # unscale first? + return + xm.reduce_gradients(self.optimizer) + if self.clip_fn is not None: + self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place + self.clip_fn(self.clip_params_fn(), self.clip_value) + self.scaler.step(self.optimizer) + self.scaler.update() + xm.mark_step() + self.reset() diff --git a/timm/data/fetcher.py b/timm/data/fetcher.py index 1cbc3fe5..ec5afe8a 100644 --- a/timm/data/fetcher.py +++ b/timm/data/fetcher.py @@ -23,24 +23,20 @@ class Fetcher: re_count=1, re_num_splits=0): self.loader = loader - self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1) - self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1) self.device = torch.device(device) self.dtype = dtype or torch.float32 - if device: - self.mean.to(device=device, dtype=self.dtype) - self.std.to(device=device, dtype=self.dtype) + self.mean = torch.tensor([x * 255 for x in mean], dtype=self.dtype, device=self.device).view(1, 3, 1, 1) + self.std = torch.tensor([x * 255 for x in std], dtype=self.dtype, device=self.device).view(1, 3, 1, 1) if re_prob > 0.: self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device=device) else: self.random_erasing = None def __iter__(self): for sample, target in self.loader: - sample = sample.to(device=self.device) + sample = sample.to(device=self.device, dtype=self.dtype).sub_(self.mean).div_(self.std) target = target.to(device=self.device) - sample = sample.to(dtype=self.dtype).sub_(self.mean).div_(self.std) if self.random_erasing is not None: sample = self.random_erasing(sample) yield sample, target diff --git a/timm/data/loader.py b/timm/data/loader.py index 45d40908..5ddcc6d2 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -8,7 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman import torch.utils.data -from timm.bits import get_device +from timm.bits import get_device, DeviceEnvType from .fetcher import Fetcher from .prefetcher_cuda import PrefetcherCuda @@ -78,7 +78,7 @@ def create_loader( dev_env = get_device() sampler = None - if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset): + if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) @@ -117,7 +117,7 @@ def create_loader( re_count=re_count, re_num_splits=re_num_splits ) - if dev_env.type == 'cuda': + if dev_env.type_cuda: loader = PrefetcherCuda(loader, **fetcher_kwargs) else: loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 92495d12..519be03d 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -82,7 +82,7 @@ class ParserTfds(Parser): self.dist_num_replicas = 1 dev_env = get_device() # FIXME allow to work without devenv usage? - if dev_env.is_distributed and dev_env.world_size > 1: + if dev_env.distributed and dev_env.world_size > 1: self.dist_rank = dev_env.global_rank self.dist_num_replicas = dev_env.world_size # if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: @@ -150,8 +150,10 @@ class ParserTfds(Parser): ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers - ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - ds.options().experimental_threading.max_intra_op_parallelism = 1 + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + options.experimental_threading.max_intra_op_parallelism = 1 + ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading diff --git a/timm/metrics/__init__.py b/timm/metrics/__init__.py deleted file mode 100644 index 93a2773e..00000000 --- a/timm/metrics/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .accuracy import Accuracy, AccuracyTopK -from .precision_recall import PrecisionRecall -from .scalar_avg import ScalarAvgMinMax -from .tensor_avg import TensorAvg, TensorEma diff --git a/timm/metrics/accuracy.py b/timm/metrics/accuracy.py deleted file mode 100644 index b58a3781..00000000 --- a/timm/metrics/accuracy.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -from typing import Optional, Tuple, Dict - - -class Accuracy(torch.nn.Module): - - def __init__(self, threshold=0.5, multi_label=False): - self.threshold = threshold - self.eps = 1e-8 - self.multi_label = multi_label - - # statistics / counts - self._correct_sum = torch.tensor(0, dtype=torch.long) - self._total_sum = torch.tensor(0, dtype=torch.long) - - def update(self, predictions, target): - raise NotImplemented() - - def reset(self): - self._correct_sum = 0 - self._total_sum = 0 - - @property - def counts(self): - pass - - def compute(self): - raise NotImplemented() - - -class AccuracyTopK(torch.nn.Module): - - def __init__(self, topk=(1, 5), device=None): - super().__init__() - self.eps = 1e-8 - self.device = device - self.topk = topk - self.maxk = max(topk) - # FIXME handle distributed operation - - # statistics / counts - self.reset() - - def update(self, predictions: torch.Tensor, target: torch.Tensor): - sorted_indices = predictions.topk(self.maxk, dim=1)[1] - sorted_indices.t_() - correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices)) - - batch_size = target.shape[0] - correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk} - for k, v in correct_k.items(): - attr = f'_correct_top{k}' - old_v = getattr(self, attr) - setattr(self, attr, old_v + v) - self._total_sum += batch_size - - def reset(self): - for k in self.topk: - setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32)) - self._total_sum = torch.tensor(0, dtype=torch.float32) - - @property - def counts(self): - pass - - def compute(self) -> Dict[str, torch.Tensor]: - # FIXME handle distributed reduction - return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk} - - -# -# class AccuracyTopK: -# -# def __init__(self, topk=(1, 5), device=None): -# self.eps = 1e-8 -# self.device = device -# self.topk = topk -# self.maxk = max(topk) -# -# # statistics / counts -# self._correct_sum = None -# self._total_sum = None -# -# def _check_init(self, device): -# to_device = self.device if self.device else device -# if self._correct_sum is None: -# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk} -# if self._total_sum is None: -# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device) -# -# def update(self, predictions: torch.Tensor, target: torch.Tensor): -# sorted_indices = predictions.topk(self.maxk, dim=1)[1] -# sorted_indices.t_() -# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices)) -# -# batch_size = target.shape[0] -# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk} -# self._check_init(device=predictions.device) -# for k, v in correct_k.items(): -# old_v = self._correct_sum[k] -# self._correct_sum[k] = old_v + v -# self._total_sum += batch_size -# -# def reset(self): -# self._correct_sum = None -# self._total_sum = None -# -# @property -# def counts(self): -# pass -# -# def compute(self) -> Dict[str, torch.Tensor]: -# assert self._correct_sum is not None and self._total_sum is not None -# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()} diff --git a/timm/scheduler/__init__.py b/timm/scheduler/__init__.py index 6a778982..60f5e3df 100644 --- a/timm/scheduler/__init__.py +++ b/timm/scheduler/__init__.py @@ -3,3 +3,4 @@ from .plateau_lr import PlateauLRScheduler from .step_lr import StepLRScheduler from .tanh_lr import TanhLRScheduler from .scheduler_factory import create_scheduler +from .scheduler import Scheduler \ No newline at end of file diff --git a/timm/utils/checkpoint_saver.py b/timm/utils/checkpoint_saver.py index 6aad74ee..7a13306e 100644 --- a/timm/utils/checkpoint_saver.py +++ b/timm/utils/checkpoint_saver.py @@ -108,7 +108,8 @@ class CheckpointSaver: save_state['arch'] = self.args.model save_state['args'] = self.args if self.amp_scaler is not None: - save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + amp_key = getattr(self.amp_scaler, 'state_dict_key', 'amp_scaler') + save_state[amp_key] = self.amp_scaler.state_dict() if self.model_ema is not None: save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) if metric is not None: diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py index 7eb40697..d1279ac9 100644 --- a/timm/utils/clip_grad.py +++ b/timm/utils/clip_grad.py @@ -3,7 +3,11 @@ import torch from timm.utils.agc import adaptive_clip_grad -def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): +def dispatch_clip_grad( + parameters, + value: float, + mode: str = 'norm', + norm_type: float = 2.0): """ Dispatch to gradient clipping method Args: diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 3c5dba8c..528b7d42 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -21,8 +21,8 @@ def distribute_bn(model, world_size, reduce=False): if ('running_mean' in bn_name) or ('running_var' in bn_name): if reduce: # average bn stats across whole group - torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + torch.distributed.all_reduce_recursive(bn_buf, op=dist.ReduceOp.SUM) bn_buf /= float(world_size) else: # broadcast bn stats from rank 0 to whole group - torch.distributed.broadcast(bn_buf, 0) + torch.distributed.broadcast_recursive(bn_buf, 0) diff --git a/train.py b/train.py index de627929..05da82e2 100755 --- a/train.py +++ b/train.py @@ -21,13 +21,15 @@ import os import logging from collections import OrderedDict from datetime import datetime +from dataclasses import replace +from typing import Tuple import torch import torch.nn as nn import torchvision.utils -from timm.bits import initialize_device, DeviceEnv, create_updater, Updater, Logger, Tracker -from timm.metrics import TensorAvg, AccuracyTopK +from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\ + TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters @@ -276,7 +278,7 @@ def main(): args, args_text = _parse_args() dev_env = initialize_device(amp=args.amp) - if dev_env.is_distributed: + if dev_env.distributed: _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.' % (dev_env.global_rank, dev_env.world_size)) else: @@ -284,6 +286,111 @@ def main(): random_seed(args.seed, dev_env.global_rank) + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + + train_state, train_cfg = setup_train_task(args, dev_env, mixup_active) + + data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active) + + # setup checkpoint saver + eval_metric = args.eval_metric + best_metric = None + best_epoch = None + saver = None + output_dir = None + if dev_env.primary: + if args.experiment: + exp_name = args.experiment + else: + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + safe_model_name(args.model), + str(data_config['input_size'][-1]) + ]) + output_dir = get_outdir(args.output if args.output else './output/train', exp_name) + decreasing = True if eval_metric == 'loss' else False + saver = CheckpointSaver( # TODO CheckpointSaverV2 + model=train_state.model, + optimizer=train_state.updater.optimizer, + args=args, + model_ema=train_state.model_ema, + amp_scaler=train_state.updater.grad_scaler, + checkpoint_dir=output_dir, + recovery_dir=output_dir, + decreasing=decreasing, + max_history=args.checkpoint_hist) + with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: + f.write(args_text) + + services = TrainServices( + logger=Logger( + output_dir=output_dir, python_logger=_logger, hparams=vars(args), output_enabled=dev_env.primary), + saver=saver, + ) + + try: + for epoch in range(train_state.epoch, train_cfg.num_epochs): + if dev_env.distributed and hasattr(loader_train.sampler, 'set_epoch'): + loader_train.sampler.set_epoch(epoch) + if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + if loader_train.mixup_enabled: + loader_train.mixup_enabled = False + + train_metrics = train_one_epoch( + dev_env=dev_env, + state=train_state, + services=services, + cfg=train_cfg, + loader=loader_train + ) + + if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): + if dev_env.primary: + _logger.info("Distributing BatchNorm running means and vars") + distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce') + + eval_metrics = evaluate( + train_state.model, + train_state.eval_loss, + loader_eval, + dev_env, + logger=services.logger) + + if train_state.model_ema is not None and not args.model_ema_force_cpu: + if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): + distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce') + + ema_eval_metrics = evaluate( + train_state.model_ema.module, + train_state.eval_loss, + loader_eval, + dev_env, + logger=services.logger, + phase_suffix='EMA') + eval_metrics = ema_eval_metrics + + if train_state.lr_scheduler is not None: + # step LR for next epoch + train_state.lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + + if services.logger is not None: + services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) + + if saver is not None: + # save proper checkpoint with eval metric + save_metric = eval_metrics[eval_metric] + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + + train_state = replace(train_state, epoch=epoch + 1) + + except KeyboardInterrupt: + pass + if best_metric is not None: + _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + + +def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): + model = create_model( args.model, pretrained=args.pretrained, @@ -302,82 +409,69 @@ def main(): assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly - if dev_env.is_master: + if dev_env.primary: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') - data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.is_master) - # setup augmentation batch splits for contrastive loss or split bn - num_aug_splits = 0 - if args.aug_splits > 0: - assert args.aug_splits > 1, 'A split of 1 makes no sense' - num_aug_splits = args.aug_splits - + assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense' # enable split bn (separate bn stats per batch-portion) if args.split_bn: - assert num_aug_splits > 1 or args.resplit - model = convert_splitbn_model(model, max(num_aug_splits, 2)) - - # move model to GPU, enable channels last layout if set - dev_env.to_device(model) - - # setup synchronized BatchNorm for distributed training - if dev_env.is_distributed and args.sync_bn: - assert not args.split_bn - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if dev_env.is_master: - _logger.info( - 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' - 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') - - if args.torchscript: - assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' - model = torch.jit.script(model) - - updater = create_updater( - create_optimizer_v2(model, **optimizer_kwargs(cfg=args)), - clip_value=args.clip_grad, clip_mode=args.clip_mode) - - # optionally resume from a checkpoint - resume_epoch = None - if args.resume: - resume_epoch = resume_checkpoint( - model, args.resume, - optimizer=None if args.no_resume_opt else updater.optimizer, - loss_scaler=None if args.no_resume_opt else updater.scaler, - log_info=dev_env.is_master) - - # setup exponential moving average of model weights, SWA could be used here too - model_ema = None - if args.model_ema: - # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper - model_ema = ModelEmaV2( - model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) - if args.resume: - load_checkpoint(model_ema.module, args.resume, use_ema=True) - - # setup distributed training - if dev_env.is_distributed: - if dev_env.is_master: - _logger.info("Distributing model.") - model = dev_env.wrap_distributed(model) - # NOTE: EMA model does not need to be wrapped by DDP + assert args.aug_splits > 1 or args.resplit + model = convert_splitbn_model(model, max(args.aug_splits, 2)) + + train_state = setup_model_and_optimizer( + dev_env=dev_env, + model=model, + optimizer=args.opt, + optimizer_cfg=optimizer_kwargs(cfg=args), + clip_fn=args.clip_mode if args.clip_grad is not None else None, + clip_value=args.clip_grad, + model_ema=args.model_ema, + model_ema_decay=args.model_ema_decay, + use_syncbn=args.sync_bn, + ) # setup learning rate schedule and starting epoch - lr_scheduler, num_epochs = create_scheduler(args, updater.optimizer) - start_epoch = 0 - if args.start_epoch is not None: - # a specified start_epoch will always override the resume epoch - start_epoch = args.start_epoch - elif resume_epoch is not None: - start_epoch = resume_epoch - if lr_scheduler is not None and start_epoch > 0: - lr_scheduler.step(start_epoch) - - if dev_env.is_master: + # FIXME move into updater? + lr_scheduler, num_epochs = create_scheduler(args, train_state.updater.optimizer) + if lr_scheduler is not None and train_state.epoch > 0: + lr_scheduler.step(train_state.epoch) + + # setup loss function + if args.jsd: + assert args.aug_splits > 1 # JSD only valid with aug splits set + train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing) + elif mixup_active: + # smoothing is handled with mixup target transform + train_loss_fn = SoftTargetCrossEntropy() + elif args.smoothing: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + train_loss_fn = nn.CrossEntropyLoss() + eval_loss_fn = nn.CrossEntropyLoss() + dev_env.to_device(train_loss_fn, eval_loss_fn) + + if dev_env.primary: _logger.info('Scheduled epochs: {}'.format(num_epochs)) + train_state = replace( + train_state, + lr_scheduler=lr_scheduler, + train_loss=train_loss_fn, + eval_loss=eval_loss_fn) + + train_cfg = TrainCfg( + num_epochs=num_epochs, + log_interval=args.log_interval, + recovery_interval=args.recovery_interval) + + return train_state, train_cfg + + +def setup_data(args, dev_env, mixup_active): + data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary) + # create the train and eval datasets dataset_train = create_dataset( args.dataset, @@ -388,18 +482,17 @@ def main(): # setup mixup / cutmix collate_fn = None - mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) - assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) + assert not args.aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) # wrap dataset in AugMix helper - if num_aug_splits > 1: - dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + if args.aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=args.aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation @@ -421,7 +514,7 @@ def main(): vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, - num_aug_splits=num_aug_splits, + num_aug_splits=args.aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], @@ -443,169 +536,107 @@ def main(): crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) + return data_config, loader_eval, loader_train - # setup loss function - if args.jsd: - assert num_aug_splits > 1 # JSD only valid with aug splits set - train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) - elif mixup_active: - # smoothing is handled with mixup target transform - train_loss_fn = SoftTargetCrossEntropy() - elif args.smoothing: - train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) - else: - train_loss_fn = nn.CrossEntropyLoss() - validate_loss_fn = nn.CrossEntropyLoss() - dev_env.to_device(train_loss_fn, validate_loss_fn) - - # setup checkpoint saver and eval metric tracking - eval_metric = args.eval_metric - best_metric = None - best_epoch = None - saver = None - output_dir = None - if dev_env.is_master: - if args.experiment: - exp_name = args.experiment - else: - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - safe_model_name(args.model), - str(data_config['input_size'][-1]) - ]) - output_dir = get_outdir(args.output if args.output else './output/train', exp_name) - decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver( - model=model, optimizer=updater.optimizer, args=args, model_ema=model_ema, amp_scaler=updater.scaler, - checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) - with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: - f.write(args_text) - - logger = Logger(output_dir=output_dir, logger=_logger, hparams=vars(args)) - try: - for epoch in range(start_epoch, num_epochs): - if dev_env.is_distributed and hasattr(loader_train.sampler, 'set_epoch'): - loader_train.sampler.set_epoch(epoch) - if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: - if loader_train.mixup_enabled: - loader_train.mixup_enabled = False +def train_one_epoch( + dev_env: DeviceEnv, + state: TrainState, + cfg: TrainCfg, + services: TrainServices, + loader, +): + tracker = Tracker() + loss_meter = AvgTensor() - train_metrics = train_one_epoch( - epoch, model, loader_train, updater, train_loss_fn, dev_env, - lr_scheduler=lr_scheduler, saver=saver, logger=logger, model_ema=model_ema, - log_interval=args.log_interval, recovery_interval=args.recovery_interval) + state.model.train() + state.updater.reset() # zero-grad - if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'): - if dev_env.is_master: - _logger.info("Distributing BatchNorm running means and vars") - distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce') + step_end_idx = len(loader) - 1 + tracker.mark_iter() + for step_idx, (sample, target) in enumerate(loader): + tracker.mark_iter_data_end() - eval_metrics = evaluate(model, loader_eval, validate_loss_fn, dev_env, logger=logger) + # FIXME move forward + loss into model 'task' wrapper + with dev_env.autocast(): + output = state.model(sample) + loss = state.train_loss(output, target) - if model_ema is not None and not args.model_ema_force_cpu: - if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'): - distribute_bn(model_ema, dev_env.world_size, args.dist_bn == 'reduce') + state.updater.apply(loss) - ema_eval_metrics = evaluate( - model_ema.module, loader_eval, validate_loss_fn, dev_env, - logger=logger, phase_suffix='EMA') - eval_metrics = ema_eval_metrics + tracker.mark_iter_step_end() - if lr_scheduler is not None: - # step LR for next epoch - lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + state.updater.after_step( + after_train_step, + dev_env, + state, + services, + cfg, + step_idx, + step_end_idx, + tracker, + loss_meter, + (output, target, loss), + ) - if logger is not None: - logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) + tracker.mark_iter() + # end for - if saver is not None: - # save proper checkpoint with eval metric - save_metric = eval_metrics[eval_metric] - best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + if hasattr(state.updater.optimizer, 'sync_lookahead'): + state.updater.optimizer.sync_lookahead() - except KeyboardInterrupt: - pass - if best_metric is not None: - _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + return OrderedDict([('loss', loss_meter.compute().item())]) -def train_one_epoch( - epoch: int, - model: nn.Module, - loader, - updater: Updater, - loss_fn: nn.Module, +def after_train_step( dev_env: DeviceEnv, - lr_scheduler=None, - saver: CheckpointSaver = None, - logger: Logger = None, - model_ema: nn.Module = None, - log_interval: int = 50, - recovery_interval: int = 0, + state: TrainState, + services: TrainServices, + cfg: TrainCfg, + step_idx: int, + step_end_idx: int, + tracker: Tracker, + loss_meter: AvgTensor, + tensors: Tuple[torch.Tensor, ...], ): - tracker = Tracker() - losses_m = TensorAvg() + end_step = step_idx == step_end_idx - model.train() - - end_idx = len(loader) - 1 - num_updates = epoch * len(loader) - batch_size = 0 - tracker.mark_iter() - for step_idx, (sample, target) in enumerate(loader): - tracker.mark_iter_data_end() - last_step = step_idx == end_idx - batch_size = max(batch_size, sample.shape[0]) - - with dev_env.autocast(): - output = model(sample) - loss = loss_fn(output, target) - - updater.reset() - updater.apply(loss) + with torch.no_grad(): + output, target, loss = tensors + loss_meter.update(loss, output.shape[0]) - dev_env.mark_step() # FIXME - tracker.mark_iter_step_end() - losses_m.update(loss, sample.size(0)) - if model_ema is not None: - model_ema.update(model) + if state.model_ema is not None: + state.model_ema.update(model) - num_updates += 1 - if last_step or (step_idx + 1) % log_interval == 0: - lrl = [param_group['lr'] for param_group in updater.optimizer.param_groups] - lr = sum(lrl) / len(lrl) + state = replace(state, step_count_global=state.step_count_global + 1) - if dev_env.is_master and logger is not None: - loss_avg = losses_m.compute() - logger.log_step( + if services.logger is not None and end_step or (step_idx + 1) % cfg.log_interval == 0: + global_batch_size = dev_env.world_size * output.shape[0] + loss_avg = loss_meter.compute() + if services.logger is not None: + lr_avg = state.updater.get_average_lr() + services.logger.log_step( 'Train', step=step_idx, - end_step=end_idx, + step_end=step_end_idx, + epoch=state.epoch, loss=loss_avg.item(), - rate=(dev_env.world_size * batch_size) / tracker.iter_time.avg, - lr=lr, + rate=tracker.get_avg_iter_rate(global_batch_size), + lr=lr_avg, ) - if saver is not None and recovery_interval and (last_step or (step_idx + 1) % recovery_interval == 0): - saver.save_recovery(epoch, batch_idx=step_idx) - - if lr_scheduler is not None: - lr_scheduler.step_update(num_updates=num_updates) + if services.saver is not None and cfg.recovery_interval and ( + end_step or (step_idx + 1) % cfg.recovery_interval == 0): + services.saver.save_recovery(state.epoch, batch_idx=step_idx) - tracker.mark_iter() - # end for - - if hasattr(updater.optimizer, 'sync_lookahead'): - updater.optimizer.sync_lookahead() - - return OrderedDict([('loss', losses_m.compute().item())]) + if state.lr_scheduler is not None: + state.lr_scheduler.step_update(num_updates=state.step_count_global) def evaluate( model: nn.Module, - loader, loss_fn: nn.Module, + loader, dev_env: DeviceEnv, logger: Logger, phase_suffix: str = '', @@ -613,7 +644,7 @@ def evaluate( ): tracker = Tracker() - losses_m = TensorAvg() + losses_m = AvgTensor() accuracy_m = AccuracyTopK() model.eval() @@ -636,13 +667,13 @@ def evaluate( losses_m.update(loss, output.size(0)) accuracy_m.update(output, target) - if dev_env.is_master and (last_step or step_idx % log_interval == 0): + if last_step or step_idx % log_interval == 0: top1, top5 = accuracy_m.compute().values() loss_avg = losses_m.compute() logger.log_step( 'Eval', step=step_idx, - num_steps=end_idx, + step_end=end_idx, loss=loss_avg.item(), top1=top1.item(), top5=top5.item(), diff --git a/validate.py b/validate.py index add23469..b7538d9f 100755 --- a/validate.py +++ b/validate.py @@ -18,8 +18,7 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict -from timm.bits import initialize_device, Tracker, Logger -from timm.metrics import AccuracyTopK, TensorAvg +from timm.bits import initialize_device, Tracker, Logger, AccuracyTopK, AvgTensor from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import natural_key, setup_default_logging @@ -155,10 +154,10 @@ def validate(args): pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) - logger = Logger(logger=_logger) + logger = Logger(python_logger=_logger) tracker = Tracker() - losses = TensorAvg() - accuracy = AccuracyTopK().to(dev_env.device) + losses = AvgTensor() + accuracy = AccuracyTopK(dev_env=dev_env) model.eval() num_steps = len(loader) @@ -175,10 +174,8 @@ def validate(args): output = output[:, valid_labels] loss = criterion(output, target) - if dev_env.type == 'cuda': + if dev_env.type_cuda: torch.cuda.synchronize() - #elif dev_env.type == 'xla': - # dev_env.mark_step() tracker.mark_iter_step_end() losses.update(loss.detach(), sample.size(0)) @@ -186,7 +183,7 @@ def validate(args): real_labels.add_result(output) accuracy.update(output.detach(), target) - if dev_env.type == 'xla': + if dev_env.type_xla: dev_env.mark_step() tracker.mark_iter()