From 25d2088d9e621b10d8a8bc15bea50ea02bd05a4d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 31 Aug 2019 23:09:48 -0700 Subject: [PATCH 1/4] Working on auto-augment --- timm/data/auto_augment.py | 299 ++++++++++++++++++++++++++++++++++++++ timm/data/loader.py | 22 ++- timm/data/transforms.py | 43 ++++++ 3 files changed, 357 insertions(+), 7 deletions(-) create mode 100644 timm/data/auto_augment.py diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py new file mode 100644 index 00000000..4a4c51c7 --- /dev/null +++ b/timm/data/auto_augment.py @@ -0,0 +1,299 @@ +import random +import math +from torchvision import transforms +from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw +import PIL +import numpy as np + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.NEAREST) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits, **__): + return ImageOps.posterize(img, 4 - bits) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level): + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level): + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, translate_const): + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level): + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return (level,) + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), + 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), + 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), + 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level), + } + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + + def __init__(self, name, prob, magnitude, hparams={}): + self.aug_fn = NAME_TO_OP[name] + self.level_fn = level_to_arg(hparams)[name] + self.prob = prob + self.magnitude = magnitude + self.kwargs = { + 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, + 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION + } + self.rand_magnitude = True + + def __call__(self, img): + if self.prob < random.random(): + return img + magnitude = self.magnitude + if self.rand_magnitude: + magnitude = random.normalvariate(magnitude, 0.5) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) + level_args = self.level_fn(magnitude) + return self.aug_fn(img, *level_args, **self.kwargs) + + +def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): + """Autoaugment policy that was used in AutoAugment Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + return pc + + +class AutoAugment: + + def __init__(self, policy): + self.policy = policy + + def __call__(self, img): + sub_policy = random.choice(self.policy) + for op in sub_policy: + img = op(img) + return img diff --git a/timm/data/loader.py b/timm/data/loader.py index 2a416b31..4bef4e88 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -109,13 +109,21 @@ def create_transform( is_training=is_training, size=img_size, interpolation=interpolation) else: if is_training: - transform = transforms_imagenet_train( - img_size, - color_jitter=color_jitter, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) + if True: + transform = transforms_imagenet_aa( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + else: + transform = transforms_imagenet_train( + img_size, + color_jitter=color_jitter, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) else: transform = transforms_imagenet_eval( img_size, diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 93796a04..49a36f85 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -9,6 +9,7 @@ import numpy as np from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .random_erasing import RandomErasing +from .auto_augment import AutoAugment, auto_augment_policy_v0 class ToNumpy: @@ -160,6 +161,48 @@ class RandomResizedCropAndInterpolation(object): return format_string +def transforms_imagenet_aa( + img_size=224, + scale=(0.08, 1.0), + interpolation='random', + random_erasing=0.4, + random_erasing_mode='const', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD +): + aa_params = dict( + cutout_max_pad_fraction=0.75, + cutout_const=100, + translate_const=img_size[-1] // 2 - 1, + img_mean=tuple([min(255, round(255*x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + aa_policy = auto_augment_policy_v0(aa_params) + + tfl = [ + RandomResizedCropAndInterpolation( + img_size, scale=scale, interpolation=interpolation), + transforms.RandomHorizontalFlip(), + AutoAugment(aa_policy) + ] + + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if random_erasing > 0.: + tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) + return transforms.Compose(tfl) + + def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0), From b750b76f6736bb8bb23beb84594a480852a6c619 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 1 Sep 2019 16:55:42 -0700 Subject: [PATCH 2/4] More AutoAugment work. Ready to roll... --- timm/data/auto_augment.py | 79 ++++++++++++++++++++++++++++++++------ timm/data/loader.py | 26 ++++++------- timm/data/transforms.py | 81 ++++++++++++--------------------------- train.py | 3 ++ 4 files changed, 107 insertions(+), 82 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 4a4c51c7..a148f771 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,7 +1,13 @@ +""" Auto Augment +Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py +Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172 + +Hacked together by Ross Wightman +""" import random import math -from torchvision import transforms -from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw +from PIL import Image, ImageOps, ImageEnhance import PIL import numpy as np @@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__): return img -def posterize(img, bits, **__): - return ImageOps.posterize(img, 4 - bits) +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + bits_to_keep = max(1, bits_to_keep) # prevent all 0 images + return ImageOps.posterize(img, bits_to_keep) def contrast(img, factor, **__): @@ -157,16 +166,19 @@ def _randomly_negate(v): def _rotate_level_to_arg(level): + # range [-30, 30] level = (level / _MAX_LEVEL) * 30. level = _randomly_negate(level) return (level,) def _enhance_level_to_arg(level): + # range [0.1, 1.9] return ((level / _MAX_LEVEL) * 1.8 + 0.1,) def _shear_level_to_arg(level): + # range [-0.3, 0.3] level = (level / _MAX_LEVEL) * 0.3 level = _randomly_negate(level) return (level,) @@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const): def _translate_rel_level_to_arg(level): + # range [-0.45, 0.45] level = (level / _MAX_LEVEL) * 0.45 level = _randomly_negate(level) return (level,) @@ -190,9 +203,12 @@ def level_to_arg(hparams): 'Equalize': lambda level: (), 'Invert': lambda level: (), 'Rotate': _rotate_level_to_arg, - 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), - 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), - 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), + # FIXME these are both different from original impl as I believe there is a bug, + # not sure what is the correct alternative, hence 2 options that look better + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8] + 'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0] + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256] + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110] 'Color': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg, 'Brightness': _enhance_level_to_arg, @@ -212,6 +228,7 @@ NAME_TO_OP = { 'Invert': invert, 'Rotate': rotate, 'Posterize': posterize, + 'Posterize2': posterize, 'Solarize': solarize, 'SolarizeAdd': solarize_add, 'Color': color, @@ -252,10 +269,8 @@ class AutoAugmentOp: def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): - """Autoaugment policy that was used in AutoAugment Paper.""" - # Each tuple is an augmentation operation of the form - # (operation, probability, magnitude). Each element in policy is a - # sub-policy that will be applied sequentially on the image. + # ImageNet policy from TPU EfficientNet impl, cannot find + # a paper reference. policy = [ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)], @@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): return pc +def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): + # ImageNet policy from https://arxiv.org/abs/1805.09501 + policy = [ + [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT): + if name == 'original': + return auto_augment_policy_original(hparams) + elif name == 'v0': + return auto_augment_policy_v0(hparams) + else: + assert False, 'Unknown AA policy (%s)' % name + + class AutoAugment: def __init__(self, policy): diff --git a/timm/data/loader.py b/timm/data/loader.py index 4bef4e88..8c27f1bb 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -92,6 +92,7 @@ def create_transform( is_training=False, use_prefetcher=False, color_jitter=0.4, + auto_augment=None, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -109,21 +110,14 @@ def create_transform( is_training=is_training, size=img_size, interpolation=interpolation) else: if is_training: - if True: - transform = transforms_imagenet_aa( - img_size, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) - else: - transform = transforms_imagenet_train( - img_size, - color_jitter=color_jitter, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) + transform = transforms_imagenet_train( + img_size, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) else: transform = transforms_imagenet_eval( img_size, @@ -146,6 +140,7 @@ def create_loader( rand_erase_mode='const', rand_erase_count=1, color_jitter=0.4, + auto_augment=None, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -161,6 +156,7 @@ def create_loader( is_training=is_training, use_prefetcher=use_prefetcher, color_jitter=color_jitter, + auto_augment=auto_augment, interpolation=interpolation, mean=mean, std=std, diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 49a36f85..d4f67bf9 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -9,7 +9,7 @@ import numpy as np from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .random_erasing import RandomErasing -from .auto_augment import AutoAugment, auto_augment_policy_v0 +from .auto_augment import AutoAugment, auto_augment_policy class ToNumpy: @@ -57,10 +57,10 @@ def _pil_interp(method): return Image.BILINEAR -RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) -class RandomResizedCropAndInterpolation(object): +class RandomResizedCropAndInterpolation: """Crop the given PIL Image to random size and aspect ratio with random interpolation. A crop of random size (default: of 0.08 to 1.0) of the original size and a random @@ -85,7 +85,7 @@ class RandomResizedCropAndInterpolation(object): warnings.warn("range should be of kind (min, max)") if interpolation == 'random': - self.interpolation = RANDOM_INTERPOLATION + self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = _pil_interp(interpolation) self.scale = scale @@ -161,52 +161,11 @@ class RandomResizedCropAndInterpolation(object): return format_string -def transforms_imagenet_aa( - img_size=224, - scale=(0.08, 1.0), - interpolation='random', - random_erasing=0.4, - random_erasing_mode='const', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD -): - aa_params = dict( - cutout_max_pad_fraction=0.75, - cutout_const=100, - translate_const=img_size[-1] // 2 - 1, - img_mean=tuple([min(255, round(255*x)) for x in mean]), - ) - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - aa_policy = auto_augment_policy_v0(aa_params) - - tfl = [ - RandomResizedCropAndInterpolation( - img_size, scale=scale, interpolation=interpolation), - transforms.RandomHorizontalFlip(), - AutoAugment(aa_policy) - ] - - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - if random_erasing > 0.: - tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) - return transforms.Compose(tfl) - - def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0), color_jitter=0.4, + auto_augment=None, interpolation='random', random_erasing=0.4, random_erasing_mode='const', @@ -214,20 +173,30 @@ def transforms_imagenet_train( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ): - if isinstance(color_jitter, (list, tuple)): - # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation - # or 4 if also augmenting hue - assert len(color_jitter) in (3, 4) - else: - # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue - color_jitter = (float(color_jitter),) * 3 - tfl = [ RandomResizedCropAndInterpolation( img_size, scale=scale, interpolation=interpolation), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(*color_jitter), + transforms.RandomHorizontalFlip() ] + if auto_augment: + aa_params = dict( + translate_const=img_size[-1] // 2 - 1, + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + aa_policy = auto_augment_policy(auto_augment, aa_params) + tfl += [AutoAugment(aa_policy)] + else: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + tfl += [transforms.ColorJitter(*color_jitter)] if use_prefetcher: # prefetcher and collate will handle tensor conversion and norm diff --git a/train.py b/train.py index 78c8a12c..f6b6e407 100644 --- a/train.py +++ b/train.py @@ -89,6 +89,8 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA # Augmentation parameters parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') +parser.add_argument('--aa', type=str, default=None, metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". (default: None)'), parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', @@ -287,6 +289,7 @@ def main(): rand_erase_mode=args.remode, rand_erase_count=args.recount, color_jitter=args.color_jitter, + auto_augment=args.aa, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], From c06274e5a20f307ad5a7299cfd1866d467edeac7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 1 Sep 2019 20:32:26 -0700 Subject: [PATCH 3/4] Add note on random selection of magnitude value --- timm/data/auto_augment.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index a148f771..04c0b60a 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -251,18 +251,22 @@ class AutoAugmentOp: self.level_fn = level_to_arg(hparams)[name] self.prob = prob self.magnitude = magnitude + # If std deviation of magnitude is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from normal dist + # with mean magnitude and std-dev of magnitude_std. + # NOTE This is being tested as it's not in paper or reference impl. + self.magnitude_std = 0.5 # FIXME add arg/hparam self.kwargs = { 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION } - self.rand_magnitude = True def __call__(self, img): if self.prob < random.random(): return img magnitude = self.magnitude - if self.rand_magnitude: - magnitude = random.normalvariate(magnitude, 0.5) + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) level_args = self.level_fn(magnitude) return self.aug_fn(img, *level_args, **self.kwargs) From 4002c0d4ce4db3c9d8acf23138c3675b1c5c395f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 1 Sep 2019 22:07:45 -0700 Subject: [PATCH 4/4] Fix AutoAugment abs translate calc --- timm/data/transforms.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index d4f67bf9..33911638 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -179,8 +179,12 @@ def transforms_imagenet_train( transforms.RandomHorizontalFlip() ] if auto_augment: + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size aa_params = dict( - translate_const=img_size[-1] // 2 - 1, + translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random':