Add 3-Augment support to auto_augment.py, clean up weighted choice handling, and allow adjust per op prob via arg string

pull/1222/merge
Ross Wightman 2 years ago
parent e98c93264c
commit e3b2f5be0a

@ -1,4 +1,4 @@
""" AutoAugment, RandAugment, and AugMix for PyTorch
""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
This code implements the searched ImageNet policies with various tweaks and improvements and
does not include any of the search code.
@ -9,18 +9,24 @@ AA and RA Implementation adapted from:
AugMix adapted from:
https://github.com/google-research/augmix
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
Papers:
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
Hacked together by / Copyright 2019, Ross Wightman
"""
import random
import math
import re
from PIL import Image, ImageOps, ImageEnhance, ImageChops
from functools import partial
from typing import Dict, List, Optional, Union
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
import PIL
import numpy as np
@ -175,6 +181,24 @@ def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
def gaussian_blur(img, factor, **__):
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
return img
def gaussian_blur_rand(img, factor, **__):
radius_min = 0.1
radius_max = 2.0
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
return img
def desaturate(img, factor, **_):
factor = min(1., max(0., 1. - factor))
# enhance factor 0 = grayscale, 1.0 = no-change
return ImageEnhance.Color(img).enhance(factor)
def _randomly_negate(v):
"""With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v
@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams):
return level,
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
level = (level / _LEVEL_DENOM)
min_val + (max_val - min_val) * level
if clamp:
level = min(min_val, max(max_val, level))
return level,
def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3]
level = (level / _LEVEL_DENOM) * 0.3
@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams):
def _solarize_level_to_arg(level, _hparams):
# range [0, 256]
# intensity/severity of augmentation decreases with level
return int((level / _LEVEL_DENOM) * 256),
return min(256, int((level / _LEVEL_DENOM) * 256)),
def _solarize_increasing_level_to_arg(level, _hparams):
@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams):
def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110]
return int((level / _LEVEL_DENOM) * 110),
return min(128, int((level / _LEVEL_DENOM) * 110)),
LEVEL_TO_ARG = {
@ -286,6 +318,9 @@ LEVEL_TO_ARG = {
'TranslateY': _translate_abs_level_to_arg,
'TranslateXRel': _translate_rel_level_to_arg,
'TranslateYRel': _translate_rel_level_to_arg,
'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
'GaussianBlurRand': _minmax_level_to_arg,
}
@ -314,6 +349,9 @@ NAME_TO_OP = {
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
'Desaturate': desaturate,
'GaussianBlur': gaussian_blur,
'GaussianBlurRand': gaussian_blur_rand,
}
@ -347,6 +385,7 @@ class AugmentOp:
if self.magnitude_std > 0:
# magnitude randomization enabled
if self.magnitude_std == float('inf'):
# inf == uniform sampling
magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams):
return pc
def auto_augment_policy_3a(hparams):
policy = [
[('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
[('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
[('GaussianBlurRand', 1.0, 10)],
]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
if name == 'original':
@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None):
return auto_augment_policy_v0(hparams)
elif name == 'v0r':
return auto_augment_policy_v0r(hparams)
elif name == '3a':
return auto_augment_policy_3a(hparams)
else:
assert False, 'Unknown AA policy (%s)' % name
@ -534,19 +585,23 @@ class AutoAugment:
return fs
def auto_augment_transform(config_str, hparams):
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
"""
Create a AutoAugment transform
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections, not order sepecific determine
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
Args:
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-').
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
The remaining sections:
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
:return: A PyTorch compatible Transform
hparams: Other hparams (kwargs) for the AutoAugmentation scheme
Returns:
A PyTorch compatible Transform
"""
config = config_str.split('-')
policy_name = config[0]
@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [
]
_RAND_3A = [
'SolarizeIncreasing',
'Desaturate',
'GaussianBlur',
]
_RAND_CHOICE_3A = {
'SolarizeIncreasing': 6,
'Desaturate': 6,
'GaussianBlur': 6,
'Rotate': 3,
'ShearX': 2,
'ShearY': 2,
'PosterizeIncreasing': 1,
'AutoContrast': 1,
'ColorIncreasing': 1,
'SharpnessIncreasing': 1,
'ContrastIncreasing': 1,
'BrightnessIncreasing': 1,
'Equalize': 1,
'Invert': 1,
}
# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
'Rotate': 0.3,
'ShearX': 0.2,
'ShearY': 0.2,
'TranslateXRel': 0.1,
'TranslateYRel': 0.1,
'Color': .025,
'Sharpness': 0.025,
'AutoContrast': 0.025,
'Solarize': .005,
'SolarizeAdd': .005,
'Contrast': .005,
'Brightness': .005,
'Equalize': .005,
'Posterize': 0,
'Invert': 0,
'Rotate': 3,
'ShearX': 2,
'ShearY': 2,
'TranslateXRel': 1,
'TranslateYRel': 1,
'ColorIncreasing': .25,
'SharpnessIncreasing': 0.25,
'AutoContrast': 0.25,
'SolarizeIncreasing': .05,
'SolarizeAdd': .05,
'ContrastIncreasing': .05,
'BrightnessIncreasing': .05,
'Equalize': .05,
'PosterizeIncreasing': 0.05,
'Invert': 0.05,
}
def _select_rand_weights(weight_idx=0, transforms=None):
transforms = transforms or _RAND_TRANSFORMS
assert weight_idx == 0 # only one set of weights currently
rand_weights = _RAND_CHOICE_WEIGHTS_0
probs = [rand_weights[k] for k in transforms]
probs /= np.sum(probs)
return probs
def _get_weighted_transforms(transforms: Dict):
transforms, probs = list(zip(*transforms.items()))
probs = np.array(probs)
probs = probs / np.sum(probs)
return transforms, probs
def rand_augment_choices(name: str, increasing=True):
if name == 'weights':
return _RAND_CHOICE_WEIGHTS_0
elif name == '3aw':
return _RAND_CHOICE_3A
elif name == '3a':
return _RAND_3A
else:
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
def rand_augment_ops(
magnitude: Union[int, float] = 10,
prob: float = 0.5,
hparams: Optional[Dict] = None,
transforms: Optional[Union[Dict, List]] = None,
):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
class RandAugment:
@ -648,11 +741,16 @@ class RandAugment:
self.ops = ops
self.num_layers = num_layers
self.choice_weights = choice_weights
print(self.ops, self.choice_weights)
def __call__(self, img):
# no replacement when using weighted choice
ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
self.ops,
self.num_layers,
replace=self.choice_weights is None,
p=self.choice_weights,
)
for op in ops:
img = op(img)
return img
@ -665,61 +763,84 @@ class RandAugment:
return fs
def rand_augment_transform(config_str, hparams):
def rand_augment_transform(
config_str: str,
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
"""
Create a RandAugment transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
:return: A PyTorch compatible Transform
Args:
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image)
'p' - float probability of applying each layer (default 0.5)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
't' - str name of transform set to use
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
Returns:
A PyTorch compatible Transform
"""
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice
transforms = _RAND_TRANSFORMS
increasing = False
prob = 0.5
config = config_str.split('-')
assert config[0] == 'rand'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param / randomization of magnitude values
mstd = float(val)
if mstd > 100:
# use uniform sampling in 0 to magnitude if mstd is > 100
mstd = float('inf')
hparams.setdefault('magnitude_std', mstd)
elif key == 'mmax':
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
hparams.setdefault('magnitude_max', int(val))
elif key == 'inc':
if bool(val):
transforms = _RAND_INCREASING_TRANSFORMS
elif key == 'm':
magnitude = int(val)
elif key == 'n':
num_layers = int(val)
elif key == 'w':
weight_idx = int(val)
if c.startswith('t'):
# NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
val = str(c[1:])
if transforms is None:
transforms = val
else:
assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
# numeric options
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param / randomization of magnitude values
mstd = float(val)
if mstd > 100:
# use uniform sampling in 0 to magnitude if mstd is > 100
mstd = float('inf')
hparams.setdefault('magnitude_std', mstd)
elif key == 'mmax':
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
hparams.setdefault('magnitude_max', int(val))
elif key == 'inc':
if bool(val):
increasing = True
elif key == 'm':
magnitude = int(val)
elif key == 'n':
num_layers = int(val)
elif key == 'p':
prob = float(val)
else:
assert False, 'Unknown RandAugment config section'
if isinstance(transforms, str):
transforms = rand_augment_choices(transforms, increasing=increasing)
elif transforms is None:
transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
choice_weights = None
if isinstance(transforms, Dict):
transforms, choice_weights = _get_weighted_transforms(transforms)
ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [
]
def augmix_ops(magnitude=10, hparams=None, transforms=None):
def augmix_ops(
magnitude: Union[int, float] = 10,
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _AUGMIX_TRANSFORMS
return [AugmentOp(
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
name,
prob=1.0,
magnitude=magnitude,
hparams=hparams
) for name in transforms]
class AugMixAugment:
@ -820,22 +949,24 @@ class AugMixAugment:
return fs
def augment_and_mix_transform(config_str, hparams):
def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
""" Create AugMix PyTorch transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
:param hparams: Other hparams (kwargs) for the Augmentation transforms
:return: A PyTorch compatible Transform
Args:
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
hparams: Other hparams (kwargs) for the Augmentation transforms
Returns:
A PyTorch compatible Transform
"""
magnitude = 3
width = 3

@ -59,6 +59,7 @@ def transforms_imagenet_train(
re_count=1,
re_num_splits=0,
separate=False,
force_color_jitter=False,
):
"""
If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -77,8 +78,12 @@ def transforms_imagenet_train(
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
disable_color_jitter = False
if auto_augment:
assert isinstance(auto_augment, str)
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size)
else:
@ -96,8 +101,9 @@ def transforms_imagenet_train(
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA
if color_jitter is not None and not disable_color_jitter:
# color jitter is enabled when not using AA or when forced
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue

Loading…
Cancel
Save