Cleanup Apex vs native AMP scaler state save/load. Cleanup CheckpointSaver a bit.

pull/233/head
Ross Wightman 5 years ago
parent 80c9d9cc72
commit 9c297ec67d

@ -48,30 +48,41 @@ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
def resume_checkpoint(model, checkpoint_path): def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
other_state = {}
resume_epoch = None resume_epoch = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items(): for k, v in checkpoint['state_dict'].items():
name = k[7:] if k.startswith('module') else k name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
other_state['optimizer'] = checkpoint['optimizer'] if optimizer is not None and 'optimizer' in checkpoint:
if 'amp' in checkpoint: if log_info:
other_state['amp'] = checkpoint['amp'] _logger.info('Restoring optimizer state from checkpoint...')
optimizer.load_state_dict(checkpoint['optimizer'])
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
if 'epoch' in checkpoint: if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch'] resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1: if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) if log_info:
return other_state, resume_epoch _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
return resume_epoch
else: else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path)) _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()

@ -37,20 +37,67 @@ def unwrap_model(model):
return model.module if hasattr(model, 'module') else model return model.module if hasattr(model, 'module') else model
def get_state_dict(model): def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_model(model).state_dict() return unwrap_fn(model).state_dict()
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
def state_dict(self):
if 'state_dict' in amp.__dict__:
return amp.state_dict()
def load_state_dict(self, state_dict):
if 'load_state_dict' in amp.__dict__:
amp.load_state_dict(state_dict)
class NativeScaler:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer):
self._scaler.scale(loss).backward()
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
class CheckpointSaver: class CheckpointSaver:
def __init__( def __init__(
self, self,
model,
optimizer,
args=None,
model_ema=None,
amp_scaler=None,
checkpoint_prefix='checkpoint', checkpoint_prefix='checkpoint',
recovery_prefix='recovery', recovery_prefix='recovery',
checkpoint_dir='', checkpoint_dir='',
recovery_dir='', recovery_dir='',
decreasing=False, decreasing=False,
max_history=10, max_history=10,
save_amp=False): unwrap_fn=unwrap_model):
# objects to save state_dicts of
self.model = model
self.optimizer = optimizer
self.args = args
self.model_ema = model_ema
self.amp_scaler = amp_scaler
# state # state
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
@ -68,14 +115,14 @@ class CheckpointSaver:
self.decreasing = decreasing # a lower metric is better if True self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.max_history = max_history self.max_history = max_history
self.save_apex_amp = save_amp # save APEX amp state self.unwrap_fn = unwrap_fn
assert self.max_history >= 1 assert self.max_history >= 1
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): def save_checkpoint(self, epoch, metric=None):
assert epoch >= 0 assert epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
self._save(tmp_save_path, model, optimizer, args, epoch, model_ema, metric) self._save(tmp_save_path, epoch, metric)
if os.path.exists(last_save_path): if os.path.exists(last_save_path):
os.unlink(last_save_path) # required for Windows support. os.unlink(last_save_path) # required for Windows support.
os.rename(tmp_save_path, last_save_path) os.rename(tmp_save_path, last_save_path)
@ -107,19 +154,21 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): def _save(self, save_path, epoch, metric=None):
save_state = { save_state = {
'epoch': epoch, 'epoch': epoch,
'arch': args.model, 'arch': type(self.model).__name__.lower(),
'state_dict': get_state_dict(model), 'state_dict': get_state_dict(self.model, self.unwrap_fn),
'optimizer': optimizer.state_dict(), 'optimizer': self.optimizer.state_dict(),
'args': args,
'version': 2, # version < 2 increments epoch before save 'version': 2, # version < 2 increments epoch before save
} }
if self.save_apex_amp and 'state_dict' in amp.__dict__: if self.args is not None:
save_state['amp'] = amp.state_dict() save_state['arch'] = self.args.model
if model_ema is not None: save_state['args'] = self.args
save_state['state_dict_ema'] = get_state_dict(model_ema) if self.amp_scaler is not None:
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
if self.model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None: if metric is not None:
save_state['metric'] = metric save_state['metric'] = metric
torch.save(save_state, save_path) torch.save(save_state, save_path)
@ -138,11 +187,11 @@ class CheckpointSaver:
_logger.error("Exception '{}' while deleting checkpoint".format(e)) _logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): def save_recovery(self, epoch, batch_idx=0):
assert epoch >= 0 assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename) save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema) self._save(save_path, epoch)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
_logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''):
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
parser.set_defaults(**{dest_name: default}) parser.set_defaults(**{dest_name: default})
def set_jit_legacy():
""" Set JIT executor to legacy w/ support for op fusion
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
in the JIT exectutor. These API are not supported so could change.
"""
#
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
#torch._C._jit_set_texpr_fuser_enabled(True)

@ -20,7 +20,6 @@ import yaml
from datetime import datetime from datetime import datetime
from contextlib import suppress from contextlib import suppress
import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
@ -31,6 +30,7 @@ from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
try: try:
from apex import amp from apex import amp
@ -264,23 +264,6 @@ def _parse_args():
return args, args_text return args, args_text
class ApexScaler:
def __call__(self, loss, optimizer):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer):
self._scaler.scale(loss).backward()
self._scaler.step(optimizer)
self._scaler.update()
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
@ -389,20 +372,13 @@ def main():
_logger.info('AMP not enabled. Training in float32.') _logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint # optionally resume from a checkpoint
resume_state = {}
resume_epoch = None resume_epoch = None
if args.resume: if args.resume:
resume_state, resume_epoch = resume_checkpoint(model, args.resume) resume_epoch = resume_checkpoint(
if resume_state and not args.no_resume_opt: model, args.resume,
if 'optimizer' in resume_state: optimizer=None if args.no_resume_opt else optimizer,
if args.local_rank == 0: loss_scaler=None if args.no_resume_opt else loss_scaler,
_logger.info('Restoring optimizer state from checkpoint') log_info=args.local_rank == 0)
optimizer.load_state_dict(resume_state['optimizer'])
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
if args.local_rank == 0:
_logger.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp'])
del resume_state
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
@ -555,7 +531,9 @@ def main():
]) ])
output_dir = get_outdir(output_base, 'train', exp_name) output_dir = get_outdir(output_base, 'train', exp_name)
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing, save_amp=use_amp == 'apex') saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)
@ -594,8 +572,7 @@ def main():
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric] save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint( best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
model, optimizer, args, epoch=epoch, model_ema=model_ema, metric=save_metric)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -688,8 +665,7 @@ def train_epoch(
if saver is not None and args.recovery_interval and ( if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
saver.save_recovery(model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

@ -21,7 +21,7 @@ from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
has_apex = False has_apex = False
try: try:
@ -102,19 +102,6 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space') help='Valid label indices txt file for validation of partial label space')
def set_jit_legacy():
""" Set JIT executor to legacy w/ support for op fusion
This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
in the JIT exectutor. These API are not supported so could change.
"""
#
assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
#torch._C._jit_set_texpr_fuser_enabled(True)
def validate(args): def validate(args):
# might as well try to validate something # might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint

Loading…
Cancel
Save