|
|
@ -1,7 +1,13 @@
|
|
|
|
|
|
|
|
""" Auto Augment
|
|
|
|
|
|
|
|
Implementation adapted from:
|
|
|
|
|
|
|
|
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
|
|
|
|
|
|
|
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Hacked together by Ross Wightman
|
|
|
|
|
|
|
|
"""
|
|
|
|
import random
|
|
|
|
import random
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
from torchvision import transforms
|
|
|
|
from PIL import Image, ImageOps, ImageEnhance
|
|
|
|
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageDraw
|
|
|
|
|
|
|
|
import PIL
|
|
|
|
import PIL
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
@ -131,8 +137,11 @@ def solarize_add(img, add, thresh=128, **__):
|
|
|
|
return img
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def posterize(img, bits, **__):
|
|
|
|
def posterize(img, bits_to_keep, **__):
|
|
|
|
return ImageOps.posterize(img, 4 - bits)
|
|
|
|
if bits_to_keep >= 8:
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
|
|
|
|
|
|
|
|
return ImageOps.posterize(img, bits_to_keep)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def contrast(img, factor, **__):
|
|
|
|
def contrast(img, factor, **__):
|
|
|
@ -157,16 +166,19 @@ def _randomly_negate(v):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rotate_level_to_arg(level):
|
|
|
|
def _rotate_level_to_arg(level):
|
|
|
|
|
|
|
|
# range [-30, 30]
|
|
|
|
level = (level / _MAX_LEVEL) * 30.
|
|
|
|
level = (level / _MAX_LEVEL) * 30.
|
|
|
|
level = _randomly_negate(level)
|
|
|
|
level = _randomly_negate(level)
|
|
|
|
return (level,)
|
|
|
|
return (level,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _enhance_level_to_arg(level):
|
|
|
|
def _enhance_level_to_arg(level):
|
|
|
|
|
|
|
|
# range [0.1, 1.9]
|
|
|
|
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
|
|
|
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _shear_level_to_arg(level):
|
|
|
|
def _shear_level_to_arg(level):
|
|
|
|
|
|
|
|
# range [-0.3, 0.3]
|
|
|
|
level = (level / _MAX_LEVEL) * 0.3
|
|
|
|
level = (level / _MAX_LEVEL) * 0.3
|
|
|
|
level = _randomly_negate(level)
|
|
|
|
level = _randomly_negate(level)
|
|
|
|
return (level,)
|
|
|
|
return (level,)
|
|
|
@ -179,6 +191,7 @@ def _translate_abs_level_to_arg(level, translate_const):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _translate_rel_level_to_arg(level):
|
|
|
|
def _translate_rel_level_to_arg(level):
|
|
|
|
|
|
|
|
# range [-0.45, 0.45]
|
|
|
|
level = (level / _MAX_LEVEL) * 0.45
|
|
|
|
level = (level / _MAX_LEVEL) * 0.45
|
|
|
|
level = _randomly_negate(level)
|
|
|
|
level = _randomly_negate(level)
|
|
|
|
return (level,)
|
|
|
|
return (level,)
|
|
|
@ -190,9 +203,12 @@ def level_to_arg(hparams):
|
|
|
|
'Equalize': lambda level: (),
|
|
|
|
'Equalize': lambda level: (),
|
|
|
|
'Invert': lambda level: (),
|
|
|
|
'Invert': lambda level: (),
|
|
|
|
'Rotate': _rotate_level_to_arg,
|
|
|
|
'Rotate': _rotate_level_to_arg,
|
|
|
|
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
|
|
|
|
# FIXME these are both different from original impl as I believe there is a bug,
|
|
|
|
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
|
|
|
|
# not sure what is the correct alternative, hence 2 options that look better
|
|
|
|
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
|
|
|
|
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8]
|
|
|
|
|
|
|
|
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0]
|
|
|
|
|
|
|
|
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
|
|
|
|
|
|
|
|
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
|
|
|
|
'Color': _enhance_level_to_arg,
|
|
|
|
'Color': _enhance_level_to_arg,
|
|
|
|
'Contrast': _enhance_level_to_arg,
|
|
|
|
'Contrast': _enhance_level_to_arg,
|
|
|
|
'Brightness': _enhance_level_to_arg,
|
|
|
|
'Brightness': _enhance_level_to_arg,
|
|
|
@ -212,6 +228,7 @@ NAME_TO_OP = {
|
|
|
|
'Invert': invert,
|
|
|
|
'Invert': invert,
|
|
|
|
'Rotate': rotate,
|
|
|
|
'Rotate': rotate,
|
|
|
|
'Posterize': posterize,
|
|
|
|
'Posterize': posterize,
|
|
|
|
|
|
|
|
'Posterize2': posterize,
|
|
|
|
'Solarize': solarize,
|
|
|
|
'Solarize': solarize,
|
|
|
|
'SolarizeAdd': solarize_add,
|
|
|
|
'SolarizeAdd': solarize_add,
|
|
|
|
'Color': color,
|
|
|
|
'Color': color,
|
|
|
@ -252,10 +269,8 @@ class AutoAugmentOp:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
|
|
|
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
|
|
|
"""Autoaugment policy that was used in AutoAugment Paper."""
|
|
|
|
# ImageNet policy from TPU EfficientNet impl, cannot find
|
|
|
|
# Each tuple is an augmentation operation of the form
|
|
|
|
# a paper reference.
|
|
|
|
# (operation, probability, magnitude). Each element in policy is a
|
|
|
|
|
|
|
|
# sub-policy that will be applied sequentially on the image.
|
|
|
|
|
|
|
|
policy = [
|
|
|
|
policy = [
|
|
|
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
|
|
|
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
|
|
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
|
|
|
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
|
|
@ -287,6 +302,48 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
|
|
|
|
return pc
|
|
|
|
return pc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
|
|
|
|
|
|
|
|
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
|
|
|
|
|
|
|
policy = [
|
|
|
|
|
|
|
|
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
|
|
|
|
|
|
|
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
|
|
|
|
|
|
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
|
|
|
|
|
|
|
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
|
|
|
|
|
|
|
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
|
|
|
|
|
|
|
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
|
|
|
|
|
|
|
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
|
|
|
|
|
|
|
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
|
|
|
|
|
|
|
|
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
|
|
|
|
|
|
|
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
|
|
|
|
|
|
|
|
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
|
|
|
|
|
|
|
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
|
|
|
|
|
|
|
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
|
|
|
|
|
|
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
|
|
|
|
|
|
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
|
|
|
|
|
|
|
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
|
|
|
|
|
|
|
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
|
|
|
|
|
|
|
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
|
|
|
|
|
|
|
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
|
|
|
|
|
|
|
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
|
|
|
|
|
|
|
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
|
|
|
|
|
|
|
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
|
|
|
|
|
|
|
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
|
|
|
|
|
|
|
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
|
|
|
|
|
|
|
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy]
|
|
|
|
|
|
|
|
return pc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT):
|
|
|
|
|
|
|
|
if name == 'original':
|
|
|
|
|
|
|
|
return auto_augment_policy_original(hparams)
|
|
|
|
|
|
|
|
elif name == 'v0':
|
|
|
|
|
|
|
|
return auto_augment_policy_v0(hparams)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert False, 'Unknown AA policy (%s)' % name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AutoAugment:
|
|
|
|
class AutoAugment:
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, policy):
|
|
|
|
def __init__(self, policy):
|
|
|
|