Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl

pull/74/head
Ross Wightman 5 years ago
parent ff8688ca3d
commit 232ab7fb12

@ -2,7 +2,8 @@ from .constants import *
from .config import resolve_data_config from .config import resolve_data_config
from .dataset import Dataset, DatasetTar from .dataset import Dataset, DatasetTar
from .transforms import * from .transforms import *
from .loader import create_loader, create_transform from .loader import create_loader
from .mixup import mixup_target, FastCollateMixup from .transforms_factory import create_transform
from .mixup import mixup_batch, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform rand_augment_transform, auto_augment_transform

@ -8,12 +8,11 @@ Hacked together by Ross Wightman
import random import random
import math import math
import re import re
from PIL import Image, ImageOps, ImageEnhance from PIL import Image, ImageOps, ImageEnhance, ImageChops
import PIL import PIL
import numpy as np import numpy as np
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
_FILL = (128, 128, 128) _FILL = (128, 128, 128)
@ -192,36 +191,47 @@ def _translate_abs_level_to_arg(level, hparams):
return level, return level,
def _translate_rel_level_to_arg(level, _hparams): def _translate_rel_level_to_arg(level, hparams):
# range [-0.45, 0.45] # default range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45 translate_pct = hparams.get('translate_pct', 0.45)
level = (level / _MAX_LEVEL) * translate_pct
level = _randomly_negate(level) level = _randomly_negate(level)
return level, return level,
def _posterize_original_level_to_arg(level, _hparams): def _posterize_level_to_arg(level, _hparams):
# As per original AutoAugment paper description # As per Tensorflow TPU EfficientNet impl
# range [4, 8], 'keep 4 up to 8 MSB of image' # range [0, 4], 'keep 0 up to 4 MSB of original image'
return int((level / _MAX_LEVEL) * 4) + 4, # intensity/severity of augmentation decreases with level
return int((level / _MAX_LEVEL) * 4),
def _posterize_research_level_to_arg(level, _hparams): def _posterize_increasing_level_to_arg(level, hparams):
# As per Tensorflow models research and UDA impl # As per Tensorflow models research and UDA impl
# range [4, 0], 'keep 4 down to 0 MSB of original image' # range [4, 0], 'keep 4 down to 0 MSB of original image',
return 4 - int((level / _MAX_LEVEL) * 4), # intensity/severity of augmentation increases with level
return 4 - _posterize_level_to_arg(level, hparams)[0],
def _posterize_tpu_level_to_arg(level, _hparams): def _posterize_original_level_to_arg(level, _hparams):
# As per Tensorflow TPU EfficientNet impl # As per original AutoAugment paper description
# range [0, 4], 'keep 0 up to 4 MSB of original image' # range [4, 8], 'keep 4 up to 8 MSB of image'
return int((level / _MAX_LEVEL) * 4), # intensity/severity of augmentation decreases with level
return int((level / _MAX_LEVEL) * 4) + 4,
def _solarize_level_to_arg(level, _hparams): def _solarize_level_to_arg(level, _hparams):
# range [0, 256] # range [0, 256]
# intensity/severity of augmentation decreases with level
return int((level / _MAX_LEVEL) * 256), return int((level / _MAX_LEVEL) * 256),
def _solarize_increasing_level_to_arg(level, _hparams):
# range [0, 256]
# intensity/severity of augmentation increases with level
return 256 - _solarize_level_to_arg(level, _hparams)[0],
def _solarize_add_level_to_arg(level, _hparams): def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110] # range [0, 110]
return int((level / _MAX_LEVEL) * 110), return int((level / _MAX_LEVEL) * 110),
@ -233,10 +243,11 @@ LEVEL_TO_ARG = {
'Invert': None, 'Invert': None,
'Rotate': _rotate_level_to_arg, 'Rotate': _rotate_level_to_arg,
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
'Posterize': _posterize_level_to_arg,
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
'PosterizeOriginal': _posterize_original_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg,
'PosterizeResearch': _posterize_research_level_to_arg,
'PosterizeTpu': _posterize_tpu_level_to_arg,
'Solarize': _solarize_level_to_arg, 'Solarize': _solarize_level_to_arg,
'SolarizeIncreasing': _solarize_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg, 'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg,
@ -256,10 +267,11 @@ NAME_TO_OP = {
'Equalize': equalize, 'Equalize': equalize,
'Invert': invert, 'Invert': invert,
'Rotate': rotate, 'Rotate': rotate,
'Posterize': posterize,
'PosterizeIncreasing': posterize,
'PosterizeOriginal': posterize, 'PosterizeOriginal': posterize,
'PosterizeResearch': posterize,
'PosterizeTpu': posterize,
'Solarize': solarize, 'Solarize': solarize,
'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add, 'SolarizeAdd': solarize_add,
'Color': color, 'Color': color,
'Contrast': contrast, 'Contrast': contrast,
@ -274,7 +286,7 @@ NAME_TO_OP = {
} }
class AutoAugmentOp: class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None): def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
@ -295,12 +307,12 @@ class AutoAugmentOp:
self.magnitude_std = self.hparams.get('magnitude_std', 0) self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img): def __call__(self, img):
if random.random() > self.prob: if not self.prob >= 1.0 or random.random() > self.prob:
return img return img
magnitude = self.magnitude magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0: if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs) return self.aug_fn(img, *level_args, **self.kwargs)
@ -320,7 +332,7 @@ def auto_augment_policy_v0(hparams):
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
@ -330,16 +342,17 @@ def auto_augment_policy_v0(hparams):
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
] ]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc return pc
def auto_augment_policy_v0r(hparams): def auto_augment_policy_v0r(hparams):
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
# in Google research implementation (number of bits discarded increases with magnitude)
policy = [ policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
@ -353,7 +366,7 @@ def auto_augment_policy_v0r(hparams):
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)], [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
@ -363,11 +376,11 @@ def auto_augment_policy_v0r(hparams):
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)], [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
] ]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc return pc
@ -400,23 +413,23 @@ def auto_augment_policy_original(hparams):
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
] ]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc return pc
def auto_augment_policy_originalr(hparams): def auto_augment_policy_originalr(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
policy = [ policy = [
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)], [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)], [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)], [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)], [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)], [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
@ -433,7 +446,7 @@ def auto_augment_policy_originalr(hparams):
[('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
] ]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc return pc
@ -499,7 +512,7 @@ _RAND_TRANSFORMS = [
'Equalize', 'Equalize',
'Invert', 'Invert',
'Rotate', 'Rotate',
'PosterizeTpu', 'Posterize',
'Solarize', 'Solarize',
'SolarizeAdd', 'SolarizeAdd',
'Color', 'Color',
@ -530,7 +543,7 @@ _RAND_CHOICE_WEIGHTS_0 = {
'Contrast': .005, 'Contrast': .005,
'Brightness': .005, 'Brightness': .005,
'Equalize': .005, 'Equalize': .005,
'PosterizeTpu': 0, 'Posterize': 0,
'Invert': 0, 'Invert': 0,
} }
@ -547,7 +560,7 @@ def _select_rand_weights(weight_idx=0, transforms=None):
def rand_augment_ops(magnitude=10, hparams=None, transforms=None): def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS transforms = transforms or _RAND_TRANSFORMS
return [AutoAugmentOp( return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
@ -609,3 +622,94 @@ def rand_augment_transform(config_str, hparams):
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
_AUGMIX_TRANSFORMS = [
'AutoContrast',
'Contrast', # not in paper
'Brightness', # not in paper
'Sharpness', # not in paper
'Equalize',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
]
def augmix_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _AUGMIX_TRANSFORMS
return [AugmentOp(
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
class AugMixAugment:
def __init__(self, ops, alpha=1., width=3, depth=-1):
self.ops = ops
self.alpha = alpha
self.width = width
self.depth = depth
self.recursive = True
def _apply_recursive(self, img, ws, prod=1.):
alpha = ws[-1] / prod
if len(ws) > 1:
img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha))
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
return Image.blend(img, img_aug, alpha)
def _apply_basic(self, img, ws, m):
w, h = img.size
c = len(img.getbands())
mixed = np.zeros((w, h, c), dtype=np.float32)
for w in ws:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
img_aug = np.asarray(img_aug, dtype=np.float32)
mixed += w * img_aug
np.clip(mixed, 0, 255., out=mixed)
mixed = Image.fromarray(mixed.astype(np.uint8))
return Image.blend(img, mixed, m)
def __call__(self, img):
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
m = np.float32(np.random.beta(self.alpha, self.alpha))
if self.recursive:
mixing_weights *= m
mixed = self._apply_recursive(img, mixing_weights)
else:
mixed = self._apply_basic(img, mixing_weights, m)
return mixed
def augment_and_mix_transform(config_str, hparams):
"""Perform AugMix augmentations and compute mixture.
Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c)
severity: Severity of underlying augmentation operators (between 1 to 10).
width: Width of augmentation chain
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
from [1, 3]
alpha: Probability coefficient for Beta and Dirichlet distributions.
Returns:
mixed: Augmented and mixed image.
"""
# FIXME parse args from config str
severity = 3
width = 3
depth = -1
alpha = 1.
ops = augmix_ops(magnitude=severity, hparams=hparams)
return AugMixAugment(ops, alpha, width, depth)

@ -140,3 +140,41 @@ class DatasetTar(data.Dataset):
def __len__(self): def __len__(self):
return len(self.imgs) return len(self.imgs)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_aug=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_aug = num_aug
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i]
x_list = [self._normalize(x)]
for n in range(self.num_aug):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)

@ -1,17 +1,46 @@
import torch.utils.data import torch.utils.data
from .transforms import * import numpy as np
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler from .distributed_sampler import OrderedDistributedSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup from .mixup import FastCollateMixup
def fast_collate(batch): def fast_collate(batch):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
batch_size = len(targets) assert isinstance(batch[0], tuple)
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) batch_size = len(batch)
for i in range(batch_size): if isinstance(batch[0][0], tuple):
tensor[i] += torch.from_numpy(batch[i][0]) # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
return tensor, targets inner_tuple_size = len(batch[0][0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
class PrefetchLoader: class PrefetchLoader:
@ -87,49 +116,6 @@ class PrefetchLoader:
self.loader.collate_fn.mixup_enabled = x self.loader.collate_fn.mixup_enabled = x
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
crop_pct=None,
tf_preprocessing=False):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
return transform
def create_loader( def create_loader(
dataset, dataset,
input_size, input_size,
@ -150,6 +136,7 @@ def create_loader(
collate_fn=None, collate_fn=None,
fp16=False, fp16=False,
tf_preprocessing=False, tf_preprocessing=False,
separate_transforms=False,
): ):
dataset.transform = create_transform( dataset.transform = create_transform(
input_size, input_size,
@ -162,6 +149,7 @@ def create_loader(
std=std, std=std,
crop_pct=crop_pct, crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing, tf_preprocessing=tf_preprocessing,
separate=separate_transforms,
) )
sampler = None sampler = None

@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
return lam*y1 + (1. - lam)*y2 return lam*y1 + (1. - lam)*y2
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
lam = 1.
if not disable:
lam = np.random.beta(alpha, alpha)
input = input.mul(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, num_classes, lam, smoothing)
return input, target
class FastCollateMixup: class FastCollateMixup:
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):

@ -1,5 +1,4 @@
import torch import torch
from torchvision import transforms
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from PIL import Image from PIL import Image
import warnings import warnings
@ -7,10 +6,6 @@ import math
import random import random
import numpy as np import numpy as np
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .random_erasing import RandomErasing
from .auto_augment import auto_augment_transform, rand_augment_transform
class ToNumpy: class ToNumpy:
@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation:
return format_string return format_string
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
tfl += [rand_augment_transform(auto_augment, aa_params)]
else:
tfl += [auto_augment_transform(auto_augment, aa_params)]
else:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
tfl += [transforms.ColorJitter(*color_jitter)]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
return transforms.Compose(tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(tfl)

@ -0,0 +1,164 @@
import math
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
from timm.data.random_erasing import RandomErasing
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
separate=False,
):
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
else:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(tfl)
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
crop_pct=None,
tf_preprocessing=False,
separate=False):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
separate=separate)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
return transform

@ -1 +1,2 @@
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy

@ -0,0 +1,34 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .cross_entropy import LabelSmoothingCrossEntropy
class JsdCrossEntropy(nn.Module):
""" Jenson-Shannon Divergence + Cross-Entropy Loss
"""
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
super().__init__()
self.num_splits = num_splits
self.alpha = alpha
if smoothing is not None and smoothing > 0:
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
else:
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def __call__(self, output, target):
split_size = output.shape[0] // self.num_splits
assert split_size * self.num_splits == output.shape[0]
logits_split = torch.split(output, split_size)
# Cross-entropy is only computed on clean images
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
probs = [F.softmax(logits, dim=1) for logits in logits_split]
# Clamp mixture distribution to avoid exploding KL divergence
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
loss += self.alpha * sum([F.kl_div(
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
return loss

@ -1,7 +1,6 @@
import argparse import argparse
import time import time
import logging
import yaml import yaml
from datetime import datetime from datetime import datetime
@ -14,13 +13,16 @@ except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
from timm.models import create_model, resume_checkpoint from timm.models import create_model, resume_checkpoint
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer from timm.optim import create_optimizer
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
#FIXME
from timm.data.dataset import AugMixDataset
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
@ -160,6 +162,10 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--jsd', action='store_true', default=False,
help='')
def _parse_args(): def _parse_args():
# Do we have a config file to parse? # Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args() args_config, remaining = config_parser.parse_known_args()
@ -311,8 +317,14 @@ def main():
collate_fn = None collate_fn = None
if args.prefetcher and args.mixup > 0: if args.prefetcher and args.mixup > 0:
assert not args.jsd
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
separate_transforms = False
if args.jsd:
dataset_train = AugMixDataset(dataset_train)
separate_transforms = True
loader_train = create_loader( loader_train = create_loader(
dataset_train, dataset_train,
input_size=data_config['input_size'], input_size=data_config['input_size'],
@ -330,6 +342,7 @@ def main():
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed, distributed=args.distributed,
collate_fn=collate_fn, collate_fn=collate_fn,
separate_transforms=separate_transforms,
) )
eval_dir = os.path.join(args.data, 'val') eval_dir = os.path.join(args.data, 'val')
@ -354,7 +367,10 @@ def main():
crop_pct=data_config['crop_pct'], crop_pct=data_config['crop_pct'],
) )
if args.mixup > 0.: if args.jsd:
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.mixup > 0.:
# smoothing is handled with mixup label transform # smoothing is handled with mixup label transform
train_loss_fn = SoftTargetCrossEntropy().cuda() train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda()
@ -452,11 +468,10 @@ def train_epoch(
if not args.prefetcher: if not args.prefetcher:
input, target = input.cuda(), target.cuda() input, target = input.cuda(), target.cuda()
if args.mixup > 0.: if args.mixup > 0.:
lam = 1. input, target = mixup_batch(
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: input, target,
lam = np.random.beta(args.mixup, args.mixup) alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
input = input.mul(lam).add_(1 - lam, input.flip(0)) disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
target = mixup_target(target, args.num_classes, lam, args.smoothing)
output = model(input) output = model(input)

Loading…
Cancel
Save