More AutoAugment work. Ready to roll...

pull/33/head
Ross Wightman 5 years ago
parent 25d2088d9e
commit b750b76f67

@ -1,7 +1,13 @@
""" Auto Augment
Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
Hacked together by Ross Wightman
"""
import random
import math
from torchvision import transforms
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
from PIL import Image, ImageOps, ImageEnhance
import PIL
import numpy as np
@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__):
return img
def posterize(img, bits, **__):
return ImageOps.posterize(img, 4 - bits)
def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
return ImageOps.posterize(img, bits_to_keep)
def contrast(img, factor, **__):
@ -157,16 +166,19 @@ def _randomly_negate(v):
def _rotate_level_to_arg(level):
# range [-30, 30]
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level)
return (level,)
def _enhance_level_to_arg(level):
# range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)
@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const):
def _translate_rel_level_to_arg(level):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level)
return (level,)
@ -190,9 +203,12 @@ def level_to_arg(hparams):
'Equalize': lambda level: (),
'Invert': lambda level: (),
'Rotate': _rotate_level_to_arg,
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
# FIXME these are both different from original impl as I believe there is a bug,
# not sure what is the correct alternative, hence 2 options that look better
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg,
'Brightness': _enhance_level_to_arg,
@ -212,6 +228,7 @@ NAME_TO_OP = {
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'Posterize2': posterize,
'Solarize': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
@ -252,10 +269,8 @@ class AutoAugmentOp:
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
"""Autoaugment policy that was used in AutoAugment Paper."""
# Each tuple is an augmentation operation of the form
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
# ImageNet policy from TPU EfficientNet impl, cannot find
# a paper reference.
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
return pc
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
if name == 'original':
return auto_augment_policy_original(hparams)
elif name == 'v0':
return auto_augment_policy_v0(hparams)
else:
assert False, 'Unknown AA policy (%s)' % name
class AutoAugment:
def __init__(self, policy):

@ -92,6 +92,7 @@ def create_transform(
is_training=False,
use_prefetcher=False,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
@ -109,21 +110,14 @@ def create_transform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
if True:
transform = transforms_imagenet_aa(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_eval(
img_size,
@ -146,6 +140,7 @@ def create_loader(
rand_erase_mode='const',
rand_erase_count=1,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
@ -161,6 +156,7 @@ def create_loader(
is_training=is_training,
use_prefetcher=use_prefetcher,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,

@ -9,7 +9,7 @@ import numpy as np
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .random_erasing import RandomErasing
from .auto_augment import AutoAugment, auto_augment_policy_v0
from .auto_augment import AutoAugment, auto_augment_policy
class ToNumpy:
@ -57,10 +57,10 @@ def _pil_interp(method):
return Image.BILINEAR
RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
class RandomResizedCropAndInterpolation(object):
class RandomResizedCropAndInterpolation:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
@ -85,7 +85,7 @@ class RandomResizedCropAndInterpolation(object):
warnings.warn("range should be of kind (min, max)")
if interpolation == 'random':
self.interpolation = RANDOM_INTERPOLATION
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = _pil_interp(interpolation)
self.scale = scale
@ -161,52 +161,11 @@ class RandomResizedCropAndInterpolation(object):
return format_string
def transforms_imagenet_aa(
img_size=224,
scale=(0.08, 1.0),
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
aa_params = dict(
cutout_max_pad_fraction=0.75,
cutout_const=100,
translate_const=img_size[-1] // 2 - 1,
img_mean=tuple([min(255, round(255*x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
aa_policy = auto_augment_policy_v0(aa_params)
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
AutoAugment(aa_policy)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
return transforms.Compose(tfl)
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
@ -214,20 +173,30 @@ def transforms_imagenet_train(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter),
transforms.RandomHorizontalFlip()
]
if auto_augment:
aa_params = dict(
translate_const=img_size[-1] // 2 - 1,
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
aa_policy = auto_augment_policy(auto_augment, aa_params)
tfl += [AutoAugment(aa_policy)]
else:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
tfl += [transforms.ColorJitter(*color_jitter)]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm

@ -89,6 +89,8 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
# Augmentation parameters
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const',
@ -287,6 +289,7 @@ def main():
rand_erase_mode=args.remode,
rand_erase_count=args.recount,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],

Loading…
Cancel
Save