diff --git a/timm/models/helpers.py b/timm/models/helpers.py index f1702af0..ac119295 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): model.load_state_dict(state_dict, strict=strict) -def resume_checkpoint(model, checkpoint_path): - other_state = {} +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) - if 'optimizer' in checkpoint: - other_state['optimizer'] = checkpoint['optimizer'] - if 'amp' in checkpoint: - other_state['amp'] = checkpoint['amp'] + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + if 'epoch' in checkpoint: resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save - _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) - _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return other_state, resume_epoch + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/timm/utils.py b/timm/utils.py index 6c10d283..94f85d84 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -37,20 +37,67 @@ def unwrap_model(model): return model.module if hasattr(model, 'module') else model -def get_state_dict(model): - return unwrap_model(model).state_dict() +def get_state_dict(model, unwrap_fn=unwrap_model): + return unwrap_fn(model).state_dict() + + +class ApexScaler: + state_dict_key = "amp" + + def __call__(self, loss, optimizer): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScaler: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer): + self._scaler.scale(loss).backward() + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) class CheckpointSaver: def __init__( self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, checkpoint_prefix='checkpoint', recovery_prefix='recovery', checkpoint_dir='', recovery_dir='', decreasing=False, max_history=10, - save_amp=False): + unwrap_fn=unwrap_model): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler # state self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness @@ -68,14 +115,14 @@ class CheckpointSaver: self.decreasing = decreasing # a lower metric is better if True self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.max_history = max_history - self.save_apex_amp = save_amp # save APEX amp state + self.unwrap_fn = unwrap_fn assert self.max_history >= 1 - def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): + def save_checkpoint(self, epoch, metric=None): assert epoch >= 0 tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) - self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric) + self._save(tmp_save_path, epoch, metric) if os.path.exists(last_save_path): os.unlink(last_save_path) # required for Windows support. os.rename(tmp_save_path, last_save_path) @@ -107,19 +154,21 @@ class CheckpointSaver: return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) - def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): + def _save(self, save_path, epoch, metric=None): save_state = { 'epoch': epoch, - 'arch': args.model, - 'state_dict': get_state_dict(model), - 'optimizer': optimizer.state_dict(), - 'args': args, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), 'version': 2, # version < 2 increments epoch before save } - if self.save_apex_amp and 'state_dict' in amp.__dict__: - save_state['amp'] = amp.state_dict() - if model_ema is not None: - save_state['state_dict_ema'] = get_state_dict(model_ema) + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) if metric is not None: save_state['metric'] = metric torch.save(save_state, save_path) @@ -138,11 +187,11 @@ class CheckpointSaver: _logger.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] - def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): + def save_recovery(self, epoch, batch_idx=0): assert epoch >= 0 filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension save_path = os.path.join(self.recovery_dir, filename) - self._save(save_path, model, optimizer, args, epoch, model_ema) + self._save(save_path, epoch) if os.path.exists(self.last_recovery_file): try: _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) @@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''): group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) parser.set_defaults(**{dest_name: default}) + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) diff --git a/train.py b/train.py index e017605d..ff421d53 100755 --- a/train.py +++ b/train.py @@ -20,7 +20,6 @@ import yaml from datetime import datetime from contextlib import suppress -import torch import torch.nn as nn import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP @@ -31,6 +30,7 @@ from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler +from timm.utils import ApexScaler, NativeScaler try: from apex import amp @@ -264,23 +264,6 @@ def _parse_args(): return args, args_text -class ApexScaler: - def __call__(self, loss, optimizer): - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - - -class NativeScaler: - def __init__(self): - self._scaler = torch.cuda.amp.GradScaler() - - def __call__(self, loss, optimizer): - self._scaler.scale(loss).backward() - self._scaler.step(optimizer) - self._scaler.update() - - def main(): setup_default_logging() args, args_text = _parse_args() @@ -389,20 +372,13 @@ def main(): _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint - resume_state = {} resume_epoch = None if args.resume: - resume_state, resume_epoch = resume_checkpoint(model, args.resume) - if resume_state and not args.no_resume_opt: - if 'optimizer' in resume_state: - if args.local_rank == 0: - _logger.info('Restoring optimizer state from checkpoint') - optimizer.load_state_dict(resume_state['optimizer']) - if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: - if args.local_rank == 0: - _logger.info('Restoring NVIDIA AMP state from checkpoint') - amp.load_state_dict(resume_state['amp']) - del resume_state + resume_epoch = resume_checkpoint( + model, args.resume, + optimizer=None if args.no_resume_opt else optimizer, + loss_scaler=None if args.no_resume_opt else loss_scaler, + log_info=args.local_rank == 0) model_ema = None if args.model_ema: @@ -555,7 +531,9 @@ def main(): ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex') + saver = CheckpointSaver( + model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, + checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) @@ -594,8 +572,7 @@ def main(): if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] - best_metric, best_epoch = saver.save_checkpoint( - model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric) + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass @@ -688,8 +665,7 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - - saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) + saver.save_recovery(epoch, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) diff --git a/validate.py b/validate.py index f20c35ba..587ac6a6 100755 --- a/validate.py +++ b/validate.py @@ -21,7 +21,7 @@ from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy has_apex = False try: @@ -102,19 +102,6 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', help='Valid label indices txt file for validation of partial label space') -def set_jit_legacy(): - """ Set JIT executor to legacy w/ support for op fusion - This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes - in the JIT exectutor. These API are not supported so could change. - """ - # - assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" - torch._C._jit_set_profiling_executor(False) - torch._C._jit_set_profiling_mode(False) - torch._C._jit_override_can_fuse_on_gpu(True) - #torch._C._jit_set_texpr_fuser_enabled(True) - - def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint