From 9d927a389af0947b6f597bee6aebb66ddc76ec99 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Mar 2019 22:03:42 -0800 Subject: [PATCH] Add adabound, random erasing --- models/random_erasing.py | 61 ++++++++++++++++++++ models/transforms.py | 17 ++++-- optim/adabound.py | 118 +++++++++++++++++++++++++++++++++++++++ train.py | 10 +++- 4 files changed, 198 insertions(+), 8 deletions(-) create mode 100644 models/random_erasing.py create mode 100644 optim/adabound.py diff --git a/models/random_erasing.py b/models/random_erasing.py new file mode 100644 index 00000000..f544525a --- /dev/null +++ b/models/random_erasing.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import + +from torchvision.transforms import * + +from PIL import Image +import random +import math +import numpy as np +import torch + + +class RandomErasing: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + Args: + probability: The probability that the Random Erasing operation will be performed. + sl: Minimum proportion of erased area against input image. + sh: Maximum proportion of erased area against input image. + r1: Minimum aspect ratio of erased area. + mean: Erasing value. + """ + + def __init__( + self, + probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, + per_pixel=False, random=False, + pl=0, ph=1., mean=[0.485, 0.456, 0.406]): + self.probability = probability + self.mean = torch.tensor(mean) + self.sl = sl + self.sh = sh + self.min_aspect = min_aspect + self.pl = pl + self.ph = ph + self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] + self.random = random # per block random, bounded by [pl, ph] + + def __call__(self, img): + if random.random() > self.probability: + return img + + chan, img_h, img_w = img.size() + area = img_h * img_w + for attempt in range(100): + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + c = torch.empty((chan)).uniform_(self.pl, self.ph) if self.random else self.mean[:chan] + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + if self.per_pixel: + img[:, top:top + h, left:left + w] = torch.empty((chan, h, w)).uniform_(self.pl, self.ph) + else: + img[:, top:top + h, left:left + w] = c + return img + + return img diff --git a/models/transforms.py b/models/transforms.py index 49aaca57..94768b48 100644 --- a/models/transforms.py +++ b/models/transforms.py @@ -2,7 +2,7 @@ import torch from torchvision import transforms from PIL import Image import math - +from models.random_erasing import RandomErasing DEFAULT_CROP_PCT = 0.875 @@ -21,7 +21,12 @@ class LeNormalize(object): return tensor -def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)): +def transforms_imagenet_train( + model_name, + img_size=224, + scale=(0.1, 1.0), + color_jitter=(0.4, 0.4, 0.4), + random_erasing=0.4): if 'dpn' in model_name: normalize = transforms.Normalize( mean=IMAGENET_DPN_MEAN, @@ -33,12 +38,14 @@ def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) - return transforms.Compose([ + tfl = [ transforms.RandomResizedCrop(img_size, scale=scale), transforms.RandomHorizontalFlip(), transforms.ColorJitter(*color_jitter), - transforms.ToTensor(), - normalize]) + transforms.ToTensor()] + if random_erasing > 0.: + tfl.append(RandomErasing(random_erasing, per_pixel=True)) + return transforms.Compose(tfl + [normalize]) def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None): diff --git a/optim/adabound.py b/optim/adabound.py new file mode 100644 index 00000000..161a2e86 --- /dev/null +++ b/optim/adabound.py @@ -0,0 +1,118 @@ +import math +import torch +from torch.optim import Optimizer + + +class AdaBound(Optimizer): + """Implements AdaBound algorithm. + It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): Adam learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + final_lr (float, optional): final (SGD) learning rate (default: 0.1) + gamma (float, optional): convergence speed of the bound functions (default: 1e-3) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm + .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: + https://openreview.net/forum?id=Bkg3g2R9FX + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, + eps=1e-8, weight_decay=0, amsbound=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= final_lr: + raise ValueError("Invalid final learning rate: {}".format(final_lr)) + if not 0.0 <= gamma < 1.0: + raise ValueError("Invalid gamma parameter: {}".format(gamma)) + defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, + weight_decay=weight_decay, amsbound=amsbound) + super(AdaBound, self).__init__(params, defaults) + + self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) + + def __setstate__(self, state): + super(AdaBound, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsbound', False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group, base_lr in zip(self.param_groups, self.base_lrs): + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse gradients, please consider SparseAdam instead') + amsbound = group['amsbound'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsbound: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsbound: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsbound: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + # Applies bounds on actual learning rate + # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay + final_lr = group['final_lr'] * group['lr'] / base_lr + lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) + upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) + step_size = torch.full_like(denom, step_size) + step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) + + p.data.add_(-step_size) + + return loss diff --git a/train.py b/train.py index c8432fa7..df7d91df 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,7 @@ from datetime import datetime from dataset import Dataset from models import model_factory, transforms_imagenet_eval, transforms_imagenet_train from utils import * -from optim import nadam +from optim import nadam, adabound import scheduler import torch @@ -166,6 +166,10 @@ def main(): elif args.opt.lower() == 'nadam': optimizer = nadam.Nadam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'adabound': + optimizer = adabound.AdaBound( + model.parameters(), lr=args.lr / 1000, weight_decay=args.weight_decay, eps=args.opt_eps, + final_lr=args.lr) elif args.opt.lower() == 'adadelta': optimizer = optim.Adadelta( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) @@ -185,12 +189,12 @@ def main(): lr_scheduler = scheduler.CosineLRScheduler( optimizer, t_initial=args.epochs, - t_mul=1.5, + t_mul=1.0, lr_min=1e-5, decay_rate=args.decay_rate, warmup_lr_init=1e-4, warmup_t=3, - cycle_limit=3, + cycle_limit=1, t_in_epochs=True, ) num_epochs = lr_scheduler.get_cycle_length() + 10