Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl

pull/74/head
Ross Wightman 4 years ago
parent ff8688ca3d
commit 232ab7fb12

@ -2,7 +2,8 @@ from .constants import *
from .config import resolve_data_config
from .dataset import Dataset, DatasetTar
from .transforms import *
from .loader import create_loader, create_transform
from .mixup import mixup_target, FastCollateMixup
from .loader import create_loader
from .transforms_factory import create_transform
from .mixup import mixup_batch, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform

@ -8,12 +8,11 @@ Hacked together by Ross Wightman
import random
import math
import re
from PIL import Image, ImageOps, ImageEnhance
from PIL import Image, ImageOps, ImageEnhance, ImageChops
import PIL
import numpy as np
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
_FILL = (128, 128, 128)
@ -192,36 +191,47 @@ def _translate_abs_level_to_arg(level, hparams):
return level,
def _translate_rel_level_to_arg(level, _hparams):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45
def _translate_rel_level_to_arg(level, hparams):
# default range [-0.45, 0.45]
translate_pct = hparams.get('translate_pct', 0.45)
level = (level / _MAX_LEVEL) * translate_pct
level = _randomly_negate(level)
return level,
def _posterize_original_level_to_arg(level, _hparams):
# As per original AutoAugment paper description
# range [4, 8], 'keep 4 up to 8 MSB of image'
return int((level / _MAX_LEVEL) * 4) + 4,
def _posterize_level_to_arg(level, _hparams):
# As per Tensorflow TPU EfficientNet impl
# range [0, 4], 'keep 0 up to 4 MSB of original image'
# intensity/severity of augmentation decreases with level
return int((level / _MAX_LEVEL) * 4),
def _posterize_research_level_to_arg(level, _hparams):
def _posterize_increasing_level_to_arg(level, hparams):
# As per Tensorflow models research and UDA impl
# range [4, 0], 'keep 4 down to 0 MSB of original image'
return 4 - int((level / _MAX_LEVEL) * 4),
# range [4, 0], 'keep 4 down to 0 MSB of original image',
# intensity/severity of augmentation increases with level
return 4 - _posterize_level_to_arg(level, hparams)[0],
def _posterize_tpu_level_to_arg(level, _hparams):
# As per Tensorflow TPU EfficientNet impl
# range [0, 4], 'keep 0 up to 4 MSB of original image'
return int((level / _MAX_LEVEL) * 4),
def _posterize_original_level_to_arg(level, _hparams):
# As per original AutoAugment paper description
# range [4, 8], 'keep 4 up to 8 MSB of image'
# intensity/severity of augmentation decreases with level
return int((level / _MAX_LEVEL) * 4) + 4,
def _solarize_level_to_arg(level, _hparams):
# range [0, 256]
# intensity/severity of augmentation decreases with level
return int((level / _MAX_LEVEL) * 256),
def _solarize_increasing_level_to_arg(level, _hparams):
# range [0, 256]
# intensity/severity of augmentation increases with level
return 256 - _solarize_level_to_arg(level, _hparams)[0],
def _solarize_add_level_to_arg(level, _hparams):
# range [0, 110]
return int((level / _MAX_LEVEL) * 110),
@ -233,10 +243,11 @@ LEVEL_TO_ARG = {
'Invert': None,
'Rotate': _rotate_level_to_arg,
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
'Posterize': _posterize_level_to_arg,
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
'PosterizeOriginal': _posterize_original_level_to_arg,
'PosterizeResearch': _posterize_research_level_to_arg,
'PosterizeTpu': _posterize_tpu_level_to_arg,
'Solarize': _solarize_level_to_arg,
'SolarizeIncreasing': _solarize_level_to_arg,
'SolarizeAdd': _solarize_add_level_to_arg,
'Color': _enhance_level_to_arg,
'Contrast': _enhance_level_to_arg,
@ -256,10 +267,11 @@ NAME_TO_OP = {
'Equalize': equalize,
'Invert': invert,
'Rotate': rotate,
'Posterize': posterize,
'PosterizeIncreasing': posterize,
'PosterizeOriginal': posterize,
'PosterizeResearch': posterize,
'PosterizeTpu': posterize,
'Solarize': solarize,
'SolarizeIncreasing': solarize,
'SolarizeAdd': solarize_add,
'Color': color,
'Contrast': contrast,
@ -274,7 +286,7 @@ NAME_TO_OP = {
}
class AutoAugmentOp:
class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
@ -295,12 +307,12 @@ class AutoAugmentOp:
self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img):
if random.random() > self.prob:
if not self.prob >= 1.0 or random.random() > self.prob:
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs)
@ -320,7 +332,7 @@ def auto_augment_policy_v0(hparams):
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
@ -330,16 +342,17 @@ def auto_augment_policy_v0(hparams):
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_v0r(hparams):
# ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
# in Google research implementation (number of bits discarded increases with magnitude)
policy = [
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
@ -353,7 +366,7 @@ def auto_augment_policy_v0r(hparams):
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
[('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
@ -363,11 +376,11 @@ def auto_augment_policy_v0r(hparams):
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
[('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)],
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
@ -400,23 +413,23 @@ def auto_augment_policy_original(hparams):
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_originalr(hparams):
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
policy = [
[('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)],
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)],
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)],
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)],
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
@ -433,7 +446,7 @@ def auto_augment_policy_originalr(hparams):
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
]
pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
return pc
@ -499,7 +512,7 @@ _RAND_TRANSFORMS = [
'Equalize',
'Invert',
'Rotate',
'PosterizeTpu',
'Posterize',
'Solarize',
'SolarizeAdd',
'Color',
@ -530,7 +543,7 @@ _RAND_CHOICE_WEIGHTS_0 = {
'Contrast': .005,
'Brightness': .005,
'Equalize': .005,
'PosterizeTpu': 0,
'Posterize': 0,
'Invert': 0,
}
@ -547,7 +560,7 @@ def _select_rand_weights(weight_idx=0, transforms=None):
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AutoAugmentOp(
return [AugmentOp(
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
@ -609,3 +622,94 @@ def rand_augment_transform(config_str, hparams):
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
_AUGMIX_TRANSFORMS = [
'AutoContrast',
'Contrast', # not in paper
'Brightness', # not in paper
'Sharpness', # not in paper
'Equalize',
'Rotate',
'PosterizeIncreasing',
'SolarizeIncreasing',
'ShearX',
'ShearY',
'TranslateXRel',
'TranslateYRel',
]
def augmix_ops(magnitude=10, hparams=None, transforms=None):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _AUGMIX_TRANSFORMS
return [AugmentOp(
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
class AugMixAugment:
def __init__(self, ops, alpha=1., width=3, depth=-1):
self.ops = ops
self.alpha = alpha
self.width = width
self.depth = depth
self.recursive = True
def _apply_recursive(self, img, ws, prod=1.):
alpha = ws[-1] / prod
if len(ws) > 1:
img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha))
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
return Image.blend(img, img_aug, alpha)
def _apply_basic(self, img, ws, m):
w, h = img.size
c = len(img.getbands())
mixed = np.zeros((w, h, c), dtype=np.float32)
for w in ws:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
img_aug = np.asarray(img_aug, dtype=np.float32)
mixed += w * img_aug
np.clip(mixed, 0, 255., out=mixed)
mixed = Image.fromarray(mixed.astype(np.uint8))
return Image.blend(img, mixed, m)
def __call__(self, img):
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
m = np.float32(np.random.beta(self.alpha, self.alpha))
if self.recursive:
mixing_weights *= m
mixed = self._apply_recursive(img, mixing_weights)
else:
mixed = self._apply_basic(img, mixing_weights, m)
return mixed
def augment_and_mix_transform(config_str, hparams):
"""Perform AugMix augmentations and compute mixture.
Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c)
severity: Severity of underlying augmentation operators (between 1 to 10).
width: Width of augmentation chain
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
from [1, 3]
alpha: Probability coefficient for Beta and Dirichlet distributions.
Returns:
mixed: Augmented and mixed image.
"""
# FIXME parse args from config str
severity = 3
width = 3
depth = -1
alpha = 1.
ops = augmix_ops(magnitude=severity, hparams=hparams)
return AugMixAugment(ops, alpha, width, depth)

@ -140,3 +140,41 @@ class DatasetTar(data.Dataset):
def __len__(self):
return len(self.imgs)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_aug=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_aug = num_aug
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i]
x_list = [self._normalize(x)]
for n in range(self.num_aug):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)

@ -1,17 +1,46 @@
import torch.utils.data
from .transforms import *
import numpy as np
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
def fast_collate(batch):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
batch_size = len(targets)
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
class PrefetchLoader:
@ -87,49 +116,6 @@ class PrefetchLoader:
self.loader.collate_fn.mixup_enabled = x
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
crop_pct=None,
tf_preprocessing=False):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
return transform
def create_loader(
dataset,
input_size,
@ -150,6 +136,7 @@ def create_loader(
collate_fn=None,
fp16=False,
tf_preprocessing=False,
separate_transforms=False,
):
dataset.transform = create_transform(
input_size,
@ -162,6 +149,7 @@ def create_loader(
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
separate=separate_transforms,
)
sampler = None

@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
return lam*y1 + (1. - lam)*y2
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
lam = 1.
if not disable:
lam = np.random.beta(alpha, alpha)
input = input.mul(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, num_classes, lam, smoothing)
return input, target
class FastCollateMixup:
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):

@ -1,5 +1,4 @@
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from PIL import Image
import warnings
@ -7,10 +6,6 @@ import math
import random
import numpy as np
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .random_erasing import RandomErasing
from .auto_augment import auto_augment_transform, rand_augment_transform
class ToNumpy:
@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation:
return format_string
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
tfl += [rand_augment_transform(auto_augment, aa_params)]
else:
tfl += [auto_augment_transform(auto_augment, aa_params)]
else:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
tfl += [transforms.ColorJitter(*color_jitter)]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
return transforms.Compose(tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(tfl)

@ -0,0 +1,164 @@
import math
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
from timm.data.random_erasing import RandomErasing
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
separate=False,
):
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
else:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(tfl)
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
crop_pct=None,
tf_preprocessing=False,
separate=False):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
separate=separate)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
return transform

@ -1 +1,2 @@
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy

@ -0,0 +1,34 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .cross_entropy import LabelSmoothingCrossEntropy
class JsdCrossEntropy(nn.Module):
""" Jenson-Shannon Divergence + Cross-Entropy Loss
"""
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
super().__init__()
self.num_splits = num_splits
self.alpha = alpha
if smoothing is not None and smoothing > 0:
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
else:
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def __call__(self, output, target):
split_size = output.shape[0] // self.num_splits
assert split_size * self.num_splits == output.shape[0]
logits_split = torch.split(output, split_size)
# Cross-entropy is only computed on clean images
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
probs = [F.softmax(logits, dim=1) for logits in logits_split]
# Clamp mixture distribution to avoid exploding KL divergence
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
loss += self.alpha * sum([F.kl_div(
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
return loss

@ -1,7 +1,6 @@
import argparse
import time
import logging
import yaml
from datetime import datetime
@ -14,13 +13,16 @@ except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
from timm.models import create_model, resume_checkpoint
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
#FIXME
from timm.data.dataset import AugMixDataset
import torch
import torch.nn as nn
import torchvision.utils
@ -160,6 +162,10 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--jsd', action='store_true', default=False,
help='')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
@ -311,8 +317,14 @@ def main():
collate_fn = None
if args.prefetcher and args.mixup > 0:
assert not args.jsd
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
separate_transforms = False
if args.jsd:
dataset_train = AugMixDataset(dataset_train)
separate_transforms = True
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
@ -330,6 +342,7 @@ def main():
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
separate_transforms=separate_transforms,
)
eval_dir = os.path.join(args.data, 'val')
@ -354,7 +367,10 @@ def main():
crop_pct=data_config['crop_pct'],
)
if args.mixup > 0.:
if args.jsd:
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.mixup > 0.:
# smoothing is handled with mixup label transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
@ -452,11 +468,10 @@ def train_epoch(
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if args.mixup > 0.:
lam = 1.
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
lam = np.random.beta(args.mixup, args.mixup)
input = input.mul(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, args.num_classes, lam, args.smoothing)
input, target = mixup_batch(
input, target,
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
output = model(input)

Loading…
Cancel
Save