From 232ab7fb12ba082e6d4039c7a7c7f2701caa0a71 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 19 Dec 2019 18:16:18 -0800 Subject: [PATCH] Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl --- timm/data/__init__.py | 5 +- timm/data/auto_augment.py | 182 +++++++++++++++++++++++++------- timm/data/dataset.py | 38 +++++++ timm/data/loader.py | 90 +++++++--------- timm/data/mixup.py | 9 ++ timm/data/transforms.py | 99 ----------------- timm/data/transforms_factory.py | 164 ++++++++++++++++++++++++++++ timm/loss/__init__.py | 1 + timm/loss/jsd.py | 34 ++++++ train.py | 33 ++++-- 10 files changed, 455 insertions(+), 200 deletions(-) create mode 100644 timm/data/transforms_factory.py create mode 100644 timm/loss/jsd.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 49c4bc60..66c16257 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -2,7 +2,8 @@ from .constants import * from .config import resolve_data_config from .dataset import Dataset, DatasetTar from .transforms import * -from .loader import create_loader, create_transform -from .mixup import mixup_target, FastCollateMixup +from .loader import create_loader +from .transforms_factory import create_transform +from .mixup import mixup_batch, FastCollateMixup from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index d730c266..cc1b716c 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -8,12 +8,11 @@ Hacked together by Ross Wightman import random import math import re -from PIL import Image, ImageOps, ImageEnhance +from PIL import Image, ImageOps, ImageEnhance, ImageChops import PIL import numpy as np - _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _FILL = (128, 128, 128) @@ -192,36 +191,47 @@ def _translate_abs_level_to_arg(level, hparams): return level, -def _translate_rel_level_to_arg(level, _hparams): - # range [-0.45, 0.45] - level = (level / _MAX_LEVEL) * 0.45 +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get('translate_pct', 0.45) + level = (level / _MAX_LEVEL) * translate_pct level = _randomly_negate(level) return level, -def _posterize_original_level_to_arg(level, _hparams): - # As per original AutoAugment paper description - # range [4, 8], 'keep 4 up to 8 MSB of image' - return int((level / _MAX_LEVEL) * 4) + 4, +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4), -def _posterize_research_level_to_arg(level, _hparams): +def _posterize_increasing_level_to_arg(level, hparams): # As per Tensorflow models research and UDA impl - # range [4, 0], 'keep 4 down to 0 MSB of original image' - return 4 - int((level / _MAX_LEVEL) * 4), + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return 4 - _posterize_level_to_arg(level, hparams)[0], -def _posterize_tpu_level_to_arg(level, _hparams): - # As per Tensorflow TPU EfficientNet impl - # range [0, 4], 'keep 0 up to 4 MSB of original image' - return int((level / _MAX_LEVEL) * 4), +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4) + 4, def _solarize_level_to_arg(level, _hparams): # range [0, 256] + # intensity/severity of augmentation decreases with level return int((level / _MAX_LEVEL) * 256), +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return 256 - _solarize_level_to_arg(level, _hparams)[0], + + def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] return int((level / _MAX_LEVEL) * 110), @@ -233,10 +243,11 @@ LEVEL_TO_ARG = { 'Invert': None, 'Rotate': _rotate_level_to_arg, # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'Posterize': _posterize_level_to_arg, + 'PosterizeIncreasing': _posterize_increasing_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg, - 'PosterizeResearch': _posterize_research_level_to_arg, - 'PosterizeTpu': _posterize_tpu_level_to_arg, 'Solarize': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg, 'Color': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg, @@ -256,10 +267,11 @@ NAME_TO_OP = { 'Equalize': equalize, 'Invert': invert, 'Rotate': rotate, + 'Posterize': posterize, + 'PosterizeIncreasing': posterize, 'PosterizeOriginal': posterize, - 'PosterizeResearch': posterize, - 'PosterizeTpu': posterize, 'Solarize': solarize, + 'SolarizeIncreasing': solarize, 'SolarizeAdd': solarize_add, 'Color': color, 'Contrast': contrast, @@ -274,7 +286,7 @@ NAME_TO_OP = { } -class AutoAugmentOp: +class AugmentOp: def __init__(self, name, prob=0.5, magnitude=10, hparams=None): hparams = hparams or _HPARAMS_DEFAULT @@ -295,12 +307,12 @@ class AutoAugmentOp: self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): - if random.random() > self.prob: + if not self.prob >= 1.0 or random.random() > self.prob: return img magnitude = self.magnitude if self.magnitude_std and self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) - magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) @@ -320,7 +332,7 @@ def auto_augment_policy_v0(hparams): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -330,16 +342,17 @@ def auto_augment_policy_v0(hparams): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc def auto_augment_policy_v0r(hparams): - # ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize + # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used + # in Google research implementation (number of bits discarded increases with magnitude) policy = [ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)], @@ -353,7 +366,7 @@ def auto_augment_policy_v0r(hparams): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -363,11 +376,11 @@ def auto_augment_policy_v0r(hparams): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)], + [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc @@ -400,23 +413,23 @@ def auto_augment_policy_original(hparams): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc def auto_augment_policy_originalr(hparams): # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation policy = [ - [('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)], + [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], - [('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)], + [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], - [('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)], + [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], - [('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)], + [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], [('Rotate', 0.8, 8), ('Color', 0.4, 0)], [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], @@ -433,7 +446,7 @@ def auto_augment_policy_originalr(hparams): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc @@ -499,7 +512,7 @@ _RAND_TRANSFORMS = [ 'Equalize', 'Invert', 'Rotate', - 'PosterizeTpu', + 'Posterize', 'Solarize', 'SolarizeAdd', 'Color', @@ -530,7 +543,7 @@ _RAND_CHOICE_WEIGHTS_0 = { 'Contrast': .005, 'Brightness': .005, 'Equalize': .005, - 'PosterizeTpu': 0, + 'Posterize': 0, 'Invert': 0, } @@ -547,7 +560,7 @@ def _select_rand_weights(weight_idx=0, transforms=None): def rand_augment_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS - return [AutoAugmentOp( + return [AugmentOp( name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] @@ -609,3 +622,94 @@ def rand_augment_transform(config_str, hparams): ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + + +_AUGMIX_TRANSFORMS = [ + 'AutoContrast', + 'Contrast', # not in paper + 'Brightness', # not in paper + 'Sharpness', # not in paper + 'Equalize', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] + + +def augmix_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _AUGMIX_TRANSFORMS + return [AugmentOp( + name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class AugMixAugment: + def __init__(self, ops, alpha=1., width=3, depth=-1): + self.ops = ops + self.alpha = alpha + self.width = width + self.depth = depth + self.recursive = True + + def _apply_recursive(self, img, ws, prod=1.): + alpha = ws[-1] / prod + if len(ws) > 1: + img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha)) + + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + return Image.blend(img, img_aug, alpha) + + def _apply_basic(self, img, ws, m): + w, h = img.size + c = len(img.getbands()) + mixed = np.zeros((w, h, c), dtype=np.float32) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img_aug = np.asarray(img_aug, dtype=np.float32) + mixed += w * img_aug + np.clip(mixed, 0, 255., out=mixed) + mixed = Image.fromarray(mixed.astype(np.uint8)) + return Image.blend(img, mixed, m) + + def __call__(self, img): + mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) + m = np.float32(np.random.beta(self.alpha, self.alpha)) + if self.recursive: + mixing_weights *= m + mixed = self._apply_recursive(img, mixing_weights) + else: + mixed = self._apply_basic(img, mixing_weights, m) + return mixed + + +def augment_and_mix_transform(config_str, hparams): + """Perform AugMix augmentations and compute mixture. + Args: + image: Raw input image as float32 np.ndarray of shape (h, w, c) + severity: Severity of underlying augmentation operators (between 1 to 10). + width: Width of augmentation chain + depth: Depth of augmentation chain. -1 enables stochastic depth uniformly + from [1, 3] + alpha: Probability coefficient for Beta and Dirichlet distributions. + Returns: + mixed: Augmented and mixed image. + """ + # FIXME parse args from config str + severity = 3 + width = 3 + depth = -1 + alpha = 1. + ops = augmix_ops(magnitude=severity, hparams=hparams) + return AugMixAugment(ops, alpha, width, depth) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 47437d5e..3220883a 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -140,3 +140,41 @@ class DatasetTar(data.Dataset): def __len__(self): return len(self.imgs) + +class AugMixDataset(torch.utils.data.Dataset): + """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" + + def __init__(self, dataset, num_aug=2): + self.augmentation = None + self.normalize = None + self.dataset = dataset + if self.dataset.transform is not None: + self._set_transforms(self.dataset.transform) + self.num_aug = num_aug + + def _set_transforms(self, x): + assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' + self.dataset.transform = x[0] + self.augmentation = x[1] + self.normalize = x[2] + + @property + def transform(self): + return self.dataset.transform + + @transform.setter + def transform(self, x): + self._set_transforms(x) + + def _normalize(self, x): + return x if self.normalize is None else self.normalize(x) + + def __getitem__(self, i): + x, y = self.dataset[i] + x_list = [self._normalize(x)] + for n in range(self.num_aug): + x_list.append(self._normalize(self.augmentation(x))) + return tuple(x_list), y + + def __len__(self): + return len(self.dataset) diff --git a/timm/data/loader.py b/timm/data/loader.py index 8c27f1bb..06d431ee 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -1,17 +1,46 @@ import torch.utils.data -from .transforms import * +import numpy as np + +from .transforms_factory import create_transform +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .distributed_sampler import OrderedDistributedSampler +from .random_erasing import RandomErasing from .mixup import FastCollateMixup def fast_collate(batch): - targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) - batch_size = len(targets) - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - for i in range(batch_size): - tensor[i] += torch.from_numpy(batch[i][0]) - - return tensor, targets + """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], tuple): + # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position + # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + inner_tuple_size = len(batch[0][0][0]) + flattened_batch_size = batch_size * inner_tuple_size + targets = torch.zeros(flattened_batch_size, dtype=torch.int64) + tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length + for j in range(inner_tuple_size): + targets[i + j * batch_size] = batch[i][1] + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + return tensor, targets + elif isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + elif isinstance(batch[0][0], torch.Tensor): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i].copy_(batch[i][0]) + return tensor, targets + else: + assert False class PrefetchLoader: @@ -87,49 +116,6 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x -def create_transform( - input_size, - is_training=False, - use_prefetcher=False, - color_jitter=0.4, - auto_augment=None, - interpolation='bilinear', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - crop_pct=None, - tf_preprocessing=False): - - if isinstance(input_size, tuple): - img_size = input_size[-2:] - else: - img_size = input_size - - if tf_preprocessing and use_prefetcher: - from timm.data.tf_preprocessing import TfPreprocessTransform - transform = TfPreprocessTransform( - is_training=is_training, size=img_size, interpolation=interpolation) - else: - if is_training: - transform = transforms_imagenet_train( - img_size, - color_jitter=color_jitter, - auto_augment=auto_augment, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) - else: - transform = transforms_imagenet_eval( - img_size, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std, - crop_pct=crop_pct) - - return transform - - def create_loader( dataset, input_size, @@ -150,6 +136,7 @@ def create_loader( collate_fn=None, fp16=False, tf_preprocessing=False, + separate_transforms=False, ): dataset.transform = create_transform( input_size, @@ -162,6 +149,7 @@ def create_loader( std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, + separate=separate_transforms, ) sampler = None diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 83d51ccb..4678472d 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): return lam*y1 + (1. - lam)*y2 +def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): + lam = 1. + if not disable: + lam = np.random.beta(alpha, alpha) + input = input.mul(lam).add_(1 - lam, input.flip(0)) + target = mixup_target(target, num_classes, lam, smoothing) + return input, target + + class FastCollateMixup: def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 41f2a63e..b3b08e30 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -1,5 +1,4 @@ import torch -from torchvision import transforms import torchvision.transforms.functional as F from PIL import Image import warnings @@ -7,10 +6,6 @@ import math import random import numpy as np -from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .random_erasing import RandomErasing -from .auto_augment import auto_augment_transform, rand_augment_transform - class ToNumpy: @@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation: return format_string -def transforms_imagenet_train( - img_size=224, - scale=(0.08, 1.0), - color_jitter=0.4, - auto_augment=None, - interpolation='random', - random_erasing=0.4, - random_erasing_mode='const', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD -): - tfl = [ - RandomResizedCropAndInterpolation( - img_size, scale=scale, interpolation=interpolation), - transforms.RandomHorizontalFlip() - ] - if auto_augment: - assert isinstance(auto_augment, str) - if isinstance(img_size, tuple): - img_size_min = min(img_size) - else: - img_size_min = img_size - aa_params = dict( - translate_const=int(img_size_min * 0.45), - img_mean=tuple([min(255, round(255 * x)) for x in mean]), - ) - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - tfl += [rand_augment_transform(auto_augment, aa_params)] - else: - tfl += [auto_augment_transform(auto_augment, aa_params)] - else: - # color jitter is enabled when not using AA - if isinstance(color_jitter, (list, tuple)): - # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation - # or 4 if also augmenting hue - assert len(color_jitter) in (3, 4) - else: - # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue - color_jitter = (float(color_jitter),) * 3 - tfl += [transforms.ColorJitter(*color_jitter)] - - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - if random_erasing > 0.: - tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) - return transforms.Compose(tfl) - - -def transforms_imagenet_eval( - img_size=224, - crop_pct=None, - interpolation='bilinear', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD): - crop_pct = crop_pct or DEFAULT_CROP_PCT - - if isinstance(img_size, tuple): - assert len(img_size) == 2 - if img_size[-1] == img_size[-2]: - # fall-back to older behaviour so Resize scales to shortest edge if target is square - scale_size = int(math.floor(img_size[0] / crop_pct)) - else: - scale_size = tuple([int(x / crop_pct) for x in img_size]) - else: - scale_size = int(math.floor(img_size / crop_pct)) - - tfl = [ - transforms.Resize(scale_size, _pil_interp(interpolation)), - transforms.CenterCrop(img_size), - ] - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - - return transforms.Compose(tfl) diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py new file mode 100644 index 00000000..b70ae76f --- /dev/null +++ b/timm/data/transforms_factory.py @@ -0,0 +1,164 @@ +import math + +import torch +from torchvision import transforms + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform +from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.random_erasing import RandomErasing + + +def transforms_imagenet_train( + img_size=224, + scale=(0.08, 1.0), + color_jitter=0.4, + auto_augment=None, + interpolation='random', + random_erasing=0.4, + random_erasing_mode='const', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + separate=False, +): + + primary_tfl = [ + RandomResizedCropAndInterpolation( + img_size, scale=scale, interpolation=interpolation), + transforms.RandomHorizontalFlip() + ] + + secondary_tfl = [] + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] + elif auto_augment.startswith('augmix'): + aa_params['translate_pct'] = 0.3 + secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] + else: + secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] + elif color_jitter is not None: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + secondary_tfl += [transforms.ColorJitter(*color_jitter)] + + final_tfl = [] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + final_tfl += [ToNumpy()] + else: + final_tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if random_erasing > 0.: + final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) + + if separate: + return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + else: + return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + + +def transforms_imagenet_eval( + img_size=224, + crop_pct=None, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): + crop_pct = crop_pct or DEFAULT_CROP_PCT + + if isinstance(img_size, tuple): + assert len(img_size) == 2 + if img_size[-1] == img_size[-2]: + # fall-back to older behaviour so Resize scales to shortest edge if target is square + scale_size = int(math.floor(img_size[0] / crop_pct)) + else: + scale_size = tuple([int(x / crop_pct) for x in img_size]) + else: + scale_size = int(math.floor(img_size / crop_pct)) + + tfl = [ + transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size), + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + + return transforms.Compose(tfl) + + +def create_transform( + input_size, + is_training=False, + use_prefetcher=False, + color_jitter=0.4, + auto_augment=None, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + crop_pct=None, + tf_preprocessing=False, + separate=False): + + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if tf_preprocessing and use_prefetcher: + assert not separate, "Separate transforms not supported for TF preprocessing" + from timm.data.tf_preprocessing import TfPreprocessTransform + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=interpolation) + else: + if is_training: + transform = transforms_imagenet_train( + img_size, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + separate=separate) + else: + assert not separate, "Separate transforms not supported for validation preprocessing" + transform = transforms_imagenet_eval( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + crop_pct=crop_pct) + + return transform diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index f436ccc7..b781472f 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1 +1,2 @@ from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from .jsd import JsdCrossEntropy \ No newline at end of file diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py new file mode 100644 index 00000000..0f99c699 --- /dev/null +++ b/timm/loss/jsd.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cross_entropy import LabelSmoothingCrossEntropy + + +class JsdCrossEntropy(nn.Module): + """ Jenson-Shannon Divergence + Cross-Entropy Loss + + """ + def __init__(self, num_splits=3, alpha=12, smoothing=0.1): + super().__init__() + self.num_splits = num_splits + self.alpha = alpha + if smoothing is not None and smoothing > 0: + self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) + else: + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def __call__(self, output, target): + split_size = output.shape[0] // self.num_splits + assert split_size * self.num_splits == output.shape[0] + logits_split = torch.split(output, split_size) + + # Cross-entropy is only computed on clean images + loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) + probs = [F.softmax(logits, dim=1) for logits in logits_split] + + # Clamp mixture distribution to avoid exploding KL divergence + logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() + loss += self.alpha * sum([F.kl_div( + logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) + return loss diff --git a/train.py b/train.py index a47f1b4d..9910a059 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,6 @@ import argparse import time -import logging import yaml from datetime import datetime @@ -14,13 +13,16 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch from timm.models import create_model, resume_checkpoint from timm.utils import * -from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler +#FIXME +from timm.data.dataset import AugMixDataset + import torch import torch.nn as nn import torchvision.utils @@ -160,6 +162,10 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--jsd', action='store_true', default=False, + help='') + + def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() @@ -311,8 +317,14 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: + assert not args.jsd collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) + separate_transforms = False + if args.jsd: + dataset_train = AugMixDataset(dataset_train) + separate_transforms = True + loader_train = create_loader( dataset_train, input_size=data_config['input_size'], @@ -330,6 +342,7 @@ def main(): num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, + separate_transforms=separate_transforms, ) eval_dir = os.path.join(args.data, 'val') @@ -354,7 +367,10 @@ def main(): crop_pct=data_config['crop_pct'], ) - if args.mixup > 0.: + if args.jsd: + train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() @@ -452,11 +468,10 @@ def train_epoch( if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: - lam = 1. - if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: - lam = np.random.beta(args.mixup, args.mixup) - input = input.mul(lam).add_(1 - lam, input.flip(0)) - target = mixup_target(target, args.num_classes, lam, args.smoothing) + input, target = mixup_batch( + input, target, + alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, + disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) output = model(input)