Some cleanup/improvements to AugMix impl:

* make 'increasing' levels for Contrast, Color, Brightness, Saturation ops
* remove recursion from faster blending mix
* add config striing parsing for AugMix
pull/74/head
Ross Wightman 5 years ago
parent 232ab7fb12
commit 3afc2a4dc0

@ -177,6 +177,14 @@ def _enhance_level_to_arg(level, _hparams):
return (level / _MAX_LEVEL) * 1.8 + 0.1, return (level / _MAX_LEVEL) * 1.8 + 0.1,
def _enhance_increasing_level_to_arg(level, _hparams):
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
# range [0.1, 1.9]
level = (level / _MAX_LEVEL) * .9
level = 1.0 + _randomly_negate(level)
return level,
def _shear_level_to_arg(level, _hparams): def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3] # range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3 level = (level / _MAX_LEVEL) * 0.3
@ -247,12 +255,16 @@ LEVEL_TO_ARG = {
'PosterizeIncreasing': _posterize_increasing_level_to_arg, 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
'PosterizeOriginal': _posterize_original_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg,
'Solarize': _solarize_level_to_arg, 'Solarize': _solarize_level_to_arg,
'SolarizeIncreasing': _solarize_level_to_arg, 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg, 'Color': _enhance_level_to_arg,
'ColorIncreasing': _enhance_increasing_level_to_arg,
'Contrast': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg,
'ContrastIncreasing': _enhance_increasing_level_to_arg,
'Brightness': _enhance_level_to_arg, 'Brightness': _enhance_level_to_arg,
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
'Sharpness': _enhance_level_to_arg, 'Sharpness': _enhance_level_to_arg,
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
'ShearX': _shear_level_to_arg, 'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg, 'ShearY': _shear_level_to_arg,
'TranslateX': _translate_abs_level_to_arg, 'TranslateX': _translate_abs_level_to_arg,
@ -274,9 +286,13 @@ NAME_TO_OP = {
'SolarizeIncreasing': solarize, 'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add, 'SolarizeAdd': solarize_add,
'Color': color, 'Color': color,
'ColorIncreasing': color,
'Contrast': contrast, 'Contrast': contrast,
'ContrastIncreasing': contrast,
'Brightness': brightness, 'Brightness': brightness,
'BrightnessIncreasing': brightness,
'Sharpness': sharpness, 'Sharpness': sharpness,
'SharpnessIncreasing': sharpness,
'ShearX': shear_x, 'ShearX': shear_x,
'ShearY': shear_y, 'ShearY': shear_y,
'TranslateX': translate_x_abs, 'TranslateX': translate_x_abs,
@ -527,6 +543,27 @@ _RAND_TRANSFORMS = [
] ]
_RAND_INCREASING_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'SolarizeAdd',
'ColorIncreasing',
'ContrastIncreasing',
'BrightnessIncreasing',
'SharpnessIncreasing',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # FIXME I implement this as random erasing separately
]
# These experimental weights are based loosely on the relative improvements mentioned in paper. # These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so. # They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = { _RAND_CHOICE_WEIGHTS_0 = {
@ -626,9 +663,10 @@ def rand_augment_transform(config_str, hparams):
_AUGMIX_TRANSFORMS = [ _AUGMIX_TRANSFORMS = [
'AutoContrast', 'AutoContrast',
'Contrast', # not in paper 'ColorIncreasing', # not in paper
'Brightness', # not in paper 'ContrastIncreasing', # not in paper
'Sharpness', # not in paper 'BrightnessIncreasing', # not in paper
'SharpnessIncreasing', # not in paper
'Equalize', 'Equalize',
'Rotate', 'Rotate',
'PosterizeIncreasing', 'PosterizeIncreasing',
@ -653,21 +691,38 @@ class AugMixAugment:
self.alpha = alpha self.alpha = alpha
self.width = width self.width = width
self.depth = depth self.depth = depth
self.recursive = True self.blended = True
def _apply_recursive(self, img, ws, prod=1.): def _calc_blended_weights(self, ws, m):
alpha = ws[-1] / prod ws = ws * m
if len(ws) > 1: cump = 1.
img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha)) rws = []
for w in ws[::-1]:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4) alpha = w / cump
ops = np.random.choice(self.ops, depth, replace=True) cump *= (1 - alpha)
img_aug = img # no ops are in-place, deep copy not necessary rws.append(alpha)
for op in ops: return np.array(rws[::-1], dtype=np.float32)
img_aug = op(img_aug)
return Image.blend(img, img_aug, alpha) def _apply_blended(self, img, ws, m):
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
# of accumulating the mix for each chain in a Numpy array and then blending with original,
# it recomputes the blending coefficients and applies one PIL image blend per chain.
# TODO I've verified the results are in the right ballpark but they differ by more than rounding.
img_orig = img.copy()
ws = self._calc_blended_weights(ws, m)
for w in ws:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img_orig # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
img = Image.blend(img, img_aug, w)
return img
def _apply_basic(self, img, ws, m): def _apply_basic(self, img, ws, m):
# This is a literal adaptation of the paper/official implementation without normalizations and
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
# typical augmentation transforms, could use a GPU / Kornia implementation.
w, h = img.size w, h = img.size
c = len(img.getbands()) c = len(img.getbands())
mixed = np.zeros((w, h, c), dtype=np.float32) mixed = np.zeros((w, h, c), dtype=np.float32)
@ -686,30 +741,53 @@ class AugMixAugment:
def __call__(self, img): def __call__(self, img):
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
m = np.float32(np.random.beta(self.alpha, self.alpha)) m = np.float32(np.random.beta(self.alpha, self.alpha))
if self.recursive: if self.blended:
mixing_weights *= m mixed = self._apply_blended(img, mixing_weights, m)
mixed = self._apply_recursive(img, mixing_weights)
else: else:
mixed = self._apply_basic(img, mixing_weights, m) mixed = self._apply_basic(img, mixing_weights, m)
return mixed return mixed
def augment_and_mix_transform(config_str, hparams): def augment_and_mix_transform(config_str, hparams):
"""Perform AugMix augmentations and compute mixture. """ Create AugMix PyTorch transform
Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c) :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
severity: Severity of underlying augmentation operators (between 1 to 10). dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
width: Width of augmentation chain sections, not order sepecific determine
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly 'm' - integer magnitude (severity) of augmentation mix (default: 3)
from [1, 3] 'w' - integer width of augmentation chain (default: 3)
alpha: Probability coefficient for Beta and Dirichlet distributions. 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
Returns: 'mstd' - float std deviation of magnitude noise applied (default: 0)
mixed: Augmented and mixed image. Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
:param hparams: Other hparams (kwargs) for the Augmentation transforms
:return: A PyTorch compatible Transform
""" """
# FIXME parse args from config str magnitude = 3
severity = 3
width = 3 width = 3
depth = -1 depth = -1
alpha = 1. alpha = 1.
ops = augmix_ops(magnitude=severity, hparams=hparams) config = config_str.split('-')
return AugMixAugment(ops, alpha, width, depth) assert config[0] == 'augmix'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'm':
magnitude = int(val)
elif key == 'w':
width = int(val)
elif key == 'd':
depth = int(val)
elif key == 'a':
alpha = float(val)
else:
assert False, 'Unknown AugMix config section'
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)

Loading…
Cancel
Save