From 3afc2a4dc0db55919a897f8e2af8aeb315b10703 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Dec 2019 22:32:05 -0800 Subject: [PATCH] Some cleanup/improvements to AugMix impl: * make 'increasing' levels for Contrast, Color, Brightness, Saturation ops * remove recursion from faster blending mix * add config striing parsing for AugMix --- timm/data/auto_augment.py | 146 +++++++++++++++++++++++++++++--------- 1 file changed, 112 insertions(+), 34 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index cc1b716c..864ca6e0 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -177,6 +177,14 @@ def _enhance_level_to_arg(level, _hparams): return (level / _MAX_LEVEL) * 1.8 + 0.1, +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * .9 + level = 1.0 + _randomly_negate(level) + return level, + + def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _MAX_LEVEL) * 0.3 @@ -247,12 +255,16 @@ LEVEL_TO_ARG = { 'PosterizeIncreasing': _posterize_increasing_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg, 'Solarize': _solarize_level_to_arg, - 'SolarizeIncreasing': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_increasing_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg, 'Color': _enhance_level_to_arg, + 'ColorIncreasing': _enhance_increasing_level_to_arg, 'Contrast': _enhance_level_to_arg, + 'ContrastIncreasing': _enhance_increasing_level_to_arg, 'Brightness': _enhance_level_to_arg, + 'BrightnessIncreasing': _enhance_increasing_level_to_arg, 'Sharpness': _enhance_level_to_arg, + 'SharpnessIncreasing': _enhance_increasing_level_to_arg, 'ShearX': _shear_level_to_arg, 'ShearY': _shear_level_to_arg, 'TranslateX': _translate_abs_level_to_arg, @@ -274,9 +286,13 @@ NAME_TO_OP = { 'SolarizeIncreasing': solarize, 'SolarizeAdd': solarize_add, 'Color': color, + 'ColorIncreasing': color, 'Contrast': contrast, + 'ContrastIncreasing': contrast, 'Brightness': brightness, + 'BrightnessIncreasing': brightness, 'Sharpness': sharpness, + 'SharpnessIncreasing': sharpness, 'ShearX': shear_x, 'ShearY': shear_y, 'TranslateX': translate_x_abs, @@ -527,6 +543,27 @@ _RAND_TRANSFORMS = [ ] +_RAND_INCREASING_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'SolarizeAdd', + 'ColorIncreasing', + 'ContrastIncreasing', + 'BrightnessIncreasing', + 'SharpnessIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # FIXME I implement this as random erasing separately +] + + + # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { @@ -626,9 +663,10 @@ def rand_augment_transform(config_str, hparams): _AUGMIX_TRANSFORMS = [ 'AutoContrast', - 'Contrast', # not in paper - 'Brightness', # not in paper - 'Sharpness', # not in paper + 'ColorIncreasing', # not in paper + 'ContrastIncreasing', # not in paper + 'BrightnessIncreasing', # not in paper + 'SharpnessIncreasing', # not in paper 'Equalize', 'Rotate', 'PosterizeIncreasing', @@ -653,21 +691,38 @@ class AugMixAugment: self.alpha = alpha self.width = width self.depth = depth - self.recursive = True - - def _apply_recursive(self, img, ws, prod=1.): - alpha = ws[-1] / prod - if len(ws) > 1: - img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha)) - - depth = self.depth if self.depth > 0 else np.random.randint(1, 4) - ops = np.random.choice(self.ops, depth, replace=True) - img_aug = img # no ops are in-place, deep copy not necessary - for op in ops: - img_aug = op(img_aug) - return Image.blend(img, img_aug, alpha) + self.blended = True + + def _calc_blended_weights(self, ws, m): + ws = ws * m + cump = 1. + rws = [] + for w in ws[::-1]: + alpha = w / cump + cump *= (1 - alpha) + rws.append(alpha) + return np.array(rws[::-1], dtype=np.float32) + + def _apply_blended(self, img, ws, m): + # This is my first crack and implementing a slightly faster mixed augmentation. Instead + # of accumulating the mix for each chain in a Numpy array and then blending with original, + # it recomputes the blending coefficients and applies one PIL image blend per chain. + # TODO I've verified the results are in the right ballpark but they differ by more than rounding. + img_orig = img.copy() + ws = self._calc_blended_weights(ws, m) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img_orig # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img = Image.blend(img, img_aug, w) + return img def _apply_basic(self, img, ws, m): + # This is a literal adaptation of the paper/official implementation without normalizations and + # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the + # typical augmentation transforms, could use a GPU / Kornia implementation. w, h = img.size c = len(img.getbands()) mixed = np.zeros((w, h, c), dtype=np.float32) @@ -686,30 +741,53 @@ class AugMixAugment: def __call__(self, img): mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) m = np.float32(np.random.beta(self.alpha, self.alpha)) - if self.recursive: - mixing_weights *= m - mixed = self._apply_recursive(img, mixing_weights) + if self.blended: + mixed = self._apply_blended(img, mixing_weights, m) else: mixed = self._apply_basic(img, mixing_weights, m) return mixed def augment_and_mix_transform(config_str, hparams): - """Perform AugMix augmentations and compute mixture. - Args: - image: Raw input image as float32 np.ndarray of shape (h, w, c) - severity: Severity of underlying augmentation operators (between 1 to 10). - width: Width of augmentation chain - depth: Depth of augmentation chain. -1 enables stochastic depth uniformly - from [1, 3] - alpha: Probability coefficient for Beta and Dirichlet distributions. - Returns: - mixed: Augmented and mixed image. + """ Create AugMix PyTorch transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + :param hparams: Other hparams (kwargs) for the Augmentation transforms + + :return: A PyTorch compatible Transform """ - # FIXME parse args from config str - severity = 3 + magnitude = 3 width = 3 depth = -1 alpha = 1. - ops = augmix_ops(magnitude=severity, hparams=hparams) - return AugMixAugment(ops, alpha, width, depth) + config = config_str.split('-') + assert config[0] == 'augmix' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'w': + width = int(val) + elif key == 'd': + depth = int(val) + elif key == 'a': + alpha = float(val) + else: + assert False, 'Unknown AugMix config section' + ops = augmix_ops(magnitude=magnitude, hparams=hparams) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)