From 91ab0b6ce5a1dbe22600132bb4fa9ededc96b2ab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 3 Jun 2021 17:49:40 -0700 Subject: [PATCH] Add proper TrainState checkpoint save/load. Some reorg/refactoring and other cleanup. More to go... --- timm/bits/__init__.py | 9 +- timm/bits/checkpoint.py | 121 ++++++++++----- timm/bits/checkpoint_manager.py | 219 ++++++++++++++++++++++++++++ timm/bits/device_env.py | 66 ++++++++- timm/bits/device_env_cuda.py | 9 +- timm/bits/device_env_factory.py | 26 ++-- timm/bits/device_env_xla.py | 19 ++- timm/bits/distributed.py | 15 +- timm/bits/metric.py | 13 +- timm/bits/{logger.py => monitor.py} | 8 +- timm/bits/train_services.py | 8 +- timm/bits/train_setup.py | 10 +- timm/bits/train_state.py | 35 +++-- timm/bits/updater.py | 4 + timm/bits/updater_deepspeed.py | 4 + timm/bits/updater_factory.py | 3 +- timm/data/loader.py | 4 +- timm/data/parsers/parser_tfds.py | 4 +- timm/utils/model.py | 47 +++--- train.py | 50 +++---- validate.py | 4 +- 21 files changed, 522 insertions(+), 156 deletions(-) create mode 100644 timm/bits/checkpoint_manager.py rename timm/bits/{logger.py => monitor.py} (98%) diff --git a/timm/bits/__init__.py b/timm/bits/__init__.py index c9960341..940e9e3e 100644 --- a/timm/bits/__init__.py +++ b/timm/bits/__init__.py @@ -1,14 +1,15 @@ from .avg_scalar import AvgMinMaxScalar from .avg_tensor import AvgTensor -from .device_env import DeviceEnv, DeviceEnvType +from .checkpoint_manager import CheckpointManager +from .device_env import DeviceEnv, DeviceEnvType, get_global_device, set_global_device, is_global_device from .device_env_cuda import DeviceEnvCuda -from .device_env_factory import initialize_device, get_device +from .device_env_factory import initialize_device from .device_env_xla import DeviceEnvXla from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\ all_reduce_sequence, all_gather_sequence # from .evaluate import evaluate, eval_step -from .logger import Logger -from .metric import Metric, MetricValue +from .monitor import Monitor +from .metric import Metric, MetricValueT from .metric_accuracy import AccuracyTopK from .tracker import Tracker # from .task_metrics import TaskMetrics, TaskMetricsClassify diff --git a/timm/bits/checkpoint.py b/timm/bits/checkpoint.py index b7ff1909..df21ab5e 100644 --- a/timm/bits/checkpoint.py +++ b/timm/bits/checkpoint.py @@ -1,17 +1,73 @@ import logging import os from collections import OrderedDict +from typing import Dict, Any, Callable import torch from timm.utils import unwrap_model -from .train_state import TrainState, serialize_train_state, deserialize_train_state +from .device_env import DeviceEnv +from .train_state import TrainState _logger = logging.getLogger(__name__) -def _load_state_dict(checkpoint, state_dict_key='state_dict'): +def save_train_state( + checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS? + train_state: TrainState, + extra_state: Dict[str, Any] = None, + unwrap_fn: Callable = unwrap_model, + dev_env: DeviceEnv = None, + log_info: bool = True): + + assert not train_state.updater.deepspeed + # DeepSpeed has a fully custom checkpoint saving setup, it is not possible + # specify a filename, checkpoints needed to be saved from all ranks, etc + # if train_state.updater.deepspeed: + # save_train_state_deepspeed(train_state, checkpoint_path) + + dev_env = dev_env or DeviceEnv.instance() + state_dict = train_state.state_dict(unwrap_fn=unwrap_fn) + if extra_state: + state_dict.update(extra_state) + if dev_env.type_xla: + # XLA state dict needs to be moved to CPU before save, this is normally done by xm.save + state_dict = dev_env.state_dict_to_cpu(state_dict) + torch.save(state_dict, checkpoint_path) + + +def load_train_state( + train_state: TrainState, + checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS + unwrap_fn: Callable = None, + load_opt: bool = True, + dev_env: DeviceEnv = None, + log_info: bool = True +): + unwrap_fn = unwrap_fn or unwrap_model + if not os.path.isfile(checkpoint_path): + _logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + if log_info: + _logger.info('Restoring training state from checkpoint...') + + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + + if not checkpoint.get('version', 0) > 2: + load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info) + if log_info: + _logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch)) + return + + train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn) + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch)) + + +def _get_state_dict(checkpoint, state_dict_key='state_dict'): new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): name = k[7:] if k.startswith('module') else k @@ -19,48 +75,35 @@ def _load_state_dict(checkpoint, state_dict_key='state_dict'): return new_state_dict -def resume_train_checkpoint( +def load_legacy_checkpoint( train_state: TrainState, - checkpoint_path, - resume_opt=True, - deserialize_fn=deserialize_train_state, + checkpoint, + load_opt=True, log_info=True): - # FIXME this is a hacky adaptation of pre-bits resume to get up and running quickly - resume_epoch = None - if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint + assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint + train_state.model.load_state_dict(_get_state_dict(checkpoint)) + + if train_state.model_ema is not None and 'state_dict_ema' in checkpoint: if log_info: - _logger.info('Restoring model state from checkpoint...') + _logger.info('Restoring model (EMA) state from checkpoint...') + unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema')) - train_state.model.load_state_dict(_load_state_dict(checkpoint)) + if load_opt: + if train_state.updater.optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + train_state.updater.optimizer.load_state_dict(checkpoint['optimizer']) - if train_state.model_ema is not None and 'state_dict_ema' in checkpoint: + scaler_state_dict_key = 'amp_scaler' + if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint: if log_info: - _logger.info('Restoring model (EMA) state from checkpoint...') - unwrap_model(train_state.model_ema).load_state_dict(_load_state_dict(checkpoint, 'state_dict_ema')) - - if resume_opt: - if train_state.updater.optimizer is not None and 'optimizer' in checkpoint: - if log_info: - _logger.info('Restoring optimizer state from checkpoint...') - train_state.updater.optimizer.load_state_dict(checkpoint['optimizer']) - - scaler_state_dict_key = 'amp_scaler' - if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint: - if log_info: - _logger.info('Restoring AMP loss scaler state from checkpoint...') - train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key]) - - if 'epoch' in checkpoint: - resume_epoch = checkpoint['epoch'] - if 'version' in checkpoint and checkpoint['version'] > 1: - resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only + _logger.info('Restoring AMP loss scaler state from checkpoint...') + train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only - if log_info: - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) - else: - _logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path)) - raise FileNotFoundError() diff --git a/timm/bits/checkpoint_manager.py b/timm/bits/checkpoint_manager.py new file mode 100644 index 00000000..b051e126 --- /dev/null +++ b/timm/bits/checkpoint_manager.py @@ -0,0 +1,219 @@ +""" Checkpoint Manager + +Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import glob +import logging +import operator +import os +import shutil +from typing import Optional, Dict, Callable, List +from dataclasses import dataclass, replace + + +from .checkpoint import save_train_state +from .train_state import TrainState + +_logger = logging.getLogger(__name__) + + +@dataclass +class CheckpointInfo: + path: str = '' + metrics: Dict[str, float] = None # all metrics at time of checkpoint save + metric_name: str = 'loss' + metric_decreasing: bool = True + epoch: int = 0 + global_step: int = 0 + + @property + def valid_key(self): + return self.metric_name and self.metrics and self.metric_name in self.metrics + + @property + def sort_key(self): + return self.metrics[self.metric_name] if self.valid_key else self.epoch + + @property + def decreasing_key(self): + return self.metric_decreasing if self.valid_key else False + + +class CheckpointManager: + def __init__( + self, + hparams=None, + save_state_fn=None, + checkpoint_dir='', + recovery_dir='', + checkpoint_tmpl=None, + recovery_tmpl=None, + metric_name='loss', + metric_decreasing=True, + max_history=10): + + # extra items to include in checkpoint + self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system + + # state + self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness + self.best_checkpoint = None + self.curr_recovery_file = '' + self.prev_recovery_file = '' + self.can_hardlink = True + + # util / helper fn + self.save_state_fn = save_state_fn or save_train_state + + # file / folder config + self.extension = '.pth.tar' + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension + self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension + + # ordering / history config + self.metric_name = metric_name + self.metric_decreasing = metric_decreasing + self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt + self.max_history = max_history + assert self.max_history >= 1 + + def _replace(self, src, dst): + if self.can_hardlink: + try: + if os.path.exists(dst): + os.unlink(dst) # required for Windows support. + except Exception as e: + self.can_hardlink = False + os.replace(src, dst) + + def _duplicate(self, src, dst): + if self.can_hardlink: + try: + if os.path.exists(dst): + # for Windows + os.unlink(dst) + os.link(src, dst) + return + except Exception as e: + self.can_hardlink = False + shutil.copy2(src, dst) + + def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None): + extra_state = dict( + # version < 2 increments epoch before save + # version < 3, pre timm bits + # version 3, first timm bits checkpoitns + version=3, + ) + if self.hparams is not None: + extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams)) + else: + arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None) + if arch is None: + arch = type(train_state.model).__name__.lower() + extra_state.update(dict(arch=arch)) + if metrics is not None: + # save the metrics and how we originally sorted them in the checkpoint for future comparisons + extra_state.update(dict( + metrics=metrics, + metric_name=self.metric_name, + metric_decreasing=self.metric_decreasing + )) + + self.save_state_fn(save_path, train_state, extra_state) + + checkpoint_info = CheckpointInfo( + path=save_path, + metrics=metrics, + metric_name=self.metric_name, + metric_decreasing=self.metric_decreasing, + epoch=train_state.epoch, + global_step=train_state.step_count_global, + ) + return checkpoint_info + + def _udpate_checkpoints(self, info: CheckpointInfo): + self.checkpoint_files.append(info) + self.checkpoint_files = sorted( + self.checkpoint_files, + key=lambda x: x.sort_key, + reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better + ) + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index < 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + _logger.debug("Cleaning checkpoint: {}".format(d)) + os.remove(d[0]) + except Exception as e: + _logger.error("Exception '{}' while deleting checkpoint".format(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo): + # compare metrics against an existing checkpoint + if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key: + # always assume lhs metrics are better if there are no usable metrics to compare + return True + return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key) + + def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None): + assert train_state.epoch >= 0 + tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) + last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) + curr_checkpoint = self._save(tmp_save_path, train_state, metrics) + self._replace(tmp_save_path, last_save_path) + + worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None + if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint): + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + + filename = self.checkpoint_tmpl.format(index=train_state.epoch) + save_path = os.path.join(self.checkpoint_dir, filename) + curr_checkpoint = replace(curr_checkpoint, path=save_path) + self._duplicate(last_save_path, save_path) + self._udpate_checkpoints(curr_checkpoint) + + checkpoints_str = "Current checkpoints:\n" + for c in self.checkpoint_files: + checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c) + _logger.info(checkpoints_str) + + if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint): + self.best_checkpoint = curr_checkpoint + best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension) + self._duplicate(last_save_path, best_save_path) + + return None if self.best_checkpoint is None else curr_checkpoint + + def save_recovery(self, train_state: TrainState): + tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension) + self._save(tmp_save_path, train_state) + + filename = self.recovery_tmpl.format(index=train_state.step_count_global) + save_path = os.path.join(self.recovery_dir, filename) + self._replace(tmp_save_path, save_path) + + if os.path.exists(self.prev_recovery_file): + try: + _logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file)) + os.remove(self.prev_recovery_file) + except Exception as e: + _logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file)) + self.prev_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + return files[0] if len(files) else '' diff --git a/timm/bits/device_env.py b/timm/bits/device_env.py index bac9b0ab..0a926e69 100644 --- a/timm/bits/device_env.py +++ b/timm/bits/device_env.py @@ -1,7 +1,7 @@ import abc from contextlib import suppress from enum import Enum -from typing import Callable, Union, Optional, List, Tuple +from typing import Callable, Union, Optional, List, Tuple, Dict, Any from dataclasses import dataclass, field, InitVar import torch @@ -18,10 +18,21 @@ class DeviceEnvType(Enum): XLA = "xla" +def state_dict_apply(state_dict: Dict[str, Any], apply_fn, select_fn=lambda x: x.isinstance(torch.Tensor)): + out_dict = {} + for k, v in state_dict.items(): + if isinstance(v, dict): + out_dict[k] = state_dict_apply(v, apply_fn, select_fn) + else: + out_dict[k] = apply_fn(v) if select_fn(v) else v + return out_dict + + @dataclass class DeviceEnv: device_type: InitVar[Optional[str]] = None device_index: InitVar[Optional[int]] = None + channels_last: InitVar[bool] = False device: torch.device = field(init=False) # set from device_type + device_index or post_init logic world_size: Optional[int] = None # set by post_init from env when None @@ -32,7 +43,12 @@ class DeviceEnv: memory_format: Optional[torch.memory_format] = None dtype: Optional[torch.dtype] = None - def __post_init__(self, device_type: Optional[str], device_index: Optional[int]): + def __post_init__( + self, + device_type: Optional[str], + device_index: Optional[int], + channels_last: bool, + ): device_type = device_type or 'cpu' self.device = torch.device(device_type) if device_index is None \ else torch.device(device_type, device_index) @@ -41,6 +57,17 @@ class DeviceEnv: self.global_rank = 0 if self.global_rank is None else self.global_rank if self.autocast is None: self.autocast = suppress + if channels_last: + self.memory_format = torch.channels_last + + @staticmethod + def is_instance(): + return is_global_device() + + @staticmethod + def instance(): + # throws if called before global device is set / initialized + return get_global_device() @property def type(self) -> DeviceEnvType: @@ -81,11 +108,23 @@ class DeviceEnv: def wrap_parallel(self, *modules): pass + def to_cpu(self, *modules: torch.nn.Module): + moved = [m.cpu() for m in modules] + return moved[0] if len(moved) == 1 else moved + def to_device(self, *modules: torch.nn.Module): - # FIXME handling dtype / memformat... disable flags, enable flags, diff fn? + # FIXME handling dtype? Do we want separate dtype for data vs model? moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules] return moved[0] if len(moved) == 1 else moved + def state_dict_to_cpu(self, state: Dict[str, Any]): + cpu_state = state_dict_apply(state, apply_fn=lambda x: x.cpu()) + return cpu_state + + def state_dict_to_device(self, state: Dict[str, Any]): + cpu_state = state_dict_apply(state, apply_fn=lambda x: x.to(self.device)) + return cpu_state + def mark_step(self): pass # NO-OP for non-XLA devices @@ -126,3 +165,24 @@ class DeviceEnv: def barrier(self): dist.barrier() + + +# Global device environment singleton instance +_global_device_env: Optional[DeviceEnv] = None + + +def is_global_device(): + return _global_device_env is not None + + +def get_global_device() -> DeviceEnv: + if not is_global_device(): + raise RuntimeError('Please initialize device environment by calling initialize_device / set_global_device.') + return _global_device_env + + +def set_global_device(device: DeviceEnv): + global _global_device_env + if _global_device_env is not None: + raise RuntimeError('Global device is already set, it should NOT be set again.') + _global_device_env = device diff --git a/timm/bits/device_env_cuda.py b/timm/bits/device_env_cuda.py index 7358e405..c57dfda5 100644 --- a/timm/bits/device_env_cuda.py +++ b/timm/bits/device_env_cuda.py @@ -16,7 +16,12 @@ def is_cuda_available(): @dataclass class DeviceEnvCuda(DeviceEnv): - def __post_init__(self, device_type: str, device_index: Optional[int]): + def __post_init__( + self, + device_type: Optional[str], + device_index: Optional[int], + channels_last: bool, + ): assert torch.cuda.device_count() torch.backends.cudnn.benchmark = True setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1)) @@ -43,6 +48,8 @@ class DeviceEnvCuda(DeviceEnv): self.global_rank = 0 if self.autocast is None: self.autocast = torch.cuda.amp.autocast if self.amp else suppress + if channels_last: + self.memory_format = torch.channels_last @property def type(self) -> DeviceEnvType: diff --git a/timm/bits/device_env_factory.py b/timm/bits/device_env_factory.py index 2037a39e..bb92daab 100644 --- a/timm/bits/device_env_factory.py +++ b/timm/bits/device_env_factory.py @@ -1,15 +1,15 @@ -from .device_env import DeviceEnv +import logging + +from .device_env import DeviceEnv, is_global_device, get_global_device, set_global_device from .device_env_cuda import DeviceEnvCuda, is_cuda_available from .device_env_xla import DeviceEnvXla, is_xla_available -_device_env = None +_logger = logging.getLogger(__name__) def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv: - global _device_env - if _device_env is not None: - # warning - return _device_env + if is_global_device(): + return get_global_device() denv = None if not force_cpu: @@ -23,14 +23,10 @@ def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv: if denv is None: denv = DeviceEnv() - print(denv) # FIXME DEBUG - _device_env = denv - return denv - - -def get_device() -> DeviceEnv: - if _device_env is None: - raise RuntimeError('Please initialize device environment by calling initialize_device first.') - return _device_env + _logger.info(f'Initialized device {denv.device}. ' + f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.') + print(denv) # FIXME temporary print for debugging + set_global_device(denv) + return denv diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py index a565c1c8..46517f7a 100644 --- a/timm/bits/device_env_xla.py +++ b/timm/bits/device_env_xla.py @@ -1,7 +1,7 @@ import os from contextlib import suppress from dataclasses import dataclass, field, InitVar -from typing import Optional +from typing import Optional, Dict import torch from torch.distributed import ReduceOp @@ -42,7 +42,12 @@ def is_xla_available(xla_device_type=None): @dataclass class DeviceEnvXla(DeviceEnv): - def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]): + def __post_init__( + self, + device_type: Optional[str], + device_idx: Optional[int], + channels_last: bool, + ): if device_type is not None: device_type = device_type.upper() assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')" @@ -59,6 +64,8 @@ class DeviceEnvXla(DeviceEnv): assert xa is not None, 'XLA AMP is not present on this build' if self.autocast is None: self.autocast = xa.autocast if self.amp else suppress + if channels_last: + self.memory_format = torch.channels_last @property def type(self) -> DeviceEnvType: @@ -114,3 +121,11 @@ class DeviceEnvXla(DeviceEnv): def barrier(self): xm.rendezvous('timm.bits.dist_barrier') + + def state_dict_to_cpu(self, state: Dict[str, torch.Tensor]): + cpu_state = xm._maybe_convert_to_cpu(state, convert=True) + return cpu_state + + def state_dict_to_device(self, state: Dict[str, torch.Tensor]): + device_state = xm.send_cpu_data_to_device(state, device=self.device) + return device_state diff --git a/timm/bits/distributed.py b/timm/bits/distributed.py index 55f9adf5..0b5df830 100644 --- a/timm/bits/distributed.py +++ b/timm/bits/distributed.py @@ -5,8 +5,7 @@ from torch.distributed import ReduceOp from timm.utils import unwrap_model -from .device_env import DeviceEnv, DeviceEnvType -from .device_env_factory import get_device +from .device_env import DeviceEnv TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]] @@ -22,7 +21,7 @@ def _validate_type(tensor: TensorSeq): def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None): if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() # ensure every node has the same running bn stats for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): if ('running_mean' in bn_name) or ('running_var' in bn_name): @@ -40,7 +39,7 @@ def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None """ _validate_type(tensor) if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() if isinstance(tensor, torch.Tensor): return dev_env.all_gather(tensor, cat_dim=cat_dim) elif isinstance(tensor, dict): @@ -55,7 +54,7 @@ def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_ """ _validate_type(tensor) if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() if isinstance(tensor, torch.Tensor): return dev_env.all_reduce_(tensor, op=op, average=average) elif isinstance(tensor, dict): @@ -70,7 +69,7 @@ def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = N """ _validate_type(tensor) if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() if isinstance(tensor, torch.Tensor): return dev_env.broadcast_(tensor, src_rank=src_rank) elif isinstance(tensor, dict): @@ -85,7 +84,7 @@ def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv """ _validate_type(tensor) if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() with torch.no_grad(): names = None @@ -124,7 +123,7 @@ def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_e """ _validate_type(tensor) if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() with torch.no_grad(): names = None diff --git a/timm/bits/metric.py b/timm/bits/metric.py index 7a5cc997..b18282b8 100644 --- a/timm/bits/metric.py +++ b/timm/bits/metric.py @@ -6,14 +6,13 @@ import torch from torch.distributed import ReduceOp from .device_env import DeviceEnv -from .device_env_factory import get_device from .distributed import all_gather_sequence, all_reduce_sequence -MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]] +MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]] @dataclass class ValueInfo: - initial: Optional[MetricValue] = 0. + initial: Optional[MetricValueT] = 0. dtype: torch.dtype = torch.float32 dist_reduce: str = 'sum' dist_average: bool = False @@ -23,10 +22,10 @@ class Metric(abc.ABC): def __init__(self, dev_env: DeviceEnv = None): self._infos: Dict[str, ValueInfo] = {} - self._values: Dict[str, Optional[MetricValue]] = {} - self._values_dist: Dict[str, Optional[MetricValue]] = {} + self._values: Dict[str, Optional[MetricValueT]] = {} + self._values_dist: Dict[str, Optional[MetricValueT]] = {} if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() self._dev_env = dev_env def _register_value(self, name: str, info: Optional[ValueInfo] = None): @@ -117,7 +116,7 @@ class Metric(abc.ABC): names.append(name) values.append(value) reductions.append(_args(info.dist_reduce)) - same_dsr = False + if same_dsr: do_gather, reduce_kwargs = reductions[0] if do_gather: diff --git a/timm/bits/logger.py b/timm/bits/monitor.py similarity index 98% rename from timm/bits/logger.py rename to timm/bits/monitor.py index a7948a8b..af397e1a 100644 --- a/timm/bits/logger.py +++ b/timm/bits/monitor.py @@ -21,8 +21,6 @@ except ImportError: HAS_WANDB = False -from .device_env_factory import get_device - # FIXME old formatting for reference, to remove # # def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''): @@ -122,19 +120,19 @@ def _add_kwargs(text_update, name_map=None, **kwargs): text_update += [_to_str(name, v)] -class Logger: +class Monitor: def __init__( self, experiment_name=None, output_dir=None, - python_logger=None, + logger=None, hparams=None, log_wandb=False, output_enabled=True, ): self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging - self.logger = python_logger or logging.getLogger('log') + self.logger = logger or logging.getLogger('log') hparams = hparams or {} # Setup CSV writer(s) diff --git a/timm/bits/train_services.py b/timm/bits/train_services.py index 286a4afc..5ead002d 100644 --- a/timm/bits/train_services.py +++ b/timm/bits/train_services.py @@ -1,13 +1,13 @@ from dataclasses import dataclass -from .logger import Logger -from timm.utils.checkpoint_saver import CheckpointSaver +from .monitor import Monitor +from .checkpoint_manager import CheckpointManager @dataclass class TrainServices: """ Train Loop Services """ - logger: Logger = None - saver: CheckpointSaver = None + logger: Monitor = None + checkpoint_manager: CheckpointManager = None diff --git a/timm/bits/train_setup.py b/timm/bits/train_setup.py index 3884958b..1480de63 100644 --- a/timm/bits/train_setup.py +++ b/timm/bits/train_setup.py @@ -13,7 +13,7 @@ try: except ImportError: ds = None -from .checkpoint import resume_train_checkpoint +from .checkpoint import load_train_state from .device_env import DeviceEnv from .train_cfg import TrainCfg from .train_state import TrainState @@ -90,10 +90,10 @@ def setup_model_and_optimizer( if resume_path: # FIXME this is not implemented yet, do a hack job before proper TrainState serialization? - resume_train_checkpoint( + load_train_state( train_state, resume_path, - resume_opt=resume_opt, + load_opt=resume_opt, log_info=dev_env.primary) if dev_env.distributed: @@ -141,10 +141,10 @@ def setup_model_and_optimizer_deepspeed( if resume_path: # FIXME deepspeed resumes differently - resume_train_checkpoint( + load_legacy_checkpoint( train_state, resume_path, - resume_opt=resume_opt, + load_opt=resume_opt, log_info=dev_env.primary) if dev_env.distributed: diff --git a/timm/bits/train_state.py b/timm/bits/train_state.py index 9a9a0d92..9c47b5fd 100644 --- a/timm/bits/train_state.py +++ b/timm/bits/train_state.py @@ -4,6 +4,8 @@ from typing import Dict, Any from torch import nn as nn from timm.scheduler import Scheduler +from timm.utils import get_state_dict, unwrap_model + from .updater import Updater @@ -16,18 +18,33 @@ class TrainState: lr_scheduler: Scheduler = None model_ema: nn.Module = None - step_count_epoch: int = 0 - step_count_global: int = 0 epoch: int = 0 + step_count: int = 0 + step_count_global: int = 0 def __post_init__(self): assert self.model is not None assert self.updater is not None - -def serialize_train_state(train_state: TrainState): - pass - - -def deserialize_train_state(checkpoint: Dict[str, Any]): - pass \ No newline at end of file + def state_dict(self, unwrap_fn=unwrap_model): + state = dict( + epoch=self.epoch, + step_count=self.step_count, + step_count_global=self.step_count_global, + model=get_state_dict(self.model, unwrap_fn), + model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn), + ) + # FIXME lr_scheduler state save? + state.update(self.updater.state_dict()) + return state + + def load_state_dict(self, state_dict, unwrap_fn=unwrap_model): + self.epoch = state_dict['epoch'] + self.step_count = state_dict['step_count'] + self.step_count_global = state_dict['step_count_global'] + + unwrap_fn(self.model).load_state_dict(state_dict.get('model')) + if 'model_ema' in state_dict and self.model_ema is not None: + unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema')) + + self.updater.load_state_dict(state_dict) diff --git a/timm/bits/updater.py b/timm/bits/updater.py index 422d12ec..0bf1c451 100644 --- a/timm/bits/updater.py +++ b/timm/bits/updater.py @@ -56,6 +56,7 @@ class Updater: state_dict = dict(optimizer=self.optimizer.state_dict()) if self.grad_scaler is not None: state_dict['grad_scaler'] = self.grad_scaler.state_dict() + return state_dict def load_state_dict(self, state_dict): if 'optimizer' in state_dict: @@ -66,3 +67,6 @@ class Updater: def after_step(self, after_step_fn, *args): after_step_fn(*args) + @property + def deepspeed(self): + return False diff --git a/timm/bits/updater_deepspeed.py b/timm/bits/updater_deepspeed.py index e080a7de..f3c4b3b0 100644 --- a/timm/bits/updater_deepspeed.py +++ b/timm/bits/updater_deepspeed.py @@ -24,3 +24,7 @@ class UpdaterDeepSpeed(Updater): self.model.backward(loss) self.model.step() self.reset() + + @property + def deepspeed(self): + return True diff --git a/timm/bits/updater_factory.py b/timm/bits/updater_factory.py index 24ef76c0..c3fd9e45 100644 --- a/timm/bits/updater_factory.py +++ b/timm/bits/updater_factory.py @@ -3,7 +3,6 @@ from typing import Callable, Optional, Union, Any import torch from .device_env import DeviceEnv, DeviceEnvType -from .device_env_factory import get_device from .updater import Updater from .updater_cuda import UpdaterCudaWithScaler from .updater_deepspeed import UpdaterDeepSpeed @@ -21,7 +20,7 @@ def create_updater( ) -> Updater: if not dev_env: - dev_env = get_device() + dev_env = DeviceEnv.instance() updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value) use_scaler = dev_env.amp diff --git a/timm/data/loader.py b/timm/data/loader.py index 5ddcc6d2..e8722b29 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -8,7 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman import torch.utils.data -from timm.bits import get_device, DeviceEnvType +from timm.bits import DeviceEnv from .fetcher import Fetcher from .prefetcher_cuda import PrefetcherCuda @@ -75,7 +75,7 @@ def create_loader( ) if dev_env is None: - dev_env = get_device() + dev_env = DeviceEnv.instance() sampler = None if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset): diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 519be03d..32dac26d 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -23,7 +23,7 @@ except ImportError as e: exit(1) from .parser import Parser -from timm.bits import get_device +from timm.bits import get_global_device MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue @@ -80,7 +80,7 @@ class ParserTfds(Parser): self.worker_info = None self.dist_rank = 0 self.dist_num_replicas = 1 - dev_env = get_device() + dev_env = get_global_device() # FIXME allow to work without devenv usage? if dev_env.distributed and dev_env.world_size > 1: self.dist_rank = dev_env.global_rank diff --git a/timm/utils/model.py b/timm/utils/model.py index bd46e2f4..66f7480e 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -3,33 +3,38 @@ Hacked together by / Copyright 2020 Ross Wightman """ from .model_ema import ModelEma -import torch +import torch import fnmatch -def unwrap_model(model): - if isinstance(model, ModelEma): - return unwrap_model(model.ema) - else: - return model.module if hasattr(model, 'module') else model +_SUB_MODULE_ATTR = ('module', 'model') + + +def unwrap_model(model, recursive=True): + for attr in _SUB_MODULE_ATTR: + sub_module = getattr(model, attr, None) + if sub_module is not None: + return unwrap_model(sub_module) if recursive else sub_module + return model def get_state_dict(model, unwrap_fn=unwrap_model): return unwrap_fn(model).state_dict() -def avg_sq_ch_mean(model, input, output): - "calculate average channel square mean of output activations" - return torch.mean(output.mean(axis=[0,2,3])**2).item() +def avg_sq_ch_mean(model, input, output): + """calculate average channel square mean of output activations + """ + return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item() -def avg_ch_var(model, input, output): - "calculate average channel variance of output activations" - return torch.mean(output.var(axis=[0,2,3])).item()\ +def avg_ch_var(model, input, output): + """calculate average channel variance of output activations""" + return torch.mean(output.var(axis=[0, 2, 3])).item() -def avg_ch_var_residual(model, input, output): - "calculate average channel variance of output activations" - return torch.mean(output.var(axis=[0,2,3])).item() +def avg_ch_var_residual(model, input, output): + """calculate average channel variance of output activations""" + return torch.mean(output.var(axis=[0, 2, 3])).item() class ActivationStatsHook: @@ -58,15 +63,16 @@ class ActivationStatsHook: raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ their lengths are different.") self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) - for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): + for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): self.register_hook(hook_fn_loc, hook_fn) def _create_hook(self, hook_fn): def append_activation_stats(module, input, output): out = hook_fn(module, input, output) self.stats[hook_fn.__name__].append(out) + return append_activation_stats - + def register_hook(self, hook_fn_loc, hook_fn): for name, module in self.model.named_modules(): if not fnmatch.fnmatch(name, hook_fn_loc): @@ -74,9 +80,9 @@ class ActivationStatsHook: module.register_forward_hook(self._create_hook(hook_fn)) -def extract_spp_stats(model, +def extract_spp_stats(model, hook_fn_locs, - hook_fns, + hook_fns, input_shape=[8, 3, 224, 224]): """Extract average square channel mean and variance of activations during forward pass to plot Signal Propogation Plots (SPP). @@ -84,9 +90,8 @@ def extract_spp_stats(model, Paper: https://arxiv.org/abs/2101.08692 Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 - """ + """ x = torch.normal(0., 1., input_shape) hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats - \ No newline at end of file diff --git a/train.py b/train.py index 51645e4d..c6142542 100755 --- a/train.py +++ b/train.py @@ -28,14 +28,14 @@ import torch import torch.nn as nn import torchvision.utils -from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\ - TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn +from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\ + TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, convert_splitbn_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy -from timm.optim import create_optimizer_v2, optimizer_kwargs +from timm.optim import optimizer_kwargs from timm.scheduler import create_scheduler -from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver +from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model _logger = logging.getLogger('train') @@ -276,7 +276,7 @@ def main(): setup_default_logging() args, args_text = _parse_args() - dev_env = initialize_device(amp=args.amp) + dev_env = initialize_device(amp=args.amp, channels_last=args.channels_last) if dev_env.distributed: _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.' % (dev_env.global_rank, dev_env.world_size)) @@ -293,13 +293,17 @@ def main(): # FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS? random_seed(args.seed, dev_env.global_rank) - data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active) + data_config, loader_eval, loader_train = setup_data( + args, + unwrap_model(train_state.model).default_cfg, + dev_env, + mixup_active) - # setup checkpoint saver + # setup checkpoint manager eval_metric = args.eval_metric best_metric = None best_epoch = None - saver = None + checkpoint_manager = None output_dir = None if dev_env.primary: if args.experiment: @@ -311,24 +315,20 @@ def main(): str(data_config['input_size'][-1]) ]) output_dir = get_outdir(args.output if args.output else './output/train', exp_name) - decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver( # TODO CheckpointSaverV2 - model=train_state.model, - optimizer=train_state.updater.optimizer, - args=args, - model_ema=train_state.model_ema, - amp_scaler=train_state.updater.grad_scaler, + checkpoint_manager = CheckpointManager( + hparams=vars(args), checkpoint_dir=output_dir, recovery_dir=output_dir, - decreasing=decreasing, + metric_name=eval_metric, + metric_decreasing=True if eval_metric == 'loss' else False, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) services = TrainServices( - logger=Logger( - output_dir=output_dir, python_logger=_logger, hparams=vars(args), output_enabled=dev_env.primary), - saver=saver, + logger=Monitor( + output_dir=output_dir, logger=_logger, hparams=vars(args), output_enabled=dev_env.primary), + checkpoint_manager=checkpoint_manager, ) try: @@ -379,10 +379,10 @@ def main(): if services.logger is not None: services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) - if saver is not None: + if checkpoint_manager is not None: # save proper checkpoint with eval metric - save_metric = eval_metrics[eval_metric] - best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + best_checkpoint = checkpoint_manager.save_checkpoint(train_state, eval_metrics) + best_metric, best_epoch = best_checkpoint.sort_key, best_checkpoint.epoch train_state = replace(train_state, epoch=epoch + 1) @@ -629,9 +629,9 @@ def after_train_step( lr=lr_avg, ) - if services.saver is not None and cfg.recovery_interval and ( + if services.checkpoint_manager is not None and cfg.recovery_interval and ( end_step or (step_idx + 1) % cfg.recovery_interval == 0): - services.saver.save_recovery(state.epoch, batch_idx=step_idx) + services.checkpoint_manager.save_recovery(state.epoch, batch_idx=step_idx) if state.lr_scheduler is not None: state.lr_scheduler.step_update(num_updates=state.step_count_global) @@ -641,7 +641,7 @@ def evaluate( model: nn.Module, loss_fn: nn.Module, loader, - logger: Logger, + logger: Monitor, dev_env: DeviceEnv, phase_suffix: str = '', log_interval: int = 10, diff --git a/validate.py b/validate.py index 89d70982..cee359c3 100755 --- a/validate.py +++ b/validate.py @@ -18,7 +18,7 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict -from timm.bits import initialize_device, Tracker, Logger, AccuracyTopK, AvgTensor +from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import natural_key, setup_default_logging @@ -154,7 +154,7 @@ def validate(args): pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) - logger = Logger(python_logger=_logger) + logger = Monitor(logger=_logger) tracker = Tracker() losses = AvgTensor() accuracy = AccuracyTopK(dev_env=dev_env)