From 6946281fde9536e123298602e5cb8835820dc965 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 28 Jul 2019 10:27:26 -0700 Subject: [PATCH 1/3] Experimenting with random erasing changes --- timm/data/random_erasing.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index e66f7b95..2a105128 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -43,6 +43,7 @@ class RandomErasing: self.sl = sl self.sh = sh self.min_aspect = min_aspect + self.max_count = 8 mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -58,18 +59,20 @@ class RandomErasing: if random.random() > self.probability: return area = img_h * img_w - for attempt in range(100): - target_area = random.uniform(self.sl, self.sh) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, self.rand_color, (chan, h, w), - dtype=dtype, device=self.device) - break + count = random.randint(1, self.max_count) + for _ in range(count): + for attempt in range(10): + target_area = random.uniform(self.sl / count, self.sh / count) * area + aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, self.rand_color, (chan, h, w), + dtype=dtype, device=self.device) + break def __call__(self, input): if len(input.size()) == 3: From 66634d2200efd5cb3f0f8ebb45c7628c57549fe3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 12 Aug 2019 15:53:40 -0700 Subject: [PATCH 2/3] Add support to split random erasing blocks into randomly selected number with --recount arg. Fix random selection of aspect ratios. --- timm/data/loader.py | 5 ++++- timm/data/random_erasing.py | 15 ++++++++++----- timm/data/transforms.py | 23 +++++++++++++++-------- train.py | 3 +++ 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 1198d5e5..815a19da 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -20,6 +20,7 @@ class PrefetchLoader: loader, rand_erase_prob=0., rand_erase_mode='const', + rand_erase_count=1, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, fp16=False): @@ -32,7 +33,7 @@ class PrefetchLoader: self.std = self.std.half() if rand_erase_prob > 0.: self.random_erasing = RandomErasing( - probability=rand_erase_prob, mode=rand_erase_mode) + probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count) else: self.random_erasing = None @@ -94,6 +95,7 @@ def create_loader( use_prefetcher=True, rand_erase_prob=0., rand_erase_mode='const', + rand_erase_count=1, color_jitter=0.4, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, @@ -160,6 +162,7 @@ def create_loader( loader, rand_erase_prob=rand_erase_prob if is_training else 0., rand_erase_mode=rand_erase_mode, + rand_erase_count=rand_erase_count, mean=mean, std=std, fp16=fp16) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 2a105128..e944f22c 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -33,17 +33,20 @@ class RandomErasing: 'const' - erase block is constant color of 0 for all channels 'rand' - erase block is same per-cannel random (normal) color 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. """ def __init__( self, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - mode='const', device='cuda'): + mode='const', max_count=1, device='cuda'): self.probability = probability self.sl = sl self.sh = sh self.min_aspect = min_aspect - self.max_count = 8 + self.min_count = 1 + self.max_count = max_count mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -59,11 +62,13 @@ class RandomErasing: if random.random() > self.probability: return area = img_h * img_w - count = random.randint(1, self.max_count) + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) for _ in range(count): for attempt in range(10): - target_area = random.uniform(self.sl / count, self.sh / count) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + target_area = random.uniform(self.sl, self.sh) * area / count + log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect)) + aspect_ratio = math.exp(random.uniform(*log_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 13a6ff01..93796a04 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -107,24 +107,31 @@ class RandomResizedCropAndInterpolation(object): for attempt in range(10): target_area = random.uniform(*scale) * area - aspect_ratio = random.uniform(*ratio) + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) - if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio): - w, h = h, w - if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w - # Fallback - w = min(img.size[0], img.size[1]) - i = (img.size[1] - w) // 2 + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 - return i, j, w, w + return i, j, h, w def __call__(self, img): """ diff --git a/train.py b/train.py index 51006a0d..c795927d 100644 --- a/train.py +++ b/train.py @@ -91,6 +91,8 @@ parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') +parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', @@ -273,6 +275,7 @@ def main(): use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, + rand_erase_count=args.recount, color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], From fac58f609a9a96ef1b273e7b7ab8380ba3743f54 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 28 Aug 2019 00:14:10 -0700 Subject: [PATCH 3/3] Add RAdam, NovoGrad, Lookahead, and AdamW optimizers, a few ResNet tweaks and scheduler factory tweak. * Add some of the trendy new optimizers. Decent results but not clearly better than the standards. * Can create a None scheduler for constant LR * ResNet defaults to zero_init of last BN in residual * add resnet50d config --- README.md | 5 +- timm/models/resnet.py | 28 ++++- timm/optim/__init__.py | 4 + timm/optim/adamw.py | 117 +++++++++++++++++++++ timm/optim/lookahead.py | 88 ++++++++++++++++ timm/optim/novograd.py | 77 ++++++++++++++ timm/optim/optim_factory.py | 34 +++++-- timm/optim/radam.py | 152 ++++++++++++++++++++++++++++ timm/scheduler/scheduler_factory.py | 3 +- train.py | 17 ++-- 10 files changed, 507 insertions(+), 18 deletions(-) create mode 100644 timm/optim/adamw.py create mode 100644 timm/optim/lookahead.py create mode 100644 timm/optim/novograd.py create mode 100644 timm/optim/radam.py diff --git a/README.md b/README.md index 5ab92987..66cea1d5 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,10 @@ The work of many others is present here. I've tried to make sure all source mate * [Myself](https://github.com/rwightman/pytorch-dpn-pretrained) * LR scheduler ideas from [AllenNLP](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers), [FAIRseq](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler), and SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983) * Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896) - +* Optimizers: + * RAdam by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) (https://arxiv.org/abs/1908.03265) + * NovoGrad by [Masashi Kimura](https://github.com/convergence-lab/novograd) (https://arxiv.org/abs/1905.11286) + * Lookahead adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610) ## Models I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors. diff --git a/timm/models/resnet.py b/timm/models/resnet.py index eff83066..89649f7c 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -44,6 +44,9 @@ default_cfgs = { 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', interpolation='bicubic'), + 'resnet50d': _cfg( + url='', + interpolation='bicubic'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), @@ -259,7 +262,7 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, cardinality=1, base_width=64, stem_width=64, deep_stem=False, block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, - norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'): + norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', zero_init_last_bn=True): self.num_classes = num_classes self.inplanes = stem_width * 2 if deep_stem else 64 self.cardinality = cardinality @@ -296,11 +299,16 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - for m in self.modules(): + last_bn_name = 'bn3' if 'Bottleneck' in block.__name__ else 'bn2' + for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) + if zero_init_last_bn and 'layer' in n and last_bn_name in n: + # Initialize weight/gamma of last BN in each residual block to zero + nn.init.constant_(m.weight, 0.) + else: + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, @@ -434,6 +442,20 @@ def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-50-D model. + """ + default_cfg = default_cfgs['resnet50d'] + model = ResNet( + Bottleneck, [3, 4, 6, 3], stem_width=32, deep_stem=True, avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnet101(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-101 model. diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index fe97041d..3213cd68 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -1,3 +1,7 @@ from .nadam import Nadam from .rmsprop_tf import RMSpropTF +from .adamw import AdamW +from .radam import RAdam +from .novograd import NovoGrad +from .lookahead import Lookahead from .optim_factory import create_optimizer diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py new file mode 100644 index 00000000..66f9a959 --- /dev/null +++ b/timm/optim/adamw.py @@ -0,0 +1,117 @@ +""" AdamW Optimizer +Impl copied from PyTorch master +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py new file mode 100644 index 00000000..cc1fb495 --- /dev/null +++ b/timm/optim/lookahead.py @@ -0,0 +1,88 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 +""" +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + self.alpha = alpha + self.k = k + self.base_optimizer = base_optimizer + self.param_groups = self.base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.state = defaultdict(dict) + for group in self.param_groups: + group["step_counter"] = 0 + + def update_slow_weights(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self.state[fast_p] + if "slow_buffer" not in param_state: + param_state["slow_buffer"] = torch.empty_like(fast_p.data) + param_state["slow_buffer"].copy_(fast_p.data) + slow = param_state["slow_buffer"] + slow.add_(self.alpha, fast_p.data - slow) + fast_p.data.copy_(slow) + + def sync_lookahead(self): + for group in self.param_groups: + self.update_slow_weights(group) + + def step(self, closure=None): + loss = self.base_optimizer.step(closure) + for group in self.param_groups: + group['step_counter'] += 1 + if group['step_counter'] % self.k == 0: + self.update_slow_weights(group) + return loss + + def state_dict(self): + fast_state_dict = self.base_optimizer.state_dict() + slow_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + fast_state = fast_state_dict["state"] + param_groups = fast_state_dict["param_groups"] + return { + "state": fast_state, + "slow_state": slow_state, + "param_groups": param_groups, + } + + def load_state_dict(self, state_dict): + if 'slow_state' not in state_dict: + print('Loading state_dict from optimizer without Lookahead applied') + state_dict['slow_state'] = defaultdict(dict) + slow_state_dict = { + "state": state_dict["slow_state"], + "param_groups": state_dict["param_groups"], + } + fast_state_dict = { + "state": state_dict["state"], + "param_groups": state_dict["param_groups"], + } + super(Lookahead, self).load_state_dict(slow_state_dict) + self.base_optimizer.load_state_dict(fast_state_dict) + + def add_param_group(self, param_group): + r"""Add a param group to the :class:`Optimizer` s `param_groups`. + This can be useful when fine tuning a pre-trained network as frozen + layers can be made trainable and added to the :class:`Optimizer` as + training progresses. + Args: + param_group (dict): Specifies what Tensors should be optimized along + with group specific optimization options. + """ + param_group['step_counter'] = 0 + self.base_optimizer.add_param_group(param_group) diff --git a/timm/optim/novograd.py b/timm/optim/novograd.py new file mode 100644 index 00000000..4137c6aa --- /dev/null +++ b/timm/optim/novograd.py @@ -0,0 +1,77 @@ +"""NovoGrad Optimizer. +Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NovoGrad(Optimizer): + def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(NovoGrad, self).__init__(params, defaults) + self._lr = lr + self._beta1 = betas[0] + self._beta2 = betas[1] + self._eps = eps + self._wd = weight_decay + self._grad_averaging = grad_averaging + + self._momentum_initialized = False + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + if not self._momentum_initialized: + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('NovoGrad does not support sparse gradients') + + v = torch.norm(grad)**2 + m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data + state['step'] = 0 + state['v'] = v + state['m'] = m + state['grad_ema'] = None + self._momentum_initialized = True + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + state['step'] += 1 + + step, v, m = state['step'], state['v'], state['m'] + grad_ema = state['grad_ema'] + + grad = p.grad.data + g2 = torch.norm(grad)**2 + grad_ema = g2 if grad_ema is None else grad_ema * \ + self._beta2 + g2 * (1. - self._beta2) + grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) + + if self._grad_averaging: + grad *= (1. - self._beta1) + + g2 = torch.norm(grad)**2 + v = self._beta2*v + (1. - self._beta2)*g2 + m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) + bias_correction1 = 1 - self._beta1 ** step + bias_correction2 = 1 - self._beta2 ** step + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['v'], state['m'] = v, m + state['grad_ema'] = grad_ema + p.data.add_(-step_size, m) + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 7fe3e1e4..c51bdf20 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,5 +1,5 @@ from torch import optim as optim -from timm.optim import Nadam, RMSpropTF +from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, Lookahead def add_weight_decay(model, weight_decay=1e-5, skip_list=()): @@ -18,35 +18,55 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()): def create_optimizer(args, model, filter_bias_and_bn=True): + opt_lower = args.opt.lower() weight_decay = args.weight_decay + if opt_lower == 'adamw' or opt_lower == 'radam': + # compensate for the way current AdamW and RAdam optimizers + # apply the weight-decay + weight_decay /= args.lr if weight_decay and filter_bias_and_bn: parameters = add_weight_decay(model, weight_decay) weight_decay = 0. else: parameters = model.parameters() - if args.opt.lower() == 'sgd': + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'sgd': optimizer = optim.SGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) - elif args.opt.lower() == 'adam': + elif opt_lower == 'adam': optimizer = optim.Adam( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'nadam': + elif opt_lower == 'adamw': + optimizer = AdamW( + parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'nadam': optimizer = Nadam( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'adadelta': + elif opt_lower == 'radam': + optimizer = RAdam( + parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) + elif opt_lower == 'adadelta': optimizer = optim.Adadelta( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'rmsprop': + elif opt_lower == 'rmsprop': optimizer = optim.RMSprop( parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) - elif args.opt.lower() == 'rmsproptf': + elif opt_lower == 'rmsproptf': optimizer = RMSpropTF( parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=weight_decay) + elif opt_lower == 'novograd': + optimizer = NovoGrad(parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + return optimizer diff --git a/timm/optim/radam.py b/timm/optim/radam.py new file mode 100644 index 00000000..9987a334 --- /dev/null +++ b/timm/optim/radam.py @@ -0,0 +1,152 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer, required + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 8f1032a1..80d37b37 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -5,6 +5,7 @@ from .step_lr import StepLRScheduler def create_scheduler(args, optimizer): num_epochs = args.epochs + lr_scheduler = None #FIXME expose cycle parms of the scheduler config to arguments if args.sched == 'cosine': lr_scheduler = CosineLRScheduler( @@ -31,7 +32,7 @@ def create_scheduler(args, optimizer): t_in_epochs=True, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs - else: + elif args.sched == 'step': lr_scheduler = StepLRScheduler( optimizer, decay_t=args.decay_epochs, diff --git a/train.py b/train.py index c795927d..9c75e050 100644 --- a/train.py +++ b/train.py @@ -251,7 +251,7 @@ def main(): start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch - if start_epoch > 0: + if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: @@ -285,10 +285,12 @@ def main(): collate_fn=collate_fn, ) - eval_dir = os.path.join(args.data, 'validation') + eval_dir = os.path.join(args.data, 'val') if not os.path.isdir(eval_dir): - logging.error('Validation folder does not exist at: {}'.format(eval_dir)) - exit(1) + eval_dir = os.path.join(args.data, 'validation') + if not os.path.isdir(eval_dir): + logging.error('Validation folder does not exist at: {}'.format(eval_dir)) + exit(1) dataset_eval = Dataset(eval_dir) loader_eval = create_loader( @@ -390,8 +392,7 @@ def train_epoch( last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: - input = input.cuda() - target = target.cuda() + input, target = input.cuda(), target.cuda() if args.mixup > 0.: lam = 1. if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: @@ -461,6 +462,10 @@ def train_epoch( lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() + # end for + + if hasattr(optimizer, 'sync_lookahead'): + optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)])