Some cleanup/improvements to AugMix impl:

* make 'increasing' levels for Contrast, Color, Brightness, Saturation ops
* remove recursion from faster blending mix
* add config striing parsing for AugMix
pull/74/head
Ross Wightman 5 years ago
parent 232ab7fb12
commit 3afc2a4dc0

@ -177,6 +177,14 @@ def _enhance_level_to_arg(level, _hparams):
return (level / _MAX_LEVEL) * 1.8 + 0.1,
def _enhance_increasing_level_to_arg(level, _hparams):
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
# range [0.1, 1.9]
level = (level / _MAX_LEVEL) * .9
level = 1.0 + _randomly_negate(level)
return level,
def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3
@ -247,12 +255,16 @@ LEVEL_TO_ARG = {
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
'PosterizeOriginal': _posterize_original_level_to_arg,
'Solarize': _solarize_level_to_arg,
'SolarizeIncreasing': _solarize_level_to_arg,
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg,
'ColorIncreasing': _enhance_increasing_level_to_arg,
'Contrast': _enhance_level_to_arg,
'ContrastIncreasing': _enhance_increasing_level_to_arg,
'Brightness': _enhance_level_to_arg,
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
'Sharpness': _enhance_level_to_arg,
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'TranslateX': _translate_abs_level_to_arg,
@ -274,9 +286,13 @@ NAME_TO_OP = {
'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'ColorIncreasing': color,
'Contrast': contrast,
'ContrastIncreasing': contrast,
'Brightness': brightness,
'BrightnessIncreasing': brightness,
'Sharpness': sharpness,
'SharpnessIncreasing': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
@ -527,6 +543,27 @@ _RAND_TRANSFORMS = [
]
_RAND_INCREASING_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'SolarizeAdd',
'ColorIncreasing',
'ContrastIncreasing',
'BrightnessIncreasing',
'SharpnessIncreasing',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # FIXME I implement this as random erasing separately
]
# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
@ -626,9 +663,10 @@ def rand_augment_transform(config_str, hparams):
_AUGMIX_TRANSFORMS = [
'AutoContrast',
'Contrast', # not in paper
'Brightness', # not in paper
'Sharpness', # not in paper
'ColorIncreasing', # not in paper
'ContrastIncreasing', # not in paper
'BrightnessIncreasing', # not in paper
'SharpnessIncreasing', # not in paper
'Equalize',
'Rotate',
'PosterizeIncreasing',
@ -653,21 +691,38 @@ class AugMixAugment:
self.alpha = alpha
self.width = width
self.depth = depth
self.recursive = True
def _apply_recursive(self, img, ws, prod=1.):
alpha = ws[-1] / prod
if len(ws) > 1:
img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha))
self.blended = True
def _calc_blended_weights(self, ws, m):
ws = ws * m
cump = 1.
rws = []
for w in ws[::-1]:
alpha = w / cump
cump *= (1 - alpha)
rws.append(alpha)
return np.array(rws[::-1], dtype=np.float32)
def _apply_blended(self, img, ws, m):
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
# of accumulating the mix for each chain in a Numpy array and then blending with original,
# it recomputes the blending coefficients and applies one PIL image blend per chain.
# TODO I've verified the results are in the right ballpark but they differ by more than rounding.
img_orig = img.copy()
ws = self._calc_blended_weights(ws, m)
for w in ws:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
img_aug = img_orig # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
return Image.blend(img, img_aug, alpha)
img = Image.blend(img, img_aug, w)
return img
def _apply_basic(self, img, ws, m):
# This is a literal adaptation of the paper/official implementation without normalizations and
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
# typical augmentation transforms, could use a GPU / Kornia implementation.
w, h = img.size
c = len(img.getbands())
mixed = np.zeros((w, h, c), dtype=np.float32)
@ -686,30 +741,53 @@ class AugMixAugment:
def __call__(self, img):
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
m = np.float32(np.random.beta(self.alpha, self.alpha))
if self.recursive:
mixing_weights *= m
mixed = self._apply_recursive(img, mixing_weights)
if self.blended:
mixed = self._apply_blended(img, mixing_weights, m)
else:
mixed = self._apply_basic(img, mixing_weights, m)
return mixed
def augment_and_mix_transform(config_str, hparams):
"""Perform AugMix augmentations and compute mixture.
Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c)
severity: Severity of underlying augmentation operators (between 1 to 10).
width: Width of augmentation chain
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
from [1, 3]
alpha: Probability coefficient for Beta and Dirichlet distributions.
Returns:
mixed: Augmented and mixed image.
""" Create AugMix PyTorch transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
:param hparams: Other hparams (kwargs) for the Augmentation transforms
:return: A PyTorch compatible Transform
"""
# FIXME parse args from config str
severity = 3
magnitude = 3
width = 3
depth = -1
alpha = 1.
ops = augmix_ops(magnitude=severity, hparams=hparams)
return AugMixAugment(ops, alpha, width, depth)
config = config_str.split('-')
assert config[0] == 'augmix'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'm':
magnitude = int(val)
elif key == 'w':
width = int(val)
elif key == 'd':
depth = int(val)
elif key == 'a':
alpha = float(val)
else:
assert False, 'Unknown AugMix config section'
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)

Loading…
Cancel
Save