Merge pull request #32 from rwightman/opt

More optimizer work
pull/35/head
Ross Wightman 5 years ago committed by GitHub
commit aff194f42c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,7 +29,7 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
def resume_checkpoint(model, checkpoint_path): def resume_checkpoint(model, checkpoint_path):
optimizer_state = None other_state = {}
resume_epoch = None resume_epoch = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
@ -40,7 +40,9 @@ def resume_checkpoint(model, checkpoint_path):
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint: if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer'] other_state['optimizer'] = checkpoint['optimizer']
if 'amp' in checkpoint:
other_state['amp'] = checkpoint['amp']
if 'epoch' in checkpoint: if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch'] resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1: if 'version' in checkpoint and checkpoint['version'] > 1:
@ -49,7 +51,7 @@ def resume_checkpoint(model, checkpoint_path):
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, resume_epoch return other_state, resume_epoch
else: else:
logging.error("No checkpoint found at '{}'".format(checkpoint_path)) logging.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()

@ -3,5 +3,6 @@ from .rmsprop_tf import RMSpropTF
from .adamw import AdamW from .adamw import AdamW
from .radam import RAdam from .radam import RAdam
from .novograd import NovoGrad from .novograd import NovoGrad
from .nvnovograd import NvNovoGrad
from .lookahead import Lookahead from .lookahead import Lookahead
from .optim_factory import create_optimizer from .optim_factory import create_optimizer

@ -13,37 +13,40 @@ class Lookahead(Optimizer):
raise ValueError(f'Invalid slow update rate: {alpha}') raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k: if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}') raise ValueError(f'Invalid lookahead steps: {k}')
self.alpha = alpha defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.k = k
self.base_optimizer = base_optimizer self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict) self.state = defaultdict(dict)
for group in self.param_groups: # manually add our defaults to the param groups
group["step_counter"] = 0 for name, default in defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def update_slow_weights(self, group): def update_slow(self, group):
for fast_p in group["params"]: for fast_p in group["params"]:
if fast_p.grad is None: if fast_p.grad is None:
continue continue
param_state = self.state[fast_p] param_state = self.state[fast_p]
if "slow_buffer" not in param_state: if 'slow_buffer' not in param_state:
param_state["slow_buffer"] = torch.empty_like(fast_p.data) param_state['slow_buffer'] = torch.empty_like(fast_p.data)
param_state["slow_buffer"].copy_(fast_p.data) param_state['slow_buffer'].copy_(fast_p.data)
slow = param_state["slow_buffer"] slow = param_state['slow_buffer']
slow.add_(self.alpha, fast_p.data - slow) slow.add_(group['lookahead_alpha'], fast_p.data - slow)
fast_p.data.copy_(slow) fast_p.data.copy_(slow)
def sync_lookahead(self): def sync_lookahead(self):
for group in self.param_groups: for group in self.param_groups:
self.update_slow_weights(group) self.update_slow(group)
def step(self, closure=None): def step(self, closure=None):
#assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure) loss = self.base_optimizer.step(closure)
for group in self.param_groups: for group in self.param_groups:
group['step_counter'] += 1 group['lookahead_step'] += 1
if group['step_counter'] % self.k == 0: if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow_weights(group) self.update_slow(group)
return loss return loss
def state_dict(self): def state_dict(self):
@ -52,37 +55,36 @@ class Lookahead(Optimizer):
(id(k) if isinstance(k, torch.Tensor) else k): v (id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items() for k, v in self.state.items()
} }
fast_state = fast_state_dict["state"] fast_state = fast_state_dict['state']
param_groups = fast_state_dict["param_groups"] param_groups = fast_state_dict['param_groups']
return { return {
"state": fast_state, 'state': fast_state,
"slow_state": slow_state, 'slow_state': slow_state,
"param_groups": param_groups, 'param_groups': param_groups,
} }
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
fast_state_dict = {
'state': state_dict['state'],
'param_groups': state_dict['param_groups'],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if 'slow_state' not in state_dict: if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied') print('Loading state_dict from optimizer without Lookahead applied.')
state_dict['slow_state'] = defaultdict(dict) state_dict['slow_state'] = defaultdict(dict)
slow_state_new = True
slow_state_dict = { slow_state_dict = {
"state": state_dict["slow_state"], 'state': state_dict['slow_state'],
"param_groups": state_dict["param_groups"], 'param_groups': state_dict['param_groups'], # this is pointless but saves code
}
fast_state_dict = {
"state": state_dict["state"],
"param_groups": state_dict["param_groups"],
} }
super(Lookahead, self).load_state_dict(slow_state_dict) super(Lookahead, self).load_state_dict(slow_state_dict)
self.base_optimizer.load_state_dict(fast_state_dict) self.param_groups = self.base_optimizer.param_groups # make both ref same container
if slow_state_new:
def add_param_group(self, param_group): # reapply defaults to catch missing lookahead specific ones
r"""Add a param group to the :class:`Optimizer` s `param_groups`. for name, default in self.defaults.items():
This can be useful when fine tuning a pre-trained network as frozen for group in self.param_groups:
layers can be made trainable and added to the :class:`Optimizer` as group.setdefault(name, default)
training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along
with group specific optimization options.
"""
param_group['step_counter'] = 0
self.base_optimizer.add_param_group(param_group)

@ -0,0 +1,118 @@
""" Nvidia NovoGrad Optimizer.
Original impl by Nvidia from Jasper example:
- https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
- https://arxiv.org/abs/1905.11286
"""
import torch
from torch.optim.optimizer import Optimizer
import math
class NvNovoGrad(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0.98))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(NvNovoGrad, self).__init__(params, defaults)
def __setstate__(self, state):
super(NvNovoGrad, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(-group['lr'], exp_avg)
return loss

@ -1,5 +1,11 @@
import torch
from torch import optim as optim from torch import optim as optim
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
@ -20,9 +26,10 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
def create_optimizer(args, model, filter_bias_and_bn=True): def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower() opt_lower = args.opt.lower()
weight_decay = args.weight_decay weight_decay = args.weight_decay
if opt_lower == 'adamw' or opt_lower == 'radam': if 'adamw' in opt_lower or 'radam' in opt_lower:
# compensate for the way current AdamW and RAdam optimizers # Compensate for the way current AdamW and RAdam optimizers apply LR to the weight-decay
# apply the weight-decay # I don't believe they follow the paper or original Torch7 impl which schedules weight
# decay based on the ratio of current_lr/initial_lr
weight_decay /= args.lr weight_decay /= args.lr
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay(model, weight_decay) parameters = add_weight_decay(model, weight_decay)
@ -30,12 +37,14 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
else: else:
parameters = model.parameters() parameters = model.parameters()
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_split = opt_lower.split('_') opt_split = opt_lower.split('_')
opt_lower = opt_split[-1] opt_lower = opt_split[-1]
if opt_lower == 'sgd': if opt_lower == 'sgd':
optimizer = optim.SGD( optimizer = optim.SGD(
parameters, lr=args.lr, parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif opt_lower == 'adam': elif opt_lower == 'adam':
optimizer = optim.Adam( optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
@ -61,6 +70,22 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
momentum=args.momentum, weight_decay=weight_decay) momentum=args.momentum, weight_decay=weight_decay)
elif opt_lower == 'novograd': elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedsgd':
optimizer = FusedSGD(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedadamw':
optimizer = FusedAdam(
parameters, lr=args.lr, adam_w_mode=True, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'fusednovograd':
optimizer = FusedNovoGrad(
parameters, lr=args.lr, betas=(0.95, 0.98), weight_decay=weight_decay, eps=args.opt_eps)
else: else:
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
raise ValueError raise ValueError

@ -11,6 +11,12 @@ import operator
import logging import logging
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
try:
from apex import amp
has_apex = True
except ImportError:
amp = None
has_apex = False
from torch import distributed as dist from torch import distributed as dist
@ -50,7 +56,7 @@ class CheckpointSaver:
self.max_history = max_history self.max_history = max_history
assert self.max_history >= 1 assert self.max_history >= 1
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
assert epoch >= 0 assert epoch >= 0
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if (len(self.checkpoint_files) < self.max_history if (len(self.checkpoint_files) < self.max_history
@ -59,7 +65,7 @@ class CheckpointSaver:
self._cleanup_checkpoints(1) self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename) save_path = os.path.join(self.checkpoint_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema, metric) self._save(save_path, model, optimizer, args, epoch, model_ema, metric, use_amp)
self.checkpoint_files.append((save_path, metric)) self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted( self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1], self.checkpoint_files, key=lambda x: x[1],
@ -77,7 +83,7 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None, use_amp=False):
save_state = { save_state = {
'epoch': epoch, 'epoch': epoch,
'arch': args.model, 'arch': args.model,
@ -86,6 +92,8 @@ class CheckpointSaver:
'args': args, 'args': args,
'version': 2, # version < 2 increments epoch before save 'version': 2, # version < 2 increments epoch before save
} }
if use_amp and 'state_dict' in amp.__dict__:
save_state['amp'] = amp.state_dict()
if model_ema is not None: if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema) save_state['state_dict_ema'] = get_state_dict(model_ema)
if metric is not None: if metric is not None:
@ -106,11 +114,11 @@ class CheckpointSaver:
logging.error("Exception '{}' while deleting checkpoint".format(e)) logging.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
assert epoch >= 0 assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename) save_path = os.path.join(self.recovery_dir, filename)
self._save(save_path, model, optimizer, args, epoch, model_ema) self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file)) logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))

@ -38,6 +38,8 @@ parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH'
help='Initialize model from this checkpoint (default: none)') help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)') help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N', parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)') help='number of label classes (default: 1000)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL', parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
@ -189,12 +191,6 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# optionally resume from a checkpoint
optimizer_state = None
resume_epoch = None
if args.resume:
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
if args.num_gpu > 1: if args.num_gpu > 1:
if args.amp: if args.amp:
logging.warning( logging.warning(
@ -205,8 +201,6 @@ def main():
model.cuda() model.cuda()
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
use_amp = False use_amp = False
if has_apex and args.amp: if has_apex and args.amp:
@ -216,6 +210,22 @@ def main():
logging.info('NVIDIA APEX {}. AMP {}.'.format( logging.info('NVIDIA APEX {}. AMP {}.'.format(
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
# optionally resume from a checkpoint
resume_state = {}
resume_epoch = None
if args.resume:
resume_state, resume_epoch = resume_checkpoint(model, args.resume)
if resume_state and not args.no_resume_opt:
if 'optimizer' in resume_state:
if args.local_rank == 0:
logging.info('Restoring Optimizer state from checkpoint')
optimizer.load_state_dict(resume_state['optimizer'])
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
if args.local_rank == 0:
logging.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp'])
resume_state = None
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
@ -363,7 +373,7 @@ def main():
save_metric = eval_metrics[eval_metric] save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint( best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args, model, optimizer, args,
epoch=epoch, model_ema=model_ema, metric=save_metric) epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -456,7 +466,7 @@ def train_epoch(
if saver is not None and args.recovery_interval and ( if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery( saver.save_recovery(
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

Loading…
Cancel
Save