More AutoAugment work. Ready to roll...

pull/33/head
Ross Wightman 5 years ago
parent 25d2088d9e
commit b750b76f67

@ -1,7 +1,13 @@
""" Auto Augment
Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
Hacked together by Ross Wightman
"""
import random import random
import math import math
from torchvision import transforms from PIL import Image, ImageOps, ImageEnhance
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
import PIL import PIL
import numpy as np import numpy as np
@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__):
return img return img
def posterize(img, bits, **__): def posterize(img, bits_to_keep, **__):
return ImageOps.posterize(img, 4 - bits) if bits_to_keep >= 8:
return img
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
return ImageOps.posterize(img, bits_to_keep)
def contrast(img, factor, **__): def contrast(img, factor, **__):
@ -157,16 +166,19 @@ def _randomly_negate(v):
def _rotate_level_to_arg(level): def _rotate_level_to_arg(level):
# range [-30, 30]
level = (level / _MAX_LEVEL) * 30. level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return (level,)
def _enhance_level_to_arg(level): def _enhance_level_to_arg(level):
# range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,) return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level): def _shear_level_to_arg(level):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3 level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return (level,)
@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const):
def _translate_rel_level_to_arg(level): def _translate_rel_level_to_arg(level):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45 level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return (level,)
@ -190,9 +203,12 @@ def level_to_arg(hparams):
'Equalize': lambda level: (), 'Equalize': lambda level: (),
'Invert': lambda level: (), 'Invert': lambda level: (),
'Rotate': _rotate_level_to_arg, 'Rotate': _rotate_level_to_arg,
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), # FIXME these are both different from original impl as I believe there is a bug,
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # not sure what is the correct alternative, hence 2 options that look better
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
'Color': _enhance_level_to_arg, 'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg,
'Brightness': _enhance_level_to_arg, 'Brightness': _enhance_level_to_arg,
@ -212,6 +228,7 @@ NAME_TO_OP = {
'Invert': invert, 'Invert': invert,
'Rotate': rotate, 'Rotate': rotate,
'Posterize': posterize, 'Posterize': posterize,
'Posterize2': posterize,
'Solarize': solarize, 'Solarize': solarize,
'SolarizeAdd': solarize_add, 'SolarizeAdd': solarize_add,
'Color': color, 'Color': color,
@ -252,10 +269,8 @@ class AutoAugmentOp:
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
"""Autoaugment policy that was used in AutoAugment Paper.""" # ImageNet policy from TPU EfficientNet impl, cannot find
# Each tuple is an augmentation operation of the form # a paper reference.
# (operation, probability, magnitude). Each element in policy is a
# sub-policy that will be applied sequentially on the image.
policy = [ policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
return pc return pc
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
if name == 'original':
return auto_augment_policy_original(hparams)
elif name == 'v0':
return auto_augment_policy_v0(hparams)
else:
assert False, 'Unknown AA policy (%s)' % name
class AutoAugment: class AutoAugment:
def __init__(self, policy): def __init__(self, policy):

@ -92,6 +92,7 @@ def create_transform(
is_training=False, is_training=False,
use_prefetcher=False, use_prefetcher=False,
color_jitter=0.4, color_jitter=0.4,
auto_augment=None,
interpolation='bilinear', interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
@ -109,17 +110,10 @@ def create_transform(
is_training=is_training, size=img_size, interpolation=interpolation) is_training=is_training, size=img_size, interpolation=interpolation)
else: else:
if is_training: if is_training:
if True:
transform = transforms_imagenet_aa(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_train( transform = transforms_imagenet_train(
img_size, img_size,
color_jitter=color_jitter, color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation, interpolation=interpolation,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
mean=mean, mean=mean,
@ -146,6 +140,7 @@ def create_loader(
rand_erase_mode='const', rand_erase_mode='const',
rand_erase_count=1, rand_erase_count=1,
color_jitter=0.4, color_jitter=0.4,
auto_augment=None,
interpolation='bilinear', interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
@ -161,6 +156,7 @@ def create_loader(
is_training=is_training, is_training=is_training,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
color_jitter=color_jitter, color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation, interpolation=interpolation,
mean=mean, mean=mean,
std=std, std=std,

@ -9,7 +9,7 @@ import numpy as np
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .random_erasing import RandomErasing from .random_erasing import RandomErasing
from .auto_augment import AutoAugment, auto_augment_policy_v0 from .auto_augment import AutoAugment, auto_augment_policy
class ToNumpy: class ToNumpy:
@ -57,10 +57,10 @@ def _pil_interp(method):
return Image.BILINEAR return Image.BILINEAR
RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
class RandomResizedCropAndInterpolation(object): class RandomResizedCropAndInterpolation:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation. """Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random A crop of random size (default: of 0.08 to 1.0) of the original size and a random
@ -85,7 +85,7 @@ class RandomResizedCropAndInterpolation(object):
warnings.warn("range should be of kind (min, max)") warnings.warn("range should be of kind (min, max)")
if interpolation == 'random': if interpolation == 'random':
self.interpolation = RANDOM_INTERPOLATION self.interpolation = _RANDOM_INTERPOLATION
else: else:
self.interpolation = _pil_interp(interpolation) self.interpolation = _pil_interp(interpolation)
self.scale = scale self.scale = scale
@ -161,9 +161,11 @@ class RandomResizedCropAndInterpolation(object):
return format_string return format_string
def transforms_imagenet_aa( def transforms_imagenet_train(
img_size=224, img_size=224,
scale=(0.08, 1.0), scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random', interpolation='random',
random_erasing=0.4, random_erasing=0.4,
random_erasing_mode='const', random_erasing_mode='const',
@ -171,49 +173,22 @@ def transforms_imagenet_aa(
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD std=IMAGENET_DEFAULT_STD
): ):
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
if auto_augment:
aa_params = dict( aa_params = dict(
cutout_max_pad_fraction=0.75,
cutout_const=100,
translate_const=img_size[-1] // 2 - 1, translate_const=img_size[-1] // 2 - 1,
img_mean=tuple([min(255, round(255 * x)) for x in mean]), img_mean=tuple([min(255, round(255 * x)) for x in mean]),
) )
if interpolation and interpolation != 'random': if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation) aa_params['interpolation'] = _pil_interp(interpolation)
aa_policy = auto_augment_policy_v0(aa_params) aa_policy = auto_augment_policy(auto_augment, aa_params)
tfl += [AutoAugment(aa_policy)]
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
AutoAugment(aa_policy)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else: else:
tfl += [ # color jitter is enabled when not using AA
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
return transforms.Compose(tfl)
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
if isinstance(color_jitter, (list, tuple)): if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue # or 4 if also augmenting hue
@ -221,13 +196,7 @@ def transforms_imagenet_train(
else: else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3 color_jitter = (float(color_jitter),) * 3
tfl += [transforms.ColorJitter(*color_jitter)]
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter),
]
if use_prefetcher: if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm # prefetcher and collate will handle tensor conversion and norm

@ -89,6 +89,8 @@ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RA
# Augmentation parameters # Augmentation parameters
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)') help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--reprob', type=float, default=0., metavar='PCT', parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const', parser.add_argument('--remode', type=str, default='const',
@ -287,6 +289,7 @@ def main():
rand_erase_mode=args.remode, rand_erase_mode=args.remode,
rand_erase_count=args.recount, rand_erase_count=args.recount,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],

Loading…
Cancel
Save