From e3b2f5be0afaa6f2dd71be17a689b44b126a3ce9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Dec 2022 16:25:50 -0800 Subject: [PATCH] Add 3-Augment support to auto_augment.py, clean up weighted choice handling, and allow adjust per op prob via arg string --- timm/data/auto_augment.py | 325 ++++++++++++++++++++++---------- timm/data/transforms_factory.py | 10 +- 2 files changed, 236 insertions(+), 99 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 1b51ccb4..e461f67c 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,4 +1,4 @@ -""" AutoAugment, RandAugment, and AugMix for PyTorch +""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch This code implements the searched ImageNet policies with various tweaks and improvements and does not include any of the search code. @@ -9,18 +9,24 @@ AA and RA Implementation adapted from: AugMix adapted from: https://github.com/google-research/augmix +3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md + Papers: AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 + 3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118 Hacked together by / Copyright 2019, Ross Wightman """ import random import math import re -from PIL import Image, ImageOps, ImageEnhance, ImageChops +from functools import partial +from typing import Dict, List, Optional, Union + +from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter import PIL import numpy as np @@ -175,6 +181,24 @@ def sharpness(img, factor, **__): return ImageEnhance.Sharpness(img).enhance(factor) +def gaussian_blur(img, factor, **__): + img = img.filter(ImageFilter.GaussianBlur(radius=factor)) + return img + + +def gaussian_blur_rand(img, factor, **__): + radius_min = 0.1 + radius_max = 2.0 + img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) + return img + + +def desaturate(img, factor, **_): + factor = min(1., max(0., 1. - factor)) + # enhance factor 0 = grayscale, 1.0 = no-change + return ImageEnhance.Color(img).enhance(factor) + + def _randomly_negate(v): """With 50% prob, negate the value""" return -v if random.random() > 0.5 else v @@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams): return level, +def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True): + level = (level / _LEVEL_DENOM) + min_val + (max_val - min_val) * level + if clamp: + level = min(min_val, max(max_val, level)) + return level, + + def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _LEVEL_DENOM) * 0.3 @@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams): def _solarize_level_to_arg(level, _hparams): # range [0, 256] # intensity/severity of augmentation decreases with level - return int((level / _LEVEL_DENOM) * 256), + return min(256, int((level / _LEVEL_DENOM) * 256)), def _solarize_increasing_level_to_arg(level, _hparams): @@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams): def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] - return int((level / _LEVEL_DENOM) * 110), + return min(128, int((level / _LEVEL_DENOM) * 110)), LEVEL_TO_ARG = { @@ -286,6 +318,9 @@ LEVEL_TO_ARG = { 'TranslateY': _translate_abs_level_to_arg, 'TranslateXRel': _translate_rel_level_to_arg, 'TranslateYRel': _translate_rel_level_to_arg, + 'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0), + 'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0), + 'GaussianBlurRand': _minmax_level_to_arg, } @@ -314,6 +349,9 @@ NAME_TO_OP = { 'TranslateY': translate_y_abs, 'TranslateXRel': translate_x_rel, 'TranslateYRel': translate_y_rel, + 'Desaturate': desaturate, + 'GaussianBlur': gaussian_blur, + 'GaussianBlurRand': gaussian_blur_rand, } @@ -347,6 +385,7 @@ class AugmentOp: if self.magnitude_std > 0: # magnitude randomization enabled if self.magnitude_std == float('inf'): + # inf == uniform sampling magnitude = random.uniform(0, magnitude) elif self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) @@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams): return pc +def auto_augment_policy_3a(hparams): + policy = [ + [('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude + [('Desaturate', 1.0, 10)], # grayscale at 10 magnitude + [('GaussianBlurRand', 1.0, 10)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + def auto_augment_policy(name='v0', hparams=None): hparams = hparams or _HPARAMS_DEFAULT if name == 'original': @@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None): return auto_augment_policy_v0(hparams) elif name == 'v0r': return auto_augment_policy_v0r(hparams) + elif name == '3a': + return auto_augment_policy_3a(hparams) else: assert False, 'Unknown AA policy (%s)' % name @@ -534,19 +585,23 @@ class AutoAugment: return fs -def auto_augment_transform(config_str, hparams): +def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None): """ Create a AutoAugment transform - :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - The remaining sections, not order sepecific determine - 'mstd' - float std deviation of magnitude noise applied - Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 + Args: + config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by + dashes ('-'). + The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme + The remaining sections: + 'mstd' - float std deviation of magnitude noise applied + Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 - :return: A PyTorch compatible Transform + hparams: Other hparams (kwargs) for the AutoAugmentation scheme + + Returns: + A PyTorch compatible Transform """ config = config_str.split('-') policy_name = config[0] @@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [ ] +_RAND_3A = [ + 'SolarizeIncreasing', + 'Desaturate', + 'GaussianBlur', +] + + +_RAND_CHOICE_3A = { + 'SolarizeIncreasing': 6, + 'Desaturate': 6, + 'GaussianBlur': 6, + 'Rotate': 3, + 'ShearX': 2, + 'ShearY': 2, + 'PosterizeIncreasing': 1, + 'AutoContrast': 1, + 'ColorIncreasing': 1, + 'SharpnessIncreasing': 1, + 'ContrastIncreasing': 1, + 'BrightnessIncreasing': 1, + 'Equalize': 1, + 'Invert': 1, +} + # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { - 'Rotate': 0.3, - 'ShearX': 0.2, - 'ShearY': 0.2, - 'TranslateXRel': 0.1, - 'TranslateYRel': 0.1, - 'Color': .025, - 'Sharpness': 0.025, - 'AutoContrast': 0.025, - 'Solarize': .005, - 'SolarizeAdd': .005, - 'Contrast': .005, - 'Brightness': .005, - 'Equalize': .005, - 'Posterize': 0, - 'Invert': 0, + 'Rotate': 3, + 'ShearX': 2, + 'ShearY': 2, + 'TranslateXRel': 1, + 'TranslateYRel': 1, + 'ColorIncreasing': .25, + 'SharpnessIncreasing': 0.25, + 'AutoContrast': 0.25, + 'SolarizeIncreasing': .05, + 'SolarizeAdd': .05, + 'ContrastIncreasing': .05, + 'BrightnessIncreasing': .05, + 'Equalize': .05, + 'PosterizeIncreasing': 0.05, + 'Invert': 0.05, } -def _select_rand_weights(weight_idx=0, transforms=None): - transforms = transforms or _RAND_TRANSFORMS - assert weight_idx == 0 # only one set of weights currently - rand_weights = _RAND_CHOICE_WEIGHTS_0 - probs = [rand_weights[k] for k in transforms] - probs /= np.sum(probs) - return probs +def _get_weighted_transforms(transforms: Dict): + transforms, probs = list(zip(*transforms.items())) + probs = np.array(probs) + probs = probs / np.sum(probs) + return transforms, probs + +def rand_augment_choices(name: str, increasing=True): + if name == 'weights': + return _RAND_CHOICE_WEIGHTS_0 + elif name == '3aw': + return _RAND_CHOICE_3A + elif name == '3a': + return _RAND_3A + else: + return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS -def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + +def rand_augment_ops( + magnitude: Union[int, float] = 10, + prob: float = 0.5, + hparams: Optional[Dict] = None, + transforms: Optional[Union[Dict, List]] = None, +): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS return [AugmentOp( - name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] class RandAugment: @@ -648,11 +741,16 @@ class RandAugment: self.ops = ops self.num_layers = num_layers self.choice_weights = choice_weights + print(self.ops, self.choice_weights) def __call__(self, img): # no replacement when using weighted choice ops = np.random.choice( - self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) for op in ops: img = op(img) return img @@ -665,61 +763,84 @@ class RandAugment: return fs -def rand_augment_transform(config_str, hparams): +def rand_augment_transform( + config_str: str, + hparams: Optional[Dict] = None, + transforms: Optional[Union[str, Dict, List]] = None, +): """ Create a RandAugment transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude of rand augment - 'n' - integer num layers (number of transform ops selected per image) - 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) - 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) - 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) - 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) - Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 - 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 - - :param hparams: Other hparams (kwargs) for the RandAugmentation scheme - - :return: A PyTorch compatible Transform + Args: + config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated + by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). + The remaining sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'p' - float probability of applying each layer (default 0.5) + 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) + 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + 't' - str name of transform set to use + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2 + + hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme + + Returns: + A PyTorch compatible Transform """ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) num_layers = 2 # default to 2 ops per image - weight_idx = None # default to no probability weights for op choice - transforms = _RAND_TRANSFORMS + increasing = False + prob = 0.5 config = config_str.split('-') assert config[0] == 'rand' config = config[1:] for c in config: - cs = re.split(r'(\d.*)', c) - if len(cs) < 2: - continue - key, val = cs[:2] - if key == 'mstd': - # noise param / randomization of magnitude values - mstd = float(val) - if mstd > 100: - # use uniform sampling in 0 to magnitude if mstd is > 100 - mstd = float('inf') - hparams.setdefault('magnitude_std', mstd) - elif key == 'mmax': - # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] - hparams.setdefault('magnitude_max', int(val)) - elif key == 'inc': - if bool(val): - transforms = _RAND_INCREASING_TRANSFORMS - elif key == 'm': - magnitude = int(val) - elif key == 'n': - num_layers = int(val) - elif key == 'w': - weight_idx = int(val) + if c.startswith('t'): + # NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights' + val = str(c[1:]) + if transforms is None: + transforms = val else: - assert False, 'Unknown RandAugment config section' - ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) - choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + # numeric options + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param / randomization of magnitude values + mstd = float(val) + if mstd > 100: + # use uniform sampling in 0 to magnitude if mstd is > 100 + mstd = float('inf') + hparams.setdefault('magnitude_std', mstd) + elif key == 'mmax': + # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM] + hparams.setdefault('magnitude_max', int(val)) + elif key == 'inc': + if bool(val): + increasing = True + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'p': + prob = float(val) + else: + assert False, 'Unknown RandAugment config section' + + if isinstance(transforms, str): + transforms = rand_augment_choices(transforms, increasing=increasing) + elif transforms is None: + transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS + + choice_weights = None + if isinstance(transforms, Dict): + transforms, choice_weights = _get_weighted_transforms(transforms) + + ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) @@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [ ] -def augmix_ops(magnitude=10, hparams=None, transforms=None): +def augmix_ops( + magnitude: Union[int, float] = 10, + hparams: Optional[Dict] = None, + transforms: Optional[Union[str, Dict, List]] = None, +): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _AUGMIX_TRANSFORMS return [AugmentOp( - name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + name, + prob=1.0, + magnitude=magnitude, + hparams=hparams + ) for name in transforms] class AugMixAugment: @@ -820,22 +949,24 @@ class AugMixAugment: return fs -def augment_and_mix_transform(config_str, hparams): +def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None): """ Create AugMix PyTorch transform - :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by - dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining - sections, not order sepecific determine - 'm' - integer magnitude (severity) of augmentation mix (default: 3) - 'w' - integer width of augmentation chain (default: 3) - 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) - 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) - 'mstd' - float std deviation of magnitude noise applied (default: 0) - Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 - - :param hparams: Other hparams (kwargs) for the Augmentation transforms - - :return: A PyTorch compatible Transform + Args: + config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated + by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). + The remaining sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + hparams: Other hparams (kwargs) for the Augmentation transforms + + Returns: + A PyTorch compatible Transform """ magnitude = 3 width = 3 diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 6c28383a..7749b206 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -59,6 +59,7 @@ def transforms_imagenet_train( re_count=1, re_num_splits=0, separate=False, + force_color_jitter=False, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -77,8 +78,12 @@ def transforms_imagenet_train( primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] + disable_color_jitter = False if auto_augment: assert isinstance(auto_augment, str) + # color jitter is typically disabled if AA/RA on, + # this allows override without breaking old hparm cfgs + disable_color_jitter = not (force_color_jitter or '3a' in auto_augment) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: @@ -96,8 +101,9 @@ def transforms_imagenet_train( secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] else: secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] - elif color_jitter is not None: - # color jitter is enabled when not using AA + + if color_jitter is not None and not disable_color_jitter: + # color jitter is enabled when not using AA or when forced if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue