Working on auto-augment

pull/33/head
Ross Wightman 5 years ago
parent aff194f42c
commit 25d2088d9e

@ -0,0 +1,299 @@
import random
import math
from torchvision import transforms
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
import PIL
import numpy as np
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
_FILL = (128, 128, 128)
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.
_HPARAMS_DEFAULT = dict(
translate_const=250,
img_mean=_FILL,
)
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.NEAREST)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:
return interpolation
def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)
def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs['resample'])
def auto_contrast(img, **__):
return ImageOps.autocontrast(img)
def invert(img, **__):
return ImageOps.invert(img)
def equalize(img, **__):
return ImageOps.equalize(img)
def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)
def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img
def posterize(img, bits, **__):
return ImageOps.posterize(img, 4 - bits)
def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)
def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)
def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)
def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
def _randomly_negate(v):
"""With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v
def _rotate_level_to_arg(level):
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level)
return (level,)
def _enhance_level_to_arg(level):
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level):
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)
def _translate_abs_level_to_arg(level, translate_const):
level = (level / _MAX_LEVEL) * float(translate_const)
level = _randomly_negate(level)
return (level,)
def _translate_rel_level_to_arg(level):
level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level)
return (level,)
def level_to_arg(hparams):
return {
'AutoContrast': lambda level: (),
'Equalize': lambda level: (),
'Invert': lambda level: (),
'Rotate': _rotate_level_to_arg,
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg,
'Brightness': _enhance_level_to_arg,
'Sharpness': _enhance_level_to_arg,
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']),
'TranslateXRel': lambda level: _translate_rel_level_to_arg(level),
'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
}
NAME_TO_OP = {
'AutoContrast': auto_contrast,
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'Solarize': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'Contrast': contrast,
'Brightness': brightness,
'Sharpness': sharpness,
'ShearX': shear_x,
'ShearY': shear_y,
'TranslateX': translate_x_abs,
'TranslateY': translate_y_abs,
'TranslateXRel': translate_x_rel,
'TranslateYRel': translate_y_rel,
}
class AutoAugmentOp:
def __init__(self, name, prob, magnitude, hparams={}):
self.aug_fn = NAME_TO_OP[name]
self.level_fn = level_to_arg(hparams)[name]
self.prob = prob
self.magnitude = magnitude
self.kwargs = {
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL,
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION
}
self.rand_magnitude = True
def __call__(self, img):
if self.prob < random.random():
return img
magnitude = self.magnitude
if self.rand_magnitude:
magnitude = random.normalvariate(magnitude, 0.5)
magnitude = min(_MAX_LEVEL, max(0, magnitude))
level_args = self.level_fn(magnitude)
return self.aug_fn(img, *level_args, **self.kwargs)
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
"""Autoaugment policy that was used in AutoAugment Paper."""
# Each tuple is an augmentation operation of the form
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc
class AutoAugment:
def __init__(self, policy):
self.policy = policy
def __call__(self, img):
sub_policy = random.choice(self.policy)
for op in sub_policy:
img = op(img)
return img

@ -109,13 +109,21 @@ def create_transform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
if True:
transform = transforms_imagenet_aa(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_eval(
img_size,

@ -9,6 +9,7 @@ import numpy as np
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .random_erasing import RandomErasing
from .auto_augment import AutoAugment, auto_augment_policy_v0
class ToNumpy:
@ -160,6 +161,48 @@ class RandomResizedCropAndInterpolation(object):
return format_string
def transforms_imagenet_aa(
img_size=224,
scale=(0.08, 1.0),
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
aa_params = dict(
cutout_max_pad_fraction=0.75,
cutout_const=100,
translate_const=img_size[-1] // 2 - 1,
img_mean=tuple([min(255, round(255*x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
aa_policy = auto_augment_policy_v0(aa_params)
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
AutoAugment(aa_policy)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
return transforms.Compose(tfl)
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),

Loading…
Cancel
Save