From 5b9c69e80a5ff4ecfd62429246e51a1de0f834fe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 May 2021 18:08:06 -0700 Subject: [PATCH] Add basic training resume based on legacy code --- timm/bits/checkpoint.py | 90 ++++++++++++++++++++++------------------- train.py | 25 ++++++------ 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/timm/bits/checkpoint.py b/timm/bits/checkpoint.py index 3c191b0a..b7ff1909 100644 --- a/timm/bits/checkpoint.py +++ b/timm/bits/checkpoint.py @@ -4,55 +4,63 @@ from collections import OrderedDict import torch +from timm.utils import unwrap_model + from .train_state import TrainState, serialize_train_state, deserialize_train_state _logger = logging.getLogger(__name__) +def _load_state_dict(checkpoint, state_dict_key='state_dict'): + new_state_dict = OrderedDict() + for k, v in checkpoint[state_dict_key].items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + return new_state_dict + + def resume_train_checkpoint( - train_state, + train_state: TrainState, checkpoint_path, resume_opt=True, deserialize_fn=deserialize_train_state, log_info=True): - raise NotImplementedError - - # resume_epoch = None - # if os.path.isfile(checkpoint_path): - # checkpoint = torch.load(checkpoint_path, map_location='cpu') - # - # if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - # if log_info: - # _logger.info('Restoring model state from checkpoint...') - # new_state_dict = OrderedDict() - # for k, v in checkpoint['state_dict'].items(): - # name = k[7:] if k.startswith('module') else k - # new_state_dict[name] = v - # model.load_state_dict(new_state_dict) - # - # if optimizer is not None and 'optimizer' in checkpoint: - # if log_info: - # _logger.info('Restoring optimizer state from checkpoint...') - # optimizer.load_state_dict(checkpoint['optimizer']) - # - # if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: - # if log_info: - # _logger.info('Restoring AMP loss scaler state from checkpoint...') - # loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) - # - # if 'epoch' in checkpoint: - # resume_epoch = checkpoint['epoch'] - # if 'version' in checkpoint and checkpoint['version'] > 1: - # resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - # - # if log_info: - # _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) - # else: - # model.load_state_dict(checkpoint) - # if log_info: - # _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - # return resume_epoch - # else: - # _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) - # raise FileNotFoundError() + # FIXME this is a hacky adaptation of pre-bits resume to get up and running quickly + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint + if log_info: + _logger.info('Restoring model state from checkpoint...') + + train_state.model.load_state_dict(_load_state_dict(checkpoint)) + + if train_state.model_ema is not None and 'state_dict_ema' in checkpoint: + if log_info: + _logger.info('Restoring model (EMA) state from checkpoint...') + unwrap_model(train_state.model_ema).load_state_dict(_load_state_dict(checkpoint, 'state_dict_ema')) + + if resume_opt: + if train_state.updater.optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + train_state.updater.optimizer.load_state_dict(checkpoint['optimizer']) + + scaler_state_dict_key = 'amp_scaler' + if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + _logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() diff --git a/train.py b/train.py index 95f5cb7e..51645e4d 100755 --- a/train.py +++ b/train.py @@ -340,11 +340,11 @@ def main(): loader_train.mixup_enabled = False train_metrics = train_one_epoch( - dev_env=dev_env, state=train_state, - services=services, cfg=train_cfg, - loader=loader_train + services=services, + loader=loader_train, + dev_env=dev_env, ) if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -356,8 +356,8 @@ def main(): train_state.model, train_state.eval_loss, loader_eval, - dev_env, - logger=services.logger) + services.logger, + dev_env) if train_state.model_ema is not None and not args.model_ema_force_cpu: if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -367,8 +367,8 @@ def main(): train_state.model_ema.module, train_state.eval_loss, loader_eval, + services.logger, dev_env, - logger=services.logger, phase_suffix='EMA') eval_metrics = ema_eval_metrics @@ -432,6 +432,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): clip_value=args.clip_grad, model_ema=args.model_ema, model_ema_decay=args.model_ema_decay, + resume_path=args.resume, use_syncbn=args.sync_bn, ) @@ -543,11 +544,11 @@ def setup_data(args, default_cfg, dev_env, mixup_active): def train_one_epoch( - dev_env: DeviceEnv, state: TrainState, cfg: TrainCfg, services: TrainServices, loader, + dev_env: DeviceEnv, ): tracker = Tracker() loss_meter = AvgTensor() @@ -571,10 +572,10 @@ def train_one_epoch( state.updater.after_step( after_train_step, - dev_env, state, - services, cfg, + services, + dev_env, step_idx, step_end_idx, tracker, @@ -592,10 +593,10 @@ def train_one_epoch( def after_train_step( - dev_env: DeviceEnv, state: TrainState, - services: TrainServices, cfg: TrainCfg, + services: TrainServices, + dev_env: DeviceEnv, step_idx: int, step_end_idx: int, tracker: Tracker, @@ -640,8 +641,8 @@ def evaluate( model: nn.Module, loss_fn: nn.Module, loader, - dev_env: DeviceEnv, logger: Logger, + dev_env: DeviceEnv, phase_suffix: str = '', log_interval: int = 10, ):