You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
361 lines
12 KiB
361 lines
12 KiB
""" Auto Augment
|
|
Implementation adapted from:
|
|
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
|
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
|
|
|
|
Hacked together by Ross Wightman
|
|
"""
|
|
import random
|
|
import math
|
|
from PIL import Image, ImageOps, ImageEnhance
|
|
import PIL
|
|
import numpy as np
|
|
|
|
|
|
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
|
|
|
_FILL = (128, 128, 128)
|
|
|
|
# This signifies the max integer that the controller RNN could predict for the
|
|
# augmentation scheme.
|
|
_MAX_LEVEL = 10.
|
|
|
|
_HPARAMS_DEFAULT = dict(
|
|
translate_const=250,
|
|
img_mean=_FILL,
|
|
)
|
|
|
|
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
|
|
|
|
|
|
def _interpolation(kwargs):
|
|
interpolation = kwargs.pop('resample', Image.NEAREST)
|
|
if isinstance(interpolation, (list, tuple)):
|
|
return random.choice(interpolation)
|
|
else:
|
|
return interpolation
|
|
|
|
|
|
def _check_args_tf(kwargs):
|
|
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
|
kwargs.pop('fillcolor')
|
|
kwargs['resample'] = _interpolation(kwargs)
|
|
|
|
|
|
def shear_x(img, factor, **kwargs):
|
|
_check_args_tf(kwargs)
|
|
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
|
|
|
|
|
def shear_y(img, factor, **kwargs):
|
|
_check_args_tf(kwargs)
|
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
|
|
|
|
|
def translate_x_rel(img, pct, **kwargs):
|
|
pixels = pct * img.size[0]
|
|
_check_args_tf(kwargs)
|
|
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
|
|
|
|
|
def translate_y_rel(img, pct, **kwargs):
|
|
pixels = pct * img.size[1]
|
|
_check_args_tf(kwargs)
|
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
|
|
|
|
|
def translate_x_abs(img, pixels, **kwargs):
|
|
_check_args_tf(kwargs)
|
|
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
|
|
|
|
|
def translate_y_abs(img, pixels, **kwargs):
|
|
_check_args_tf(kwargs)
|
|
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
|
|
|
|
|
def rotate(img, degrees, **kwargs):
|
|
_check_args_tf(kwargs)
|
|
if _PIL_VER >= (5, 2):
|
|
return img.rotate(degrees, **kwargs)
|
|
elif _PIL_VER >= (5, 0):
|
|
w, h = img.size
|
|
post_trans = (0, 0)
|
|
rotn_center = (w / 2.0, h / 2.0)
|
|
angle = -math.radians(degrees)
|
|
matrix = [
|
|
round(math.cos(angle), 15),
|
|
round(math.sin(angle), 15),
|
|
0.0,
|
|
round(-math.sin(angle), 15),
|
|
round(math.cos(angle), 15),
|
|
0.0,
|
|
]
|
|
|
|
def transform(x, y, matrix):
|
|
(a, b, c, d, e, f) = matrix
|
|
return a * x + b * y + c, d * x + e * y + f
|
|
|
|
matrix[2], matrix[5] = transform(
|
|
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
|
)
|
|
matrix[2] += rotn_center[0]
|
|
matrix[5] += rotn_center[1]
|
|
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
|
else:
|
|
return img.rotate(degrees, resample=kwargs['resample'])
|
|
|
|
|
|
def auto_contrast(img, **__):
|
|
return ImageOps.autocontrast(img)
|
|
|
|
|
|
def invert(img, **__):
|
|
return ImageOps.invert(img)
|
|
|
|
|
|
def equalize(img, **__):
|
|
return ImageOps.equalize(img)
|
|
|
|
|
|
def solarize(img, thresh, **__):
|
|
return ImageOps.solarize(img, thresh)
|
|
|
|
|
|
def solarize_add(img, add, thresh=128, **__):
|
|
lut = []
|
|
for i in range(256):
|
|
if i < thresh:
|
|
lut.append(min(255, i + add))
|
|
else:
|
|
lut.append(i)
|
|
if img.mode in ("L", "RGB"):
|
|
if img.mode == "RGB" and len(lut) == 256:
|
|
lut = lut + lut + lut
|
|
return img.point(lut)
|
|
else:
|
|
return img
|
|
|
|
|
|
def posterize(img, bits_to_keep, **__):
|
|
if bits_to_keep >= 8:
|
|
return img
|
|
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
|
|
return ImageOps.posterize(img, bits_to_keep)
|
|
|
|
|
|
def contrast(img, factor, **__):
|
|
return ImageEnhance.Contrast(img).enhance(factor)
|
|
|
|
|
|
def color(img, factor, **__):
|
|
return ImageEnhance.Color(img).enhance(factor)
|
|
|
|
|
|
def brightness(img, factor, **__):
|
|
return ImageEnhance.Brightness(img).enhance(factor)
|
|
|
|
|
|
def sharpness(img, factor, **__):
|
|
return ImageEnhance.Sharpness(img).enhance(factor)
|
|
|
|
|
|
def _randomly_negate(v):
|
|
"""With 50% prob, negate the value"""
|
|
return -v if random.random() > 0.5 else v
|
|
|
|
|
|
def _rotate_level_to_arg(level):
|
|
# range [-30, 30]
|
|
level = (level / _MAX_LEVEL) * 30.
|
|
level = _randomly_negate(level)
|
|
return (level,)
|
|
|
|
|
|
def _enhance_level_to_arg(level):
|
|
# range [0.1, 1.9]
|
|
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
|
|
|
|
|
def _shear_level_to_arg(level):
|
|
# range [-0.3, 0.3]
|
|
level = (level / _MAX_LEVEL) * 0.3
|
|
level = _randomly_negate(level)
|
|
return (level,)
|
|
|
|
|
|
def _translate_abs_level_to_arg(level, translate_const):
|
|
level = (level / _MAX_LEVEL) * float(translate_const)
|
|
level = _randomly_negate(level)
|
|
return (level,)
|
|
|
|
|
|
def _translate_rel_level_to_arg(level):
|
|
# range [-0.45, 0.45]
|
|
level = (level / _MAX_LEVEL) * 0.45
|
|
level = _randomly_negate(level)
|
|
return (level,)
|
|
|
|
|
|
def level_to_arg(hparams):
|
|
return {
|
|
'AutoContrast': lambda level: (),
|
|
'Equalize': lambda level: (),
|
|
'Invert': lambda level: (),
|
|
'Rotate': _rotate_level_to_arg,
|
|
# FIXME these are both different from original impl as I believe there is a bug,
|
|
# not sure what is the correct alternative, hence 2 options that look better
|
|
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
|
|
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
|
|
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
|
|
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
|
|
'Color': _enhance_level_to_arg,
|
|
'Contrast': _enhance_level_to_arg,
|
|
'Brightness': _enhance_level_to_arg,
|
|
'Sharpness': _enhance_level_to_arg,
|
|
'ShearX': _shear_level_to_arg,
|
|
'ShearY': _shear_level_to_arg,
|
|
'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
|
|
'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
|
|
'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
|
|
'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
|
|
}
|
|
|
|
|
|
NAME_TO_OP = {
|
|
'AutoContrast': auto_contrast,
|
|
'Equalize': equalize,
|
|
'Invert': invert,
|
|
'Rotate': rotate,
|
|
'Posterize': posterize,
|
|
'Posterize2': posterize,
|
|
'Solarize': solarize,
|
|
'SolarizeAdd': solarize_add,
|
|
'Color': color,
|
|
'Contrast': contrast,
|
|
'Brightness': brightness,
|
|
'Sharpness': sharpness,
|
|
'ShearX': shear_x,
|
|
'ShearY': shear_y,
|
|
'TranslateX': translate_x_abs,
|
|
'TranslateY': translate_y_abs,
|
|
'TranslateXRel': translate_x_rel,
|
|
'TranslateYRel': translate_y_rel,
|
|
}
|
|
|
|
|
|
class AutoAugmentOp:
|
|
|
|
def __init__(self, name, prob, magnitude, hparams={}):
|
|
self.aug_fn = NAME_TO_OP[name]
|
|
self.level_fn = level_to_arg(hparams)[name]
|
|
self.prob = prob
|
|
self.magnitude = magnitude
|
|
# If std deviation of magnitude is > 0, we introduce some randomness
|
|
# in the usually fixed policy and sample magnitude from normal dist
|
|
# with mean magnitude and std-dev of magnitude_std.
|
|
# NOTE This is being tested as it's not in paper or reference impl.
|
|
self.magnitude_std = 0.5 # FIXME add arg/hparam
|
|
self.kwargs = {
|
|
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
|
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
|
|
}
|
|
|
|
def __call__(self, img):
|
|
if self.prob < random.random():
|
|
return img
|
|
magnitude = self.magnitude
|
|
if self.magnitude_std and self.magnitude_std > 0:
|
|
magnitude = random.gauss(magnitude, self.magnitude_std)
|
|
magnitude = min(_MAX_LEVEL, max(0, magnitude))
|
|
level_args = self.level_fn(magnitude)
|
|
return self.aug_fn(img, *level_args, **self.kwargs)
|
|
|
|
|
|
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
|
# ImageNet policy from TPU EfficientNet impl, cannot find
|
|
# a paper reference.
|
|
policy = [
|
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
|
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
|
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
|
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
|
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
|
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
|
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
|
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
|
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
|
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
|
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
|
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
|
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
|
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
|
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
|
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
|
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
|
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
|
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
|
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
|
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
|
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
|
|
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
|
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
|
]
|
|
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
|
|
return pc
|
|
|
|
|
|
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
|
|
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
|
policy = [
|
|
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
|
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
|
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
|
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
|
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
|
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
|
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
|
|
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
|
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
|
|
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
|
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
|
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
|
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
|
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
|
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
|
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
|
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
|
]
|
|
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
|
|
return pc
|
|
|
|
|
|
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
|
|
if name == 'original':
|
|
return auto_augment_policy_original(hparams)
|
|
elif name == 'v0':
|
|
return auto_augment_policy_v0(hparams)
|
|
else:
|
|
assert False, 'Unknown AA policy (%s)' % name
|
|
|
|
|
|
class AutoAugment:
|
|
|
|
def __init__(self, policy):
|
|
self.policy = policy
|
|
|
|
def __call__(self, img):
|
|
sub_policy = random.choice(self.policy)
|
|
for op in sub_policy:
|
|
img = op(img)
|
|
return img
|