diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py new file mode 100644 index 00000000..4a4c51c7 --- /dev/null +++ b/timm/data/auto_augment.py @@ -0,0 +1,299 @@ +import random +import math +from torchvision import transforms +from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw +import PIL +import numpy as np + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.NEAREST) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits, **__): + return ImageOps.posterize(img, 4 - bits) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level): + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level): + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _shear_level_to_arg(level): + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, translate_const): + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level): + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return (level,) + + +def level_to_arg(hparams): + return { + 'AutoContrast': lambda level: (), + 'Equalize': lambda level: (), + 'Invert': lambda level: (), + 'Rotate': _rotate_level_to_arg, + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), + 'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), + 'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), + 'TranslateYRel': lambda level: _translate_rel_level_to_arg(level), + } + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + + def __init__(self, name, prob, magnitude, hparams={}): + self.aug_fn = NAME_TO_OP[name] + self.level_fn = level_to_arg(hparams)[name] + self.prob = prob + self.magnitude = magnitude + self.kwargs = { + 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, + 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION + } + self.rand_magnitude = True + + def __call__(self, img): + if self.prob < random.random(): + return img + magnitude = self.magnitude + if self.rand_magnitude: + magnitude = random.normalvariate(magnitude, 0.5) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) + level_args = self.level_fn(magnitude) + return self.aug_fn(img, *level_args, **self.kwargs) + + +def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): + """Autoaugment policy that was used in AutoAugment Paper.""" + # Each tuple is an augmentation operation of the form + # (operation, probability, magnitude). Each element in policy is a + # sub-policy that will be applied sequentially on the image. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] + return pc + + +class AutoAugment: + + def __init__(self, policy): + self.policy = policy + + def __call__(self, img): + sub_policy = random.choice(self.policy) + for op in sub_policy: + img = op(img) + return img diff --git a/timm/data/loader.py b/timm/data/loader.py index 2a416b31..4bef4e88 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -109,13 +109,21 @@ def create_transform( is_training=is_training, size=img_size, interpolation=interpolation) else: if is_training: - transform = transforms_imagenet_train( - img_size, - color_jitter=color_jitter, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) + if True: + transform = transforms_imagenet_aa( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + else: + transform = transforms_imagenet_train( + img_size, + color_jitter=color_jitter, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) else: transform = transforms_imagenet_eval( img_size, diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 93796a04..49a36f85 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -9,6 +9,7 @@ import numpy as np from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .random_erasing import RandomErasing +from .auto_augment import AutoAugment, auto_augment_policy_v0 class ToNumpy: @@ -160,6 +161,48 @@ class RandomResizedCropAndInterpolation(object): return format_string +def transforms_imagenet_aa( + img_size=224, + scale=(0.08, 1.0), + interpolation='random', + random_erasing=0.4, + random_erasing_mode='const', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD +): + aa_params = dict( + cutout_max_pad_fraction=0.75, + cutout_const=100, + translate_const=img_size[-1] // 2 - 1, + img_mean=tuple([min(255, round(255*x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + aa_policy = auto_augment_policy_v0(aa_params) + + tfl = [ + RandomResizedCropAndInterpolation( + img_size, scale=scale, interpolation=interpolation), + transforms.RandomHorizontalFlip(), + AutoAugment(aa_policy) + ] + + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if random_erasing > 0.: + tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) + return transforms.Compose(tfl) + + def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0),