Add 3-Augment support to auto_augment.py, clean up weighted choice handling, and allow adjust per op prob via arg string

pull/1222/merge
Ross Wightman 2 years ago
parent e98c93264c
commit e3b2f5be0a

@ -1,4 +1,4 @@
""" AutoAugment, RandAugment, and AugMix for PyTorch """ AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
This code implements the searched ImageNet policies with various tweaks and improvements and This code implements the searched ImageNet policies with various tweaks and improvements and
does not include any of the search code. does not include any of the search code.
@ -9,18 +9,24 @@ AA and RA Implementation adapted from:
AugMix adapted from: AugMix adapted from:
https://github.com/google-research/augmix https://github.com/google-research/augmix
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
Papers: Papers:
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
import random import random
import math import math
import re import re
from PIL import Image, ImageOps, ImageEnhance, ImageChops from functools import partial
from typing import Dict, List, Optional, Union
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
import PIL import PIL
import numpy as np import numpy as np
@ -175,6 +181,24 @@ def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor) return ImageEnhance.Sharpness(img).enhance(factor)
def gaussian_blur(img, factor, **__):
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
return img
def gaussian_blur_rand(img, factor, **__):
radius_min = 0.1
radius_max = 2.0
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
return img
def desaturate(img, factor, **_):
factor = min(1., max(0., 1. - factor))
# enhance factor 0 = grayscale, 1.0 = no-change
return ImageEnhance.Color(img).enhance(factor)
def _randomly_negate(v): def _randomly_negate(v):
"""With 50% prob, negate the value""" """With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v return -v if random.random() > 0.5 else v
@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams):
return level, return level,
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
level = (level / _LEVEL_DENOM)
min_val + (max_val - min_val) * level
if clamp:
level = min(min_val, max(max_val, level))
return level,
def _shear_level_to_arg(level, _hparams): def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3] # range [-0.3, 0.3]
level = (level / _LEVEL_DENOM) * 0.3 level = (level / _LEVEL_DENOM) * 0.3
@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams):
def _solarize_level_to_arg(level, _hparams): def _solarize_level_to_arg(level, _hparams):
# range [0, 256] # range [0, 256]
# intensity/severity of augmentation decreases with level # intensity/severity of augmentation decreases with level
return int((level / _LEVEL_DENOM) * 256), return min(256, int((level / _LEVEL_DENOM) * 256)),
def _solarize_increasing_level_to_arg(level, _hparams): def _solarize_increasing_level_to_arg(level, _hparams):
@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams):
def _solarize_add_level_to_arg(level, _hparams): def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110] # range [0, 110]
return int((level / _LEVEL_DENOM) * 110), return min(128, int((level / _LEVEL_DENOM) * 110)),
LEVEL_TO_ARG = { LEVEL_TO_ARG = {
@ -286,6 +318,9 @@ LEVEL_TO_ARG = {
'TranslateY': _translate_abs_level_to_arg, 'TranslateY': _translate_abs_level_to_arg,
'TranslateXRel': _translate_rel_level_to_arg, 'TranslateXRel': _translate_rel_level_to_arg,
'TranslateYRel': _translate_rel_level_to_arg, 'TranslateYRel': _translate_rel_level_to_arg,
'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
'GaussianBlurRand': _minmax_level_to_arg,
} }
@ -314,6 +349,9 @@ NAME_TO_OP = {
'TranslateY': translate_y_abs, 'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel, 'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel, 'TranslateYRel': translate_y_rel,
'Desaturate': desaturate,
'GaussianBlur': gaussian_blur,
'GaussianBlurRand': gaussian_blur_rand,
} }
@ -347,6 +385,7 @@ class AugmentOp:
if self.magnitude_std > 0: if self.magnitude_std > 0:
# magnitude randomization enabled # magnitude randomization enabled
if self.magnitude_std == float('inf'): if self.magnitude_std == float('inf'):
# inf == uniform sampling
magnitude = random.uniform(0, magnitude) magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0: elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = random.gauss(magnitude, self.magnitude_std)
@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams):
return pc return pc
def auto_augment_policy_3a(hparams):
policy = [
[('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
[('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
[('GaussianBlurRand', 1.0, 10)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=None): def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
if name == 'original': if name == 'original':
@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None):
return auto_augment_policy_v0(hparams) return auto_augment_policy_v0(hparams)
elif name == 'v0r': elif name == 'v0r':
return auto_augment_policy_v0r(hparams) return auto_augment_policy_v0r(hparams)
elif name == '3a':
return auto_augment_policy_3a(hparams)
else: else:
assert False, 'Unknown AA policy (%s)' % name assert False, 'Unknown AA policy (%s)' % name
@ -534,19 +585,23 @@ class AutoAugment:
return fs return fs
def auto_augment_transform(config_str, hparams): def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
""" """
Create a AutoAugment transform Create a AutoAugment transform
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by Args:
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
The remaining sections, not order sepecific determine dashes ('-').
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections:
'mstd' - float std deviation of magnitude noise applied 'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme hparams: Other hparams (kwargs) for the AutoAugmentation scheme
:return: A PyTorch compatible Transform Returns:
A PyTorch compatible Transform
""" """
config = config_str.split('-') config = config_str.split('-')
policy_name = config[0] policy_name = config[0]
@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [
] ]
_RAND_3A = [
'SolarizeIncreasing',
'Desaturate',
'GaussianBlur',
]
_RAND_CHOICE_3A = {
'SolarizeIncreasing': 6,
'Desaturate': 6,
'GaussianBlur': 6,
'Rotate': 3,
'ShearX': 2,
'ShearY': 2,
'PosterizeIncreasing': 1,
'AutoContrast': 1,
'ColorIncreasing': 1,
'SharpnessIncreasing': 1,
'ContrastIncreasing': 1,
'BrightnessIncreasing': 1,
'Equalize': 1,
'Invert': 1,
}
# These experimental weights are based loosely on the relative improvements mentioned in paper. # These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so. # They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = { _RAND_CHOICE_WEIGHTS_0 = {
'Rotate': 0.3, 'Rotate': 3,
'ShearX': 0.2, 'ShearX': 2,
'ShearY': 0.2, 'ShearY': 2,
'TranslateXRel': 0.1, 'TranslateXRel': 1,
'TranslateYRel': 0.1, 'TranslateYRel': 1,
'Color': .025, 'ColorIncreasing': .25,
'Sharpness': 0.025, 'SharpnessIncreasing': 0.25,
'AutoContrast': 0.025, 'AutoContrast': 0.25,
'Solarize': .005, 'SolarizeIncreasing': .05,
'SolarizeAdd': .005, 'SolarizeAdd': .05,
'Contrast': .005, 'ContrastIncreasing': .05,
'Brightness': .005, 'BrightnessIncreasing': .05,
'Equalize': .005, 'Equalize': .05,
'Posterize': 0, 'PosterizeIncreasing': 0.05,
'Invert': 0, 'Invert': 0.05,
} }
def _select_rand_weights(weight_idx=0, transforms=None): def _get_weighted_transforms(transforms: Dict):
transforms = transforms or _RAND_TRANSFORMS transforms, probs = list(zip(*transforms.items()))
assert weight_idx == 0 # only one set of weights currently probs = np.array(probs)
rand_weights = _RAND_CHOICE_WEIGHTS_0 probs = probs / np.sum(probs)
probs = [rand_weights[k] for k in transforms] return transforms, probs
probs /= np.sum(probs)
return probs
def rand_augment_choices(name: str, increasing=True):
if name == 'weights':
return _RAND_CHOICE_WEIGHTS_0
elif name == '3aw':
return _RAND_CHOICE_3A
elif name == '3a':
return _RAND_3A
else:
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
def rand_augment_ops(magnitude=10, hparams=None, transforms=None): def rand_augment_ops(
magnitude: Union[int, float] = 10,
prob: float = 0.5,
hparams: Optional[Dict] = None,
transforms: Optional[Union[Dict, List]] = None,
):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp( return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
class RandAugment: class RandAugment:
@ -648,11 +741,16 @@ class RandAugment:
self.ops = ops self.ops = ops
self.num_layers = num_layers self.num_layers = num_layers
self.choice_weights = choice_weights self.choice_weights = choice_weights
print(self.ops, self.choice_weights)
def __call__(self, img): def __call__(self, img):
# no replacement when using weighted choice # no replacement when using weighted choice
ops = np.random.choice( ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) self.ops,
self.num_layers,
replace=self.choice_weights is None,
p=self.choice_weights,
)
for op in ops: for op in ops:
img = op(img) img = op(img)
return img return img
@ -665,34 +763,48 @@ class RandAugment:
return fs return fs
def rand_augment_transform(config_str, hparams): def rand_augment_transform(
config_str: str,
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
""" """
Create a RandAugment transform Create a RandAugment transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by Args:
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
sections, not order sepecific determine by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude of rand augment 'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image) 'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'p' - float probability of applying each layer (default 0.5)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
't' - str name of transform set to use
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
:return: A PyTorch compatible Transform Returns:
A PyTorch compatible Transform
""" """
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice increasing = False
transforms = _RAND_TRANSFORMS prob = 0.5
config = config_str.split('-') config = config_str.split('-')
assert config[0] == 'rand' assert config[0] == 'rand'
config = config[1:] config = config[1:]
for c in config: for c in config:
if c.startswith('t'):
# NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
val = str(c[1:])
if transforms is None:
transforms = val
else:
# numeric options
cs = re.split(r'(\d.*)', c) cs = re.split(r'(\d.*)', c)
if len(cs) < 2: if len(cs) < 2:
continue continue
@ -709,17 +821,26 @@ def rand_augment_transform(config_str, hparams):
hparams.setdefault('magnitude_max', int(val)) hparams.setdefault('magnitude_max', int(val))
elif key == 'inc': elif key == 'inc':
if bool(val): if bool(val):
transforms = _RAND_INCREASING_TRANSFORMS increasing = True
elif key == 'm': elif key == 'm':
magnitude = int(val) magnitude = int(val)
elif key == 'n': elif key == 'n':
num_layers = int(val) num_layers = int(val)
elif key == 'w': elif key == 'p':
weight_idx = int(val) prob = float(val)
else: else:
assert False, 'Unknown RandAugment config section' assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) if isinstance(transforms, str):
transforms = rand_augment_choices(transforms, increasing=increasing)
elif transforms is None:
transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
choice_weights = None
if isinstance(transforms, Dict):
transforms, choice_weights = _get_weighted_transforms(transforms)
ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [
] ]
def augmix_ops(magnitude=10, hparams=None, transforms=None): def augmix_ops(
magnitude: Union[int, float] = 10,
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _AUGMIX_TRANSFORMS transforms = transforms or _AUGMIX_TRANSFORMS
return [AugmentOp( return [AugmentOp(
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] name,
prob=1.0,
magnitude=magnitude,
hparams=hparams
) for name in transforms]
class AugMixAugment: class AugMixAugment:
@ -820,12 +949,13 @@ class AugMixAugment:
return fs return fs
def augment_and_mix_transform(config_str, hparams): def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
""" Create AugMix PyTorch transform """ Create AugMix PyTorch transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by Args:
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
sections, not order sepecific determine by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude (severity) of augmentation mix (default: 3) 'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3) 'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
@ -833,9 +963,10 @@ def augment_and_mix_transform(config_str, hparams):
'mstd' - float std deviation of magnitude noise applied (default: 0) 'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
:param hparams: Other hparams (kwargs) for the Augmentation transforms hparams: Other hparams (kwargs) for the Augmentation transforms
:return: A PyTorch compatible Transform Returns:
A PyTorch compatible Transform
""" """
magnitude = 3 magnitude = 3
width = 3 width = 3

@ -59,6 +59,7 @@ def transforms_imagenet_train(
re_count=1, re_count=1,
re_num_splits=0, re_num_splits=0,
separate=False, separate=False,
force_color_jitter=False,
): ):
""" """
If separate==True, the transforms are returned as a tuple of 3 separate transforms If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -77,8 +78,12 @@ def transforms_imagenet_train(
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
secondary_tfl = [] secondary_tfl = []
disable_color_jitter = False
if auto_augment: if auto_augment:
assert isinstance(auto_augment, str) assert isinstance(auto_augment, str)
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
if isinstance(img_size, (tuple, list)): if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size) img_size_min = min(img_size)
else: else:
@ -96,8 +101,9 @@ def transforms_imagenet_train(
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else: else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA if color_jitter is not None and not disable_color_jitter:
# color jitter is enabled when not using AA or when forced
if isinstance(color_jitter, (list, tuple)): if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue # or 4 if also augmenting hue

Loading…
Cancel
Save