Add support for new AMP checkpointing support w/ amp.state_dict

pull/32/head
Ross Wightman 5 years ago
parent ba3c97c3ad
commit 3d9c8a6489

@ -29,7 +29,7 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
def resume_checkpoint(model, checkpoint_path):
optimizer_state = None
other_state = {}
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
@ -40,7 +40,9 @@ def resume_checkpoint(model, checkpoint_path):
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
other_state['optimizer'] = checkpoint['optimizer']
if 'amp' in checkpoint:
other_state['amp'] = checkpoint['amp']
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
@ -49,7 +51,7 @@ def resume_checkpoint(model, checkpoint_path):
else:
model.load_state_dict(checkpoint)
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, resume_epoch
return other_state, resume_epoch
else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -11,6 +11,12 @@ import operator
import logging
import numpy as np
from collections import OrderedDict
try:
from apex import amp
has_apex = True
except ImportError:
amp = None
has_apex = False
from torch import distributed as dist
@ -50,7 +56,7 @@ class CheckpointSaver:
self.max_history = max_history
assert self.max_history >= 1
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
assert epoch >= 0
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if (len(self.checkpoint_files) < self.max_history
@ -59,7 +65,7 @@ class CheckpointSaver:
self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
self._save(save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1],
@ -77,7 +83,7 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
save_state = {
'epoch': epoch,
'arch': args.model,
@ -86,6 +92,8 @@ class CheckpointSaver:
'args': args,
'version': 2, # version < 2 increments epoch before save
}
if use_amp and 'state_dict' in amp.__dict__:
save_state['amp'] = amp.state_dict()
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)
if metric is not None:
@ -106,11 +114,11 @@ class CheckpointSaver:
logging.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema)
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
if os.path.exists(self.last_recovery_file):
try:
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))

@ -38,6 +38,8 @@ parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH'
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
@ -189,12 +191,6 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# optionally resume from a checkpoint
optimizer_state = None
resume_epoch = None
if args.resume:
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
if args.num_gpu > 1:
if args.amp:
logging.warning(
@ -205,8 +201,6 @@ def main():
model.cuda()
optimizer = create_optimizer(args, model)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
use_amp = False
if has_apex and args.amp:
@ -216,6 +210,22 @@ def main():
logging.info('NVIDIA APEX {}. AMP {}.'.format(
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
# optionally resume from a checkpoint
resume_state = {}
resume_epoch = None
if args.resume:
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
if resume_state and not args.no_resume_opt:
if 'optimizer' in resume_state:
if args.local_rank == 0:
logging.info('Restoring Optimizer state from checkpoint')
optimizer.load_state_dict(resume_state['optimizer'])
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
if args.local_rank == 0:
logging.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp'])
resume_state = None
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
@ -363,7 +373,7 @@ def main():
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args,
epoch=epoch, model_ema=model_ema, metric=save_metric)
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
except KeyboardInterrupt:
pass
@ -456,7 +466,7 @@ def train_epoch(
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

Loading…
Cancel
Save