diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 61ad83ea..49c4bc60 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -4,3 +4,5 @@ from .dataset import Dataset, DatasetTar from .transforms import * from .loader import create_loader, create_transform from .mixup import mixup_target, FastCollateMixup +from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ + rand_augment_transform, auto_augment_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 04c0b60a..9d711cd1 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -7,11 +7,13 @@ Hacked together by Ross Wightman """ import random import math +import re from PIL import Image, ImageOps, ImageEnhance import PIL import numpy as np + _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _FILL = (128, 128, 128) @@ -25,11 +27,11 @@ _HPARAMS_DEFAULT = dict( img_mean=_FILL, ) -_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) def _interpolation(kwargs): - interpolation = kwargs.pop('resample', Image.NEAREST) + interpolation = kwargs.pop('resample', Image.BILINEAR) if isinstance(interpolation, (list, tuple)): return random.choice(interpolation) else: @@ -140,7 +142,6 @@ def solarize_add(img, add, thresh=128, **__): def posterize(img, bits_to_keep, **__): if bits_to_keep >= 8: return img - bits_to_keep = max(1, bits_to_keep) # prevent all 0 images return ImageOps.posterize(img, bits_to_keep) @@ -165,61 +166,89 @@ def _randomly_negate(v): return -v if random.random() > 0.5 else v -def _rotate_level_to_arg(level): +def _rotate_level_to_arg(level, _hparams): # range [-30, 30] level = (level / _MAX_LEVEL) * 30. level = _randomly_negate(level) - return (level,) + return level, -def _enhance_level_to_arg(level): +def _enhance_level_to_arg(level, _hparams): # range [0.1, 1.9] - return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + return (level / _MAX_LEVEL) * 1.8 + 0.1, -def _shear_level_to_arg(level): +def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _MAX_LEVEL) * 0.3 level = _randomly_negate(level) - return (level,) + return level, -def _translate_abs_level_to_arg(level, translate_const): +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] level = (level / _MAX_LEVEL) * float(translate_const) level = _randomly_negate(level) - return (level,) + return level, -def _translate_rel_level_to_arg(level): +def _translate_rel_level_to_arg(level, _hparams): # range [-0.45, 0.45] level = (level / _MAX_LEVEL) * 0.45 level = _randomly_negate(level) - return (level,) - - -def level_to_arg(hparams): - return { - 'AutoContrast': lambda level: (), - 'Equalize': lambda level: (), - 'Invert': lambda level: (), - 'Rotate': _rotate_level_to_arg, - # FIXME these are both different from original impl as I believe there is a bug, - # not sure what is the correct alternative, hence 2 options that look better - 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8] - 'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0] - 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256] - 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110] - 'Color': _enhance_level_to_arg, - 'Contrast': _enhance_level_to_arg, - 'Brightness': _enhance_level_to_arg, - 'Sharpness': _enhance_level_to_arg, - 'ShearX': _shear_level_to_arg, - 'ShearY': _shear_level_to_arg, - 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), - 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), - 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), - 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level), - } + return level, + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return int((level / _MAX_LEVEL) * 4) + 4, + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return 4 - int((level / _MAX_LEVEL) * 4), + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return int((level / _MAX_LEVEL) * 4), + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return int((level / _MAX_LEVEL) * 256), + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return int((level / _MAX_LEVEL) * 110), + + +LEVEL_TO_ARG = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Rotate': _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} NAME_TO_OP = { @@ -227,8 +256,9 @@ NAME_TO_OP = { 'Equalize': equalize, 'Invert': invert, 'Rotate': rotate, - 'Posterize': posterize, - 'Posterize2': posterize, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, 'Solarize': solarize, 'SolarizeAdd': solarize_add, 'Color': color, @@ -246,35 +276,37 @@ NAME_TO_OP = { class AutoAugmentOp: - def __init__(self, name, prob, magnitude, hparams={}): + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT self.aug_fn = NAME_TO_OP[name] - self.level_fn = level_to_arg(hparams)[name] + self.level_fn = LEVEL_TO_ARG[name] self.prob = prob self.magnitude = magnitude - # If std deviation of magnitude is > 0, we introduce some randomness - # in the usually fixed policy and sample magnitude from normal dist - # with mean magnitude and std-dev of magnitude_std. - # NOTE This is being tested as it's not in paper or reference impl. - self.magnitude_std = 0.5 # FIXME add arg/hparam - self.kwargs = { - 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, - 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION - } + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + + # If magnitude_noise is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_noise`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_noise = self.hparams.get('magnitude_noise', 0) def __call__(self, img): - if self.prob < random.random(): + if random.random() > self.prob: return img magnitude = self.magnitude - if self.magnitude_std and self.magnitude_std > 0: - magnitude = random.gauss(magnitude, self.magnitude_std) - magnitude = min(_MAX_LEVEL, max(0, magnitude)) - level_args = self.level_fn(magnitude) + if self.magnitude_noise and self.magnitude_noise > 0: + magnitude = random.gauss(magnitude, self.magnitude_noise) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) -def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): - # ImageNet policy from TPU EfficientNet impl, cannot find - # a paper reference. +def auto_augment_policy_v0(hparams): + # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference. policy = [ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)], @@ -288,7 +320,7 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -298,27 +330,93 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc -def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): +def auto_augment_policy_v0r(hparams): + # ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_original(hparams): # ImageNet policy from https://arxiv.org/abs/1805.09501 policy = [ - [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], + [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_originalr(hparams): + # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation + policy = [ + [('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)], [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], - [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], + [('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)], [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], - [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], + [('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)], [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], - [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], + [('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)], [('Rotate', 0.8, 8), ('Color', 0.4, 0)], [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], @@ -335,15 +433,20 @@ def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc -def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT): +def auto_augment_policy(name='v0', hparams=None): + hparams = hparams or _HPARAMS_DEFAULT if name == 'original': return auto_augment_policy_original(hparams) + elif name == 'originalr': + return auto_augment_policy_originalr(hparams) elif name == 'v0': return auto_augment_policy_v0(hparams) + elif name == 'v0r': + return auto_augment_policy_v0r(hparams) else: assert False, 'Unknown AA policy (%s)' % name @@ -358,3 +461,78 @@ class AutoAugment: for op in sub_policy: img = op(img) return img + + +def auto_augment_transform(config_str, hparams): + config = config_str.split('-') + policy_name = config[0] + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) >= 2: + key, val = cs[:2] + if key == 'noise': + # noise param injected via hparams for now + hparams.setdefault('magnitude_noise', float(val)) + aa_policy = auto_augment_policy(policy_name, hparams=hparams) + return AutoAugment(aa_policy) + + +_RAND_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # FIXME I implement this as random erasing separately +] + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AutoAugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class RandAugment: + def __init__(self, ops, num_layers=2): + self.ops = ops + self.num_layers = num_layers + + def __call__(self, img): + for _ in range(self.num_layers): + op = random.choice(self.ops) + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + magnitude = 10 + num_layers = 2 + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) >= 2: + key, val = cs[:2] + if key == 'noise': + # noise param injected via hparams for now + hparams.setdefault('magnitude_noise', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + return RandAugment(ra_ops, num_layers) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 33911638..ac03b098 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -9,7 +9,7 @@ import numpy as np from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .random_erasing import RandomErasing -from .auto_augment import AutoAugment, auto_augment_policy +from .auto_augment import auto_augment_transform, rand_augment_transform class ToNumpy: @@ -179,6 +179,7 @@ def transforms_imagenet_train( transforms.RandomHorizontalFlip() ] if auto_augment: + assert isinstance(auto_augment, str) if isinstance(img_size, tuple): img_size_min = min(img_size) else: @@ -189,8 +190,10 @@ def transforms_imagenet_train( ) if interpolation and interpolation != 'random': aa_params['interpolation'] = _pil_interp(interpolation) - aa_policy = auto_augment_policy(auto_augment, aa_params) - tfl += [AutoAugment(aa_policy)] + if 'rand' in auto_augment: + tfl += [rand_augment_transform(auto_augment, aa_params)] + else: + tfl += [auto_augment_transform(auto_augment, aa_params)] else: # color jitter is enabled when not using AA if isinstance(color_jitter, (list, tuple)):