Add RAdam, NovoGrad, Lookahead, and AdamW optimizers, a few ResNet tweaks and scheduler factory tweak.

* Add some of the trendy new optimizers. Decent results but not clearly better than the standards.
* Can create a None scheduler for constant LR
* ResNet defaults to zero_init of last BN in residual
* add resnet50d config
pull/31/head
Ross Wightman 5 years ago
parent f37e633e9b
commit fac58f609a

@ -13,7 +13,10 @@ The work of many others is present here. I've tried to make sure all source mate
* [Myself](https://github.com/rwightman/pytorch-dpn-pretrained) * [Myself](https://github.com/rwightman/pytorch-dpn-pretrained)
* LR scheduler ideas from [AllenNLP](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers), [FAIRseq](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler), and SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983) * LR scheduler ideas from [AllenNLP](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers), [FAIRseq](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler), and SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983)
* Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896) * Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896)
* Optimizers:
* RAdam by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) (https://arxiv.org/abs/1908.03265)
* NovoGrad by [Masashi Kimura](https://github.com/convergence-lab/novograd) (https://arxiv.org/abs/1905.11286)
* Lookahead adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610)
## Models ## Models
I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors. I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors.

@ -44,6 +44,9 @@ default_cfgs = {
'resnet50': _cfg( 'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnet50d': _cfg(
url='',
interpolation='bicubic'),
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
@ -259,7 +262,7 @@ class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
cardinality=1, base_width=64, stem_width=64, deep_stem=False, cardinality=1, base_width=64, stem_width=64, deep_stem=False,
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'): norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', zero_init_last_bn=True):
self.num_classes = num_classes self.num_classes = num_classes
self.inplanes = stem_width * 2 if deep_stem else 64 self.inplanes = stem_width * 2 if deep_stem else 64
self.cardinality = cardinality self.cardinality = cardinality
@ -296,10 +299,15 @@ class ResNet(nn.Module):
self.num_features = 512 * block.expansion self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
for m in self.modules(): last_bn_name = 'bn3' if 'Bottleneck' in block.__name__ else 'bn2'
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
# Initialize weight/gamma of last BN in each residual block to zero
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.) nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.) nn.init.constant_(m.bias, 0.)
@ -434,6 +442,20 @@ def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model return model
@register_model
def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50-D model.
"""
default_cfg = default_cfgs['resnet50d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], stem_width=32, deep_stem=True, avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model @register_model
def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-101 model. """Constructs a ResNet-101 model.

@ -1,3 +1,7 @@
from .nadam import Nadam from .nadam import Nadam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .adamw import AdamW
from .radam import RAdam
from .novograd import NovoGrad
from .lookahead import Lookahead
from .optim_factory import create_optimizer from .optim_factory import create_optimizer

@ -0,0 +1,117 @@
""" AdamW Optimizer
Impl copied from PyTorch master
"""
import math
import torch
from torch.optim.optimizer import Optimizer
class AdamW(Optimizer):
r"""Implements AdamW algorithm.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])
# Perform optimization step
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss

@ -0,0 +1,88 @@
""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
"""
import torch
from torch.optim.optimizer import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
self.alpha = alpha
self.k = k
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.state = defaultdict(dict)
for group in self.param_groups:
group["step_counter"] = 0
def update_slow_weights(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if "slow_buffer" not in param_state:
param_state["slow_buffer"] = torch.empty_like(fast_p.data)
param_state["slow_buffer"].copy_(fast_p.data)
slow = param_state["slow_buffer"]
slow.add_(self.alpha, fast_p.data - slow)
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow_weights(group)
def step(self, closure=None):
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group['step_counter'] += 1
if group['step_counter'] % self.k == 0:
self.update_slow_weights(group)
return loss
def state_dict(self):
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}
def load_state_dict(self, state_dict):
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied')
state_dict['slow_state'] = defaultdict(dict)
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict["param_groups"],
}
fast_state_dict = {
"state": state_dict["state"],
"param_groups": state_dict["param_groups"],
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.base_optimizer.load_state_dict(fast_state_dict)
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen
layers can be made trainable and added to the :class:`Optimizer` as
training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along
with group specific optimization options.
"""
param_group['step_counter'] = 0
self.base_optimizer.add_param_group(param_group)

@ -0,0 +1,77 @@
"""NovoGrad Optimizer.
Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd
Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
- https://arxiv.org/abs/1905.11286
"""
import torch
from torch.optim.optimizer import Optimizer
import math
class NovoGrad(Optimizer):
def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(NovoGrad, self).__init__(params, defaults)
self._lr = lr
self._beta1 = betas[0]
self._beta2 = betas[1]
self._eps = eps
self._wd = weight_decay
self._grad_averaging = grad_averaging
self._momentum_initialized = False
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
if not self._momentum_initialized:
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('NovoGrad does not support sparse gradients')
v = torch.norm(grad)**2
m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data
state['step'] = 0
state['v'] = v
state['m'] = m
state['grad_ema'] = None
self._momentum_initialized = True
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
state = self.state[p]
state['step'] += 1
step, v, m = state['step'], state['v'], state['m']
grad_ema = state['grad_ema']
grad = p.grad.data
g2 = torch.norm(grad)**2
grad_ema = g2 if grad_ema is None else grad_ema * \
self._beta2 + g2 * (1. - self._beta2)
grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps)
if self._grad_averaging:
grad *= (1. - self._beta1)
g2 = torch.norm(grad)**2
v = self._beta2*v + (1. - self._beta2)*g2
m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data)
bias_correction1 = 1 - self._beta1 ** step
bias_correction2 = 1 - self._beta2 ** step
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
state['v'], state['m'] = v, m
state['grad_ema'] = grad_ema
p.data.add_(-step_size, m)
return loss

@ -1,5 +1,5 @@
from torch import optim as optim from torch import optim as optim
from timm.optim import Nadam, RMSpropTF from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead
def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
@ -18,35 +18,55 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
def create_optimizer(args, model, filter_bias_and_bn=True): def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay weight_decay = args.weight_decay
if opt_lower == 'adamw' or opt_lower == 'radam':
# compensate for the way current AdamW and RAdam optimizers
# apply the weight-decay
weight_decay /= args.lr
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay(model, weight_decay) parameters = add_weight_decay(model, weight_decay)
weight_decay = 0. weight_decay = 0.
else: else:
parameters = model.parameters() parameters = model.parameters()
if args.opt.lower() == 'sgd': opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd':
optimizer = optim.SGD( optimizer = optim.SGD(
parameters, lr=args.lr, parameters, lr=args.lr,
momentum=args.momentum, weight_decay=weight_decay, nesterov=True) momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
elif args.opt.lower() == 'adam': elif opt_lower == 'adam':
optimizer = optim.Adam( optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam': elif opt_lower == 'adamw':
optimizer = AdamW(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'nadam':
optimizer = Nadam( optimizer = Nadam(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adadelta': elif opt_lower == 'radam':
optimizer = RAdam(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta( optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop': elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop( optimizer = optim.RMSprop(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=weight_decay) momentum=args.momentum, weight_decay=weight_decay)
elif args.opt.lower() == 'rmsproptf': elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF( optimizer = RMSpropTF(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=weight_decay) momentum=args.momentum, weight_decay=weight_decay)
elif opt_lower == 'novograd':
optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
else: else:
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
raise ValueError raise ValueError
if len(opt_split) > 1:
if opt_split[0] == 'lookahead':
optimizer = Lookahead(optimizer)
return optimizer return optimizer

@ -0,0 +1,152 @@
"""RAdam Optimizer.
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
"""
import math
import torch
from torch.optim.optimizer import Optimizer, required
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else:
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss

@ -5,6 +5,7 @@ from .step_lr import StepLRScheduler
def create_scheduler(args, optimizer): def create_scheduler(args, optimizer):
num_epochs = args.epochs num_epochs = args.epochs
lr_scheduler = None
#FIXME expose cycle parms of the scheduler config to arguments #FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine': if args.sched == 'cosine':
lr_scheduler = CosineLRScheduler( lr_scheduler = CosineLRScheduler(
@ -31,7 +32,7 @@ def create_scheduler(args, optimizer):
t_in_epochs=True, t_in_epochs=True,
) )
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
else: elif args.sched == 'step':
lr_scheduler = StepLRScheduler( lr_scheduler = StepLRScheduler(
optimizer, optimizer,
decay_t=args.decay_epochs, decay_t=args.decay_epochs,

@ -251,7 +251,7 @@ def main():
start_epoch = args.start_epoch start_epoch = args.start_epoch
elif resume_epoch is not None: elif resume_epoch is not None:
start_epoch = resume_epoch start_epoch = resume_epoch
if start_epoch > 0: if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch) lr_scheduler.step(start_epoch)
if args.local_rank == 0: if args.local_rank == 0:
@ -285,6 +285,8 @@ def main():
collate_fn=collate_fn, collate_fn=collate_fn,
) )
eval_dir = os.path.join(args.data, 'val')
if not os.path.isdir(eval_dir):
eval_dir = os.path.join(args.data, 'validation') eval_dir = os.path.join(args.data, 'validation')
if not os.path.isdir(eval_dir): if not os.path.isdir(eval_dir):
logging.error('Validation folder does not exist at: {}'.format(eval_dir)) logging.error('Validation folder does not exist at: {}'.format(eval_dir))
@ -390,8 +392,7 @@ def train_epoch(
last_batch = batch_idx == last_idx last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end) data_time_m.update(time.time() - end)
if not args.prefetcher: if not args.prefetcher:
input = input.cuda() input, target = input.cuda(), target.cuda()
target = target.cuda()
if args.mixup > 0.: if args.mixup > 0.:
lam = 1. lam = 1.
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
@ -461,6 +462,10 @@ def train_epoch(
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time() end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)]) return OrderedDict([('loss', losses_m.avg)])

Loading…
Cancel
Save