From ba3c97c3adc6919ce431049d8978a81dc2465de9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Aug 2019 15:14:35 -0700 Subject: [PATCH 1/3] Some Lookahead cleanup and fixes --- timm/optim/lookahead.py | 80 ++++++++++++++++++++-------------------- timm/optim/nvnovograd.py | 0 2 files changed, 41 insertions(+), 39 deletions(-) create mode 100644 timm/optim/nvnovograd.py diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py index cc1fb495..7a58e0a6 100644 --- a/timm/optim/lookahead.py +++ b/timm/optim/lookahead.py @@ -13,37 +13,40 @@ class Lookahead(Optimizer): raise ValueError(f'Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError(f'Invalid lookahead steps: {k}') - self.alpha = alpha - self.k = k + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) self.base_optimizer = base_optimizer self.param_groups = self.base_optimizer.param_groups self.defaults = base_optimizer.defaults + self.defaults.update(defaults) self.state = defaultdict(dict) - for group in self.param_groups: - group["step_counter"] = 0 + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) - def update_slow_weights(self, group): + def update_slow(self, group): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self.state[fast_p] - if "slow_buffer" not in param_state: - param_state["slow_buffer"] = torch.empty_like(fast_p.data) - param_state["slow_buffer"].copy_(fast_p.data) - slow = param_state["slow_buffer"] - slow.add_(self.alpha, fast_p.data - slow) + if 'slow_buffer' not in param_state: + param_state['slow_buffer'] = torch.empty_like(fast_p.data) + param_state['slow_buffer'].copy_(fast_p.data) + slow = param_state['slow_buffer'] + slow.add_(group['lookahead_alpha'], fast_p.data - slow) fast_p.data.copy_(slow) def sync_lookahead(self): for group in self.param_groups: - self.update_slow_weights(group) + self.update_slow(group) def step(self, closure=None): + #assert id(self.param_groups) == id(self.base_optimizer.param_groups) loss = self.base_optimizer.step(closure) for group in self.param_groups: - group['step_counter'] += 1 - if group['step_counter'] % self.k == 0: - self.update_slow_weights(group) + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) return loss def state_dict(self): @@ -52,37 +55,36 @@ class Lookahead(Optimizer): (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } - fast_state = fast_state_dict["state"] - param_groups = fast_state_dict["param_groups"] + fast_state = fast_state_dict['state'] + param_groups = fast_state_dict['param_groups'] return { - "state": fast_state, - "slow_state": slow_state, - "param_groups": param_groups, + 'state': fast_state, + 'slow_state': slow_state, + 'param_groups': param_groups, } def load_state_dict(self, state_dict): + fast_state_dict = { + 'state': state_dict['state'], + 'param_groups': state_dict['param_groups'], + } + self.base_optimizer.load_state_dict(fast_state_dict) + + # We want to restore the slow state, but share param_groups reference + # with base_optimizer. This is a bit redundant but least code + slow_state_new = False if 'slow_state' not in state_dict: - print('Loading state_dict from optimizer without Lookahead applied') + print('Loading state_dict from optimizer without Lookahead applied.') state_dict['slow_state'] = defaultdict(dict) + slow_state_new = True slow_state_dict = { - "state": state_dict["slow_state"], - "param_groups": state_dict["param_groups"], - } - fast_state_dict = { - "state": state_dict["state"], - "param_groups": state_dict["param_groups"], + 'state': state_dict['slow_state'], + 'param_groups': state_dict['param_groups'], # this is pointless but saves code } super(Lookahead, self).load_state_dict(slow_state_dict) - self.base_optimizer.load_state_dict(fast_state_dict) - - def add_param_group(self, param_group): - r"""Add a param group to the :class:`Optimizer` s `param_groups`. - This can be useful when fine tuning a pre-trained network as frozen - layers can be made trainable and added to the :class:`Optimizer` as - training progresses. - Args: - param_group (dict): Specifies what Tensors should be optimized along - with group specific optimization options. - """ - param_group['step_counter'] = 0 - self.base_optimizer.add_param_group(param_group) + self.param_groups = self.base_optimizer.param_groups # make both ref same container + if slow_state_new: + # reapply defaults to catch missing lookahead specific ones + for name, default in self.defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py new file mode 100644 index 00000000..e69de29b From 3d9c8a6489541b1863716a2d20241fa74f1238fa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Aug 2019 15:19:18 -0700 Subject: [PATCH 2/3] Add support for new AMP checkpointing support w/ amp.state_dict --- timm/models/helpers.py | 8 +++++--- timm/utils.py | 18 +++++++++++++----- train.py | 30 ++++++++++++++++++++---------- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index a37f5062..16bdacfb 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -29,7 +29,7 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): def resume_checkpoint(model, checkpoint_path): - optimizer_state = None + other_state = {} resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -40,7 +40,9 @@ def resume_checkpoint(model, checkpoint_path): new_state_dict[name] = v model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: - optimizer_state = checkpoint['optimizer'] + other_state['optimizer'] = checkpoint['optimizer'] + if 'amp' in checkpoint: + other_state['amp'] = checkpoint['amp'] if 'epoch' in checkpoint: resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: @@ -49,7 +51,7 @@ def resume_checkpoint(model, checkpoint_path): else: model.load_state_dict(checkpoint) logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return optimizer_state, resume_epoch + return other_state, resume_epoch else: logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/timm/utils.py b/timm/utils.py index 7de38a80..8ed8f195 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -11,6 +11,12 @@ import operator import logging import numpy as np from collections import OrderedDict +try: + from apex import amp + has_apex = True +except ImportError: + amp = None + has_apex = False from torch import distributed as dist @@ -50,7 +56,7 @@ class CheckpointSaver: self.max_history = max_history assert self.max_history >= 1 - def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): + def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): assert epoch >= 0 worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None if (len(self.checkpoint_files) < self.max_history @@ -59,7 +65,7 @@ class CheckpointSaver: self._cleanup_checkpoints(1) filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension save_path = os.path.join(self.checkpoint_dir, filename) - self._save(save_path, model, optimizer, args, epoch, model_ema, metric) + self._save(save_path, model, optimizer, args, epoch, model_ema, metric, use_amp) self.checkpoint_files.append((save_path, metric)) self.checkpoint_files = sorted( self.checkpoint_files, key=lambda x: x[1], @@ -77,7 +83,7 @@ class CheckpointSaver: return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) - def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): + def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False): save_state = { 'epoch': epoch, 'arch': args.model, @@ -86,6 +92,8 @@ class CheckpointSaver: 'args': args, 'version': 2, # version < 2 increments epoch before save } + if use_amp and 'state_dict' in amp.__dict__: + save_state['amp'] = amp.state_dict() if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) if metric is not None: @@ -106,11 +114,11 @@ class CheckpointSaver: logging.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] - def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): + def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0): assert epoch >= 0 filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension save_path = os.path.join(self.recovery_dir, filename) - self._save(save_path, model, optimizer, args, epoch, model_ema) + self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp) if os.path.exists(self.last_recovery_file): try: logging.debug("Cleaning recovery: {}".format(self.last_recovery_file)) diff --git a/train.py b/train.py index 9c75e050..78c8a12c 100644 --- a/train.py +++ b/train.py @@ -38,6 +38,8 @@ parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH' help='Initialize model from this checkpoint (default: none)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') +parser.add_argument('--no-resume-opt', action='store_true', default=False, + help='prevent resume of optimizer state when resuming model') parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') parser.add_argument('--gp', default='avg', type=str, metavar='POOL', @@ -189,12 +191,6 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) - # optionally resume from a checkpoint - optimizer_state = None - resume_epoch = None - if args.resume: - optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) - if args.num_gpu > 1: if args.amp: logging.warning( @@ -205,8 +201,6 @@ def main(): model.cuda() optimizer = create_optimizer(args, model) - if optimizer_state is not None: - optimizer.load_state_dict(optimizer_state) use_amp = False if has_apex and args.amp: @@ -216,6 +210,22 @@ def main(): logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) + # optionally resume from a checkpoint + resume_state = {} + resume_epoch = None + if args.resume: + resume_state, resume_epoch = resume_checkpoint(model, args.resume) + if resume_state and not args.no_resume_opt: + if 'optimizer' in resume_state: + if args.local_rank == 0: + logging.info('Restoring Optimizer state from checkpoint') + optimizer.load_state_dict(resume_state['optimizer']) + if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: + if args.local_rank == 0: + logging.info('Restoring NVIDIA AMP state from checkpoint') + amp.load_state_dict(resume_state['amp']) + resume_state = None + model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper @@ -363,7 +373,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch, model_ema=model_ema, metric=save_metric) + epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) except KeyboardInterrupt: pass @@ -456,7 +466,7 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery( - model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) From 64966f61f70f64c44205b0c6b0944b268de22278 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Aug 2019 15:21:38 -0700 Subject: [PATCH 3/3] Add Nvidia's NovogGrad impl from Jasper (cleaner/faster than current) and Apex Fused optimizers --- timm/optim/__init__.py | 1 + timm/optim/nvnovograd.py | 118 ++++++++++++++++++++++++++++++++++++ timm/optim/optim_factory.py | 37 +++++++++-- 3 files changed, 150 insertions(+), 6 deletions(-) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 3213cd68..994b36d2 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -3,5 +3,6 @@ from .rmsprop_tf import RMSpropTF from .adamw import AdamW from .radam import RAdam from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad from .lookahead import Lookahead from .optim_factory import create_optimizer diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py index e69de29b..323312d2 100644 --- a/timm/optim/nvnovograd.py +++ b/timm/optim/nvnovograd.py @@ -0,0 +1,118 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.data.add_(-group['lr'], exp_avg) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c51bdf20..553a6b6d 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,5 +1,11 @@ +import torch from torch import optim as optim -from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead +from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False def add_weight_decay(model, weight_decay=1e-5, skip_list=()): @@ -20,9 +26,10 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay - if opt_lower == 'adamw' or opt_lower == 'radam': - # compensate for the way current AdamW and RAdam optimizers - # apply the weight-decay + if 'adamw' in opt_lower or 'radam' in opt_lower: + # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay + # I don't believe they follow the paper or original Torch7 impl which schedules weight + # decay based on the ratio of current_lr/initial_lr weight_decay /= args.lr if weight_decay and filter_bias_and_bn: parameters = add_weight_decay(model, weight_decay) @@ -30,12 +37,14 @@ def create_optimizer(args, model, filter_bias_and_bn=True): else: parameters = model.parameters() + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd': optimizer = optim.SGD( - parameters, lr=args.lr, - momentum=args.momentum, weight_decay=weight_decay, nesterov=True) + parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif opt_lower == 'adam': optimizer = optim.Adam( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) @@ -61,6 +70,22 @@ def create_optimizer(args, model, filter_bias_and_bn=True): momentum=args.momentum, weight_decay=weight_decay) elif opt_lower == 'novograd': optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'nvnovograd': + optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'fusedsgd': + optimizer = FusedSGD( + parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) + elif opt_lower == 'fusedadam': + optimizer = FusedAdam( + parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'fusedadamw': + optimizer = FusedAdam( + parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'fusedlamb': + optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'fusednovograd': + optimizer = FusedNovoGrad( + parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError