Merge pull request #52 from rwightman/randaugment

RandAugment and more
pull/62/head
Ross Wightman 5 years ago committed by GitHub
commit db04677c94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -69,6 +69,7 @@ Several (less common) features that I often utilize in my projects are included.
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc) * Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing * Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
* An inference script that dumps output to CSV is provided as an example * An inference script that dumps output to CSV is provided as an example
* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
## Results ## Results

@ -4,3 +4,5 @@ from .dataset import Dataset, DatasetTar
from .transforms import * from .transforms import *
from .loader import create_loader, create_transform from .loader import create_loader, create_transform
from .mixup import mixup_target, FastCollateMixup from .mixup import mixup_target, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform

@ -1,17 +1,19 @@
""" Auto Augment """ AutoAugment and RandAugment
Implementation adapted from: Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172 Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
import random import random
import math import math
import re
from PIL import Image, ImageOps, ImageEnhance from PIL import Image, ImageOps, ImageEnhance
import PIL import PIL
import numpy as np import numpy as np
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
_FILL = (128, 128, 128) _FILL = (128, 128, 128)
@ -25,11 +27,11 @@ _HPARAMS_DEFAULT = dict(
img_mean=_FILL, img_mean=_FILL,
) )
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
def _interpolation(kwargs): def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.NEAREST) interpolation = kwargs.pop('resample', Image.BILINEAR)
if isinstance(interpolation, (list, tuple)): if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation) return random.choice(interpolation)
else: else:
@ -140,7 +142,6 @@ def solarize_add(img, add, thresh=128, **__):
def posterize(img, bits_to_keep, **__): def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8: if bits_to_keep >= 8:
return img return img
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
return ImageOps.posterize(img, bits_to_keep) return ImageOps.posterize(img, bits_to_keep)
@ -165,61 +166,89 @@ def _randomly_negate(v):
return -v if random.random() > 0.5 else v return -v if random.random() > 0.5 else v
def _rotate_level_to_arg(level): def _rotate_level_to_arg(level, _hparams):
# range [-30, 30] # range [-30, 30]
level = (level / _MAX_LEVEL) * 30. level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return level,
def _enhance_level_to_arg(level): def _enhance_level_to_arg(level, _hparams):
# range [0.1, 1.9] # range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,) return (level / _MAX_LEVEL) * 1.8 + 0.1,
def _shear_level_to_arg(level): def _shear_level_to_arg(level, _hparams):
# range [-0.3, 0.3] # range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3 level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return level,
def _translate_abs_level_to_arg(level, translate_const): def _translate_abs_level_to_arg(level, hparams):
translate_const = hparams['translate_const']
level = (level / _MAX_LEVEL) * float(translate_const) level = (level / _MAX_LEVEL) * float(translate_const)
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return level,
def _translate_rel_level_to_arg(level): def _translate_rel_level_to_arg(level, _hparams):
# range [-0.45, 0.45] # range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45 level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level) level = _randomly_negate(level)
return (level,) return level,
def level_to_arg(hparams): def _posterize_original_level_to_arg(level, _hparams):
return { # As per original AutoAugment paper description
'AutoContrast': lambda level: (), # range [4, 8], 'keep 4 up to 8 MSB of image'
'Equalize': lambda level: (), return int((level / _MAX_LEVEL) * 4) + 4,
'Invert': lambda level: (),
'Rotate': _rotate_level_to_arg,
# FIXME these are both different from original impl as I believe there is a bug, def _posterize_research_level_to_arg(level, _hparams):
# not sure what is the correct alternative, hence 2 options that look better # As per Tensorflow models research and UDA impl
'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4) + 4,), # range [4, 8] # range [4, 0], 'keep 4 down to 0 MSB of original image'
'Posterize2': lambda level: (4 - int((level / _MAX_LEVEL) * 4),), # range [4, 0] return 4 - int((level / _MAX_LEVEL) * 4),
'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), # range [0, 256]
'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), # range [0, 110]
'Color': _enhance_level_to_arg, def _posterize_tpu_level_to_arg(level, _hparams):
'Contrast': _enhance_level_to_arg, # As per Tensorflow TPU EfficientNet impl
'Brightness': _enhance_level_to_arg, # range [0, 4], 'keep 0 up to 4 MSB of original image'
'Sharpness': _enhance_level_to_arg, return int((level / _MAX_LEVEL) * 4),
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'TranslateX': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), def _solarize_level_to_arg(level, _hparams):
'TranslateY': lambda level: _translate_abs_level_to_arg(level, hparams['translate_const']), # range [0, 256]
'TranslateXRel': lambda level: _translate_rel_level_to_arg(level), return int((level / _MAX_LEVEL) * 256),
'TranslateYRel': lambda level: _translate_rel_level_to_arg(level),
}
def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110]
return int((level / _MAX_LEVEL) * 110),
LEVEL_TO_ARG = {
'AutoContrast': None,
'Equalize': None,
'Invert': None,
'Rotate': _rotate_level_to_arg,
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
'PosterizeOriginal': _posterize_original_level_to_arg,
'PosterizeResearch': _posterize_research_level_to_arg,
'PosterizeTpu': _posterize_tpu_level_to_arg,
'Solarize': _solarize_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg,
'Brightness': _enhance_level_to_arg,
'Sharpness': _enhance_level_to_arg,
'ShearX': _shear_level_to_arg,
'ShearY': _shear_level_to_arg,
'TranslateX': _translate_abs_level_to_arg,
'TranslateY': _translate_abs_level_to_arg,
'TranslateXRel': _translate_rel_level_to_arg,
'TranslateYRel': _translate_rel_level_to_arg,
}
NAME_TO_OP = { NAME_TO_OP = {
@ -227,8 +256,9 @@ NAME_TO_OP = {
'Equalize': equalize, 'Equalize': equalize,
'Invert': invert, 'Invert': invert,
'Rotate': rotate, 'Rotate': rotate,
'Posterize': posterize, 'PosterizeOriginal': posterize,
'Posterize2': posterize, 'PosterizeResearch': posterize,
'PosterizeTpu': posterize,
'Solarize': solarize, 'Solarize': solarize,
'SolarizeAdd': solarize_add, 'SolarizeAdd': solarize_add,
'Color': color, 'Color': color,
@ -246,35 +276,70 @@ NAME_TO_OP = {
class AutoAugmentOp: class AutoAugmentOp:
def __init__(self, name, prob, magnitude, hparams={}): def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
self.aug_fn = NAME_TO_OP[name] self.aug_fn = NAME_TO_OP[name]
self.level_fn = level_to_arg(hparams)[name] self.level_fn = LEVEL_TO_ARG[name]
self.prob = prob self.prob = prob
self.magnitude = magnitude self.magnitude = magnitude
# If std deviation of magnitude is > 0, we introduce some randomness self.hparams = hparams.copy()
# in the usually fixed policy and sample magnitude from normal dist self.kwargs = dict(
# with mean magnitude and std-dev of magnitude_std. fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
# NOTE This is being tested as it's not in paper or reference impl. resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
self.magnitude_std = 0.5 # FIXME add arg/hparam )
self.kwargs = {
'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, # If magnitude_std is > 0, we introduce some randomness
'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION # in the usually fixed policy and sample magnitude from a normal distribution
} # with mean `magnitude` and std-dev of `magnitude_std`.
# NOTE This is my own hack, being tested, not in papers or reference impls.
self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img): def __call__(self, img):
if self.prob < random.random(): if random.random() > self.prob:
return img return img
magnitude = self.magnitude magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0: if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude) level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs) return self.aug_fn(img, *level_args, **self.kwargs)
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): def auto_augment_policy_v0(hparams):
# ImageNet policy from TPU EfficientNet impl, cannot find # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
# a paper reference. policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_v0r(hparams):
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize
policy = [ policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
@ -288,7 +353,7 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], [('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
@ -298,27 +363,27 @@ def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], [('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
] ]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc return pc
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): def auto_augment_policy_original(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501 # ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [ policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)], [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
@ -335,15 +400,53 @@ def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
] ]
pc = [[AutoAugmentOp(*a, hparams) for a in sp] for sp in policy] pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc return pc
def auto_augment_policy(name='v0', hparams=_HPARAMS_DEFAULT): def auto_augment_policy_originalr(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
policy = [
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
if name == 'original': if name == 'original':
return auto_augment_policy_original(hparams) return auto_augment_policy_original(hparams)
elif name == 'originalr':
return auto_augment_policy_originalr(hparams)
elif name == 'v0': elif name == 'v0':
return auto_augment_policy_v0(hparams) return auto_augment_policy_v0(hparams)
elif name == 'v0r':
return auto_augment_policy_v0r(hparams)
else: else:
assert False, 'Unknown AA policy (%s)' % name assert False, 'Unknown AA policy (%s)' % name
@ -358,3 +461,151 @@ class AutoAugment:
for op in sub_policy: for op in sub_policy:
img = op(img) img = op(img)
return img return img
def auto_augment_transform(config_str, hparams):
"""
Create a AutoAugment transform
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections, not order sepecific determine
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
:return: A PyTorch compatible Transform
"""
config = config_str.split('-')
policy_name = config[0]
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
else:
assert False, 'Unknown AutoAugment config section'
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
return AutoAugment(aa_policy)
_RAND_TRANSFORMS = [
'AutoContrast',
'Equalize',
'Invert',
'Rotate',
'PosterizeTpu',
'Solarize',
'SolarizeAdd',
'Color',
'Contrast',
'Brightness',
'Sharpness',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # FIXME I implement this as random erasing separately
]
# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
'Rotate': 0.3,
'ShearX': 0.2,
'ShearY': 0.2,
'TranslateXRel': 0.1,
'TranslateYRel': 0.1,
'Color': .025,
'Sharpness': 0.025,
'AutoContrast': 0.025,
'Solarize': .005,
'SolarizeAdd': .005,
'Contrast': .005,
'Brightness': .005,
'Equalize': .005,
'PosterizeTpu': 0,
'Invert': 0,
}
def _select_rand_weights(weight_idx=0, transforms=None):
transforms = transforms or _RAND_TRANSFORMS
assert weight_idx == 0 # only one set of weights currently
rand_weights = _RAND_CHOICE_WEIGHTS_0
probs = [rand_weights[k] for k in transforms]
probs /= np.sum(probs)
return probs
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AutoAugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
class RandAugment:
def __init__(self, ops, num_layers=2, choice_weights=None):
self.ops = ops
self.num_layers = num_layers
self.choice_weights = choice_weights
def __call__(self, img):
# no replacement when using weighted choice
ops = np.random.choice(
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
for op in ops:
img = op(img)
return img
def rand_augment_transform(config_str, hparams):
"""
Create a RandAugment transform
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
sections, not order sepecific determine
'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
'mstd' - float std deviation of magnitude noise applied
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
:return: A PyTorch compatible Transform
"""
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice
config = config_str.split('-')
assert config[0] == 'rand'
config = config[1:]
for c in config:
cs = re.split(r'(\d.*)', c)
if len(cs) < 2:
continue
key, val = cs[:2]
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'm':
magnitude = int(val)
elif key == 'n':
num_layers = int(val)
elif key == 'w':
weight_idx = int(val)
else:
assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)

@ -9,7 +9,7 @@ import numpy as np
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .random_erasing import RandomErasing from .random_erasing import RandomErasing
from .auto_augment import AutoAugment, auto_augment_policy from .auto_augment import auto_augment_transform, rand_augment_transform
class ToNumpy: class ToNumpy:
@ -179,6 +179,7 @@ def transforms_imagenet_train(
transforms.RandomHorizontalFlip() transforms.RandomHorizontalFlip()
] ]
if auto_augment: if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple): if isinstance(img_size, tuple):
img_size_min = min(img_size) img_size_min = min(img_size)
else: else:
@ -189,8 +190,10 @@ def transforms_imagenet_train(
) )
if interpolation and interpolation != 'random': if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation) aa_params['interpolation'] = _pil_interp(interpolation)
aa_policy = auto_augment_policy(auto_augment, aa_params) if auto_augment.startswith('rand'):
tfl += [AutoAugment(aa_policy)] tfl += [rand_augment_transform(auto_augment, aa_params)]
else:
tfl += [auto_augment_transform(auto_augment, aa_params)]
else: else:
# color jitter is enabled when not using AA # color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)): if isinstance(color_jitter, (list, tuple)):

@ -25,12 +25,13 @@ def create_model(
""" """
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants # Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet']) is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet'])
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]): if not is_efficientnet:
kwargs.pop('bn_tf', None) kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None) kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None) kwargs.pop('bn_eps', None)
kwargs.pop('drop_connect_rate', None)
if is_model(model_name): if is_model(model_name):
create_fn = model_entrypoint(model_name) create_fn = model_entrypoint(model_name)

@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
_USE_SWISH_OPT = True _USE_SWISH_OPT = True
if _USE_SWISH_OPT: if _USE_SWISH_OPT:
class SwishAutoFn(torch.autograd.Function): @torch.jit.script
""" Memory Efficient Swish def swish_jit_fwd(x):
From: https://blog.ceshine.net/post/pytorch-memory-swish/ return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
""" """
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
result = x.mul(torch.sigmoid(x))
ctx.save_for_backward(x) ctx.save_for_backward(x)
return result return swish_jit_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_variables[0] x = ctx.saved_tensors[0]
sigmoid_x = torch.sigmoid(x) return swish_jit_bwd(x, grad_output)
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
def swish(x, inplace=False): def swish(x, inplace=False):
return SwishAutoFn.apply(x) # inplace ignored
return SwishJitAutoFn.apply(x)
else: else:
def swish(x, inplace=False): def swish(x, inplace=False):
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())

@ -65,6 +65,8 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)') help='input batch size for training (default: 32)')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)') help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
help='Drop connect rate (default: 0.)')
# Optimizer parameters # Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"') help='Optimizer (default: "sgd"')
@ -87,7 +89,7 @@ parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)') help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N', parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N', parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports') help='epochs to warmup LR, if scheduler supports')
@ -208,6 +210,7 @@ def main():
pretrained=args.pretrained, pretrained=args.pretrained,
num_classes=args.num_classes, num_classes=args.num_classes,
drop_rate=args.drop, drop_rate=args.drop,
drop_connect_rate=args.drop_connect,
global_pool=args.gp, global_pool=args.gp,
bn_tf=args.bn_tf, bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum, bn_momentum=args.bn_momentum,
@ -253,7 +256,7 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Restoring NVIDIA AMP state from checkpoint') logging.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp']) amp.load_state_dict(resume_state['amp'])
resume_state = None resume_state = None # clear it
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:

Loading…
Cancel
Save