From 31453b039ef1217171cfef128ad6ca4e595787ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Nov 2019 13:00:24 -0800 Subject: [PATCH] Update Auto/RandAugment comments, README, more. * Add a weighted choice option for RandAugment * Adjust magnitude noise/std naming, config --- README.md | 1 + timm/data/auto_augment.py | 127 ++++++++++++++++++++++++++++++-------- timm/data/transforms.py | 2 +- 3 files changed, 102 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 615e68b1..bb7f4206 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Several (less common) features that I often utilize in my projects are included. * Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc) * Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing * An inference script that dumps output to CSV is provided as an example +* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py) ## Results diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 9d711cd1..d730c266 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,7 +1,7 @@ -""" Auto Augment +""" AutoAugment and RandAugment Implementation adapted from: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py -Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172 +Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 Hacked together by Ross Wightman """ @@ -288,18 +288,18 @@ class AutoAugmentOp: resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, ) - # If magnitude_noise is > 0, we introduce some randomness + # If magnitude_std is > 0, we introduce some randomness # in the usually fixed policy and sample magnitude from a normal distribution - # with mean `magnitude` and std-dev of `magnitude_noise`. + # with mean `magnitude` and std-dev of `magnitude_std`. # NOTE This is my own hack, being tested, not in papers or reference impls. - self.magnitude_noise = self.hparams.get('magnitude_noise', 0) + self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): if random.random() > self.prob: return img magnitude = self.magnitude - if self.magnitude_noise and self.magnitude_noise > 0: - magnitude = random.gauss(magnitude, self.magnitude_noise) + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) @@ -464,16 +464,32 @@ class AutoAugment: def auto_augment_transform(config_str, hparams): + """ + Create a AutoAugment transform + + :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). + The remaining sections, not order sepecific determine + 'mstd' - float std deviation of magnitude noise applied + Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 + + :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme + + :return: A PyTorch compatible Transform + """ config = config_str.split('-') policy_name = config[0] config = config[1:] for c in config: cs = re.split(r'(\d.*)', c) - if len(cs) >= 2: - key, val = cs[:2] - if key == 'noise': - # noise param injected via hparams for now - hparams.setdefault('magnitude_noise', float(val)) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + else: + assert False, 'Unknown AutoAugment config section' aa_policy = auto_augment_policy(policy_name, hparams=hparams) return AutoAugment(aa_policy) @@ -498,6 +514,36 @@ _RAND_TRANSFORMS = [ ] +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + def rand_augment_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS @@ -506,33 +552,60 @@ def rand_augment_ops(magnitude=10, hparams=None, transforms=None): class RandAugment: - def __init__(self, ops, num_layers=2): + def __init__(self, ops, num_layers=2, choice_weights=None): self.ops = ops self.num_layers = num_layers + self.choice_weights = choice_weights def __call__(self, img): - for _ in range(self.num_layers): - op = random.choice(self.ops) + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: img = op(img) return img def rand_augment_transform(config_str, hparams): - magnitude = 10 - num_layers = 2 + """ + Create a RandAugment transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice config = config_str.split('-') assert config[0] == 'rand' config = config[1:] for c in config: cs = re.split(r'(\d.*)', c) - if len(cs) >= 2: - key, val = cs[:2] - if key == 'noise': - # noise param injected via hparams for now - hparams.setdefault('magnitude_noise', float(val)) - elif key == 'm': - magnitude = int(val) - elif key == 'n': - num_layers = int(val) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) - return RandAugment(ra_ops, num_layers) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index ac03b098..41f2a63e 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -190,7 +190,7 @@ def transforms_imagenet_train( ) if interpolation and interpolation != 'random': aa_params['interpolation'] = _pil_interp(interpolation) - if 'rand' in auto_augment: + if auto_augment.startswith('rand'): tfl += [rand_augment_transform(auto_augment, aa_params)] else: tfl += [auto_augment_transform(auto_augment, aa_params)]