From 232ab7fb12ba082e6d4039c7a7c7f2701caa0a71 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 19 Dec 2019 18:16:18 -0800 Subject: [PATCH 1/7] Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl --- timm/data/__init__.py | 5 +- timm/data/auto_augment.py | 182 +++++++++++++++++++++++++------- timm/data/dataset.py | 38 +++++++ timm/data/loader.py | 90 +++++++--------- timm/data/mixup.py | 9 ++ timm/data/transforms.py | 99 ----------------- timm/data/transforms_factory.py | 164 ++++++++++++++++++++++++++++ timm/loss/__init__.py | 1 + timm/loss/jsd.py | 34 ++++++ train.py | 33 ++++-- 10 files changed, 455 insertions(+), 200 deletions(-) create mode 100644 timm/data/transforms_factory.py create mode 100644 timm/loss/jsd.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 49c4bc60..66c16257 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -2,7 +2,8 @@ from .constants import * from .config import resolve_data_config from .dataset import Dataset, DatasetTar from .transforms import * -from .loader import create_loader, create_transform -from .mixup import mixup_target, FastCollateMixup +from .loader import create_loader +from .transforms_factory import create_transform +from .mixup import mixup_batch, FastCollateMixup from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index d730c266..cc1b716c 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -8,12 +8,11 @@ Hacked together by Ross Wightman import random import math import re -from PIL import Image, ImageOps, ImageEnhance +from PIL import Image, ImageOps, ImageEnhance, ImageChops import PIL import numpy as np - _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) _FILL = (128, 128, 128) @@ -192,36 +191,47 @@ def _translate_abs_level_to_arg(level, hparams): return level, -def _translate_rel_level_to_arg(level, _hparams): - # range [-0.45, 0.45] - level = (level / _MAX_LEVEL) * 0.45 +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get('translate_pct', 0.45) + level = (level / _MAX_LEVEL) * translate_pct level = _randomly_negate(level) return level, -def _posterize_original_level_to_arg(level, _hparams): - # As per original AutoAugment paper description - # range [4, 8], 'keep 4 up to 8 MSB of image' - return int((level / _MAX_LEVEL) * 4) + 4, +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4), -def _posterize_research_level_to_arg(level, _hparams): +def _posterize_increasing_level_to_arg(level, hparams): # As per Tensorflow models research and UDA impl - # range [4, 0], 'keep 4 down to 0 MSB of original image' - return 4 - int((level / _MAX_LEVEL) * 4), + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return 4 - _posterize_level_to_arg(level, hparams)[0], -def _posterize_tpu_level_to_arg(level, _hparams): - # As per Tensorflow TPU EfficientNet impl - # range [0, 4], 'keep 0 up to 4 MSB of original image' - return int((level / _MAX_LEVEL) * 4), +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4) + 4, def _solarize_level_to_arg(level, _hparams): # range [0, 256] + # intensity/severity of augmentation decreases with level return int((level / _MAX_LEVEL) * 256), +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return 256 - _solarize_level_to_arg(level, _hparams)[0], + + def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] return int((level / _MAX_LEVEL) * 110), @@ -233,10 +243,11 @@ LEVEL_TO_ARG = { 'Invert': None, 'Rotate': _rotate_level_to_arg, # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'Posterize': _posterize_level_to_arg, + 'PosterizeIncreasing': _posterize_increasing_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg, - 'PosterizeResearch': _posterize_research_level_to_arg, - 'PosterizeTpu': _posterize_tpu_level_to_arg, 'Solarize': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg, 'Color': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg, @@ -256,10 +267,11 @@ NAME_TO_OP = { 'Equalize': equalize, 'Invert': invert, 'Rotate': rotate, + 'Posterize': posterize, + 'PosterizeIncreasing': posterize, 'PosterizeOriginal': posterize, - 'PosterizeResearch': posterize, - 'PosterizeTpu': posterize, 'Solarize': solarize, + 'SolarizeIncreasing': solarize, 'SolarizeAdd': solarize_add, 'Color': color, 'Contrast': contrast, @@ -274,7 +286,7 @@ NAME_TO_OP = { } -class AutoAugmentOp: +class AugmentOp: def __init__(self, name, prob=0.5, magnitude=10, hparams=None): hparams = hparams or _HPARAMS_DEFAULT @@ -295,12 +307,12 @@ class AutoAugmentOp: self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): - if random.random() > self.prob: + if not self.prob >= 1.0 or random.random() > self.prob: return img magnitude = self.magnitude if self.magnitude_std and self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) - magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) @@ -320,7 +332,7 @@ def auto_augment_policy_v0(hparams): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('PosterizeTpu', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -330,16 +342,17 @@ def auto_augment_policy_v0(hparams): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('PosterizeTpu', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc def auto_augment_policy_v0r(hparams): - # ImageNet v0 policy from TPU EfficientNet impl, with research variation of Posterize + # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used + # in Google research implementation (number of bits discarded increases with magnitude) policy = [ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], [('Color', 0.4, 9), ('Equalize', 0.6, 3)], @@ -353,7 +366,7 @@ def auto_augment_policy_v0r(hparams): [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], [('Color', 0.4, 7), ('Equalize', 0.6, 0)], - [('PosterizeResearch', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], [('Solarize', 0.6, 8), ('Color', 0.6, 9)], [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], @@ -363,11 +376,11 @@ def auto_augment_policy_v0r(hparams): [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], - [('PosterizeResearch', 0.8, 2), ('Solarize', 0.6, 10)], + [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], [('Color', 0.8, 6), ('Rotate', 0.4, 5)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc @@ -400,23 +413,23 @@ def auto_augment_policy_original(hparams): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc def auto_augment_policy_originalr(hparams): # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation policy = [ - [('PosterizeResearch', 0.4, 8), ('Rotate', 0.6, 9)], + [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], - [('PosterizeResearch', 0.6, 7), ('PosterizeResearch', 0.6, 6)], + [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], - [('PosterizeResearch', 0.8, 5), ('Equalize', 1.0, 2)], + [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], - [('Equalize', 0.6, 8), ('PosterizeResearch', 0.4, 6)], + [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], [('Rotate', 0.8, 8), ('Color', 0.4, 0)], [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], @@ -433,7 +446,7 @@ def auto_augment_policy_originalr(hparams): [('Color', 0.6, 4), ('Contrast', 1.0, 8)], [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], ] - pc = [[AutoAugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] return pc @@ -499,7 +512,7 @@ _RAND_TRANSFORMS = [ 'Equalize', 'Invert', 'Rotate', - 'PosterizeTpu', + 'Posterize', 'Solarize', 'SolarizeAdd', 'Color', @@ -530,7 +543,7 @@ _RAND_CHOICE_WEIGHTS_0 = { 'Contrast': .005, 'Brightness': .005, 'Equalize': .005, - 'PosterizeTpu': 0, + 'Posterize': 0, 'Invert': 0, } @@ -547,7 +560,7 @@ def _select_rand_weights(weight_idx=0, transforms=None): def rand_augment_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS - return [AutoAugmentOp( + return [AugmentOp( name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] @@ -609,3 +622,94 @@ def rand_augment_transform(config_str, hparams): ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + + +_AUGMIX_TRANSFORMS = [ + 'AutoContrast', + 'Contrast', # not in paper + 'Brightness', # not in paper + 'Sharpness', # not in paper + 'Equalize', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] + + +def augmix_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _AUGMIX_TRANSFORMS + return [AugmentOp( + name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class AugMixAugment: + def __init__(self, ops, alpha=1., width=3, depth=-1): + self.ops = ops + self.alpha = alpha + self.width = width + self.depth = depth + self.recursive = True + + def _apply_recursive(self, img, ws, prod=1.): + alpha = ws[-1] / prod + if len(ws) > 1: + img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha)) + + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + return Image.blend(img, img_aug, alpha) + + def _apply_basic(self, img, ws, m): + w, h = img.size + c = len(img.getbands()) + mixed = np.zeros((w, h, c), dtype=np.float32) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img_aug = np.asarray(img_aug, dtype=np.float32) + mixed += w * img_aug + np.clip(mixed, 0, 255., out=mixed) + mixed = Image.fromarray(mixed.astype(np.uint8)) + return Image.blend(img, mixed, m) + + def __call__(self, img): + mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) + m = np.float32(np.random.beta(self.alpha, self.alpha)) + if self.recursive: + mixing_weights *= m + mixed = self._apply_recursive(img, mixing_weights) + else: + mixed = self._apply_basic(img, mixing_weights, m) + return mixed + + +def augment_and_mix_transform(config_str, hparams): + """Perform AugMix augmentations and compute mixture. + Args: + image: Raw input image as float32 np.ndarray of shape (h, w, c) + severity: Severity of underlying augmentation operators (between 1 to 10). + width: Width of augmentation chain + depth: Depth of augmentation chain. -1 enables stochastic depth uniformly + from [1, 3] + alpha: Probability coefficient for Beta and Dirichlet distributions. + Returns: + mixed: Augmented and mixed image. + """ + # FIXME parse args from config str + severity = 3 + width = 3 + depth = -1 + alpha = 1. + ops = augmix_ops(magnitude=severity, hparams=hparams) + return AugMixAugment(ops, alpha, width, depth) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 47437d5e..3220883a 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -140,3 +140,41 @@ class DatasetTar(data.Dataset): def __len__(self): return len(self.imgs) + +class AugMixDataset(torch.utils.data.Dataset): + """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" + + def __init__(self, dataset, num_aug=2): + self.augmentation = None + self.normalize = None + self.dataset = dataset + if self.dataset.transform is not None: + self._set_transforms(self.dataset.transform) + self.num_aug = num_aug + + def _set_transforms(self, x): + assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' + self.dataset.transform = x[0] + self.augmentation = x[1] + self.normalize = x[2] + + @property + def transform(self): + return self.dataset.transform + + @transform.setter + def transform(self, x): + self._set_transforms(x) + + def _normalize(self, x): + return x if self.normalize is None else self.normalize(x) + + def __getitem__(self, i): + x, y = self.dataset[i] + x_list = [self._normalize(x)] + for n in range(self.num_aug): + x_list.append(self._normalize(self.augmentation(x))) + return tuple(x_list), y + + def __len__(self): + return len(self.dataset) diff --git a/timm/data/loader.py b/timm/data/loader.py index 8c27f1bb..06d431ee 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -1,17 +1,46 @@ import torch.utils.data -from .transforms import * +import numpy as np + +from .transforms_factory import create_transform +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .distributed_sampler import OrderedDistributedSampler +from .random_erasing import RandomErasing from .mixup import FastCollateMixup def fast_collate(batch): - targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) - batch_size = len(targets) - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - for i in range(batch_size): - tensor[i] += torch.from_numpy(batch[i][0]) - - return tensor, targets + """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], tuple): + # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position + # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + inner_tuple_size = len(batch[0][0][0]) + flattened_batch_size = batch_size * inner_tuple_size + targets = torch.zeros(flattened_batch_size, dtype=torch.int64) + tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length + for j in range(inner_tuple_size): + targets[i + j * batch_size] = batch[i][1] + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + return tensor, targets + elif isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + elif isinstance(batch[0][0], torch.Tensor): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i].copy_(batch[i][0]) + return tensor, targets + else: + assert False class PrefetchLoader: @@ -87,49 +116,6 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x -def create_transform( - input_size, - is_training=False, - use_prefetcher=False, - color_jitter=0.4, - auto_augment=None, - interpolation='bilinear', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - crop_pct=None, - tf_preprocessing=False): - - if isinstance(input_size, tuple): - img_size = input_size[-2:] - else: - img_size = input_size - - if tf_preprocessing and use_prefetcher: - from timm.data.tf_preprocessing import TfPreprocessTransform - transform = TfPreprocessTransform( - is_training=is_training, size=img_size, interpolation=interpolation) - else: - if is_training: - transform = transforms_imagenet_train( - img_size, - color_jitter=color_jitter, - auto_augment=auto_augment, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std) - else: - transform = transforms_imagenet_eval( - img_size, - interpolation=interpolation, - use_prefetcher=use_prefetcher, - mean=mean, - std=std, - crop_pct=crop_pct) - - return transform - - def create_loader( dataset, input_size, @@ -150,6 +136,7 @@ def create_loader( collate_fn=None, fp16=False, tf_preprocessing=False, + separate_transforms=False, ): dataset.transform = create_transform( input_size, @@ -162,6 +149,7 @@ def create_loader( std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, + separate=separate_transforms, ) sampler = None diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 83d51ccb..4678472d 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -15,6 +15,15 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): return lam*y1 + (1. - lam)*y2 +def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): + lam = 1. + if not disable: + lam = np.random.beta(alpha, alpha) + input = input.mul(lam).add_(1 - lam, input.flip(0)) + target = mixup_target(target, num_classes, lam, smoothing) + return input, target + + class FastCollateMixup: def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 41f2a63e..b3b08e30 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -1,5 +1,4 @@ import torch -from torchvision import transforms import torchvision.transforms.functional as F from PIL import Image import warnings @@ -7,10 +6,6 @@ import math import random import numpy as np -from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .random_erasing import RandomErasing -from .auto_augment import auto_augment_transform, rand_augment_transform - class ToNumpy: @@ -161,97 +156,3 @@ class RandomResizedCropAndInterpolation: return format_string -def transforms_imagenet_train( - img_size=224, - scale=(0.08, 1.0), - color_jitter=0.4, - auto_augment=None, - interpolation='random', - random_erasing=0.4, - random_erasing_mode='const', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD -): - tfl = [ - RandomResizedCropAndInterpolation( - img_size, scale=scale, interpolation=interpolation), - transforms.RandomHorizontalFlip() - ] - if auto_augment: - assert isinstance(auto_augment, str) - if isinstance(img_size, tuple): - img_size_min = min(img_size) - else: - img_size_min = img_size - aa_params = dict( - translate_const=int(img_size_min * 0.45), - img_mean=tuple([min(255, round(255 * x)) for x in mean]), - ) - if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - tfl += [rand_augment_transform(auto_augment, aa_params)] - else: - tfl += [auto_augment_transform(auto_augment, aa_params)] - else: - # color jitter is enabled when not using AA - if isinstance(color_jitter, (list, tuple)): - # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation - # or 4 if also augmenting hue - assert len(color_jitter) in (3, 4) - else: - # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue - color_jitter = (float(color_jitter),) * 3 - tfl += [transforms.ColorJitter(*color_jitter)] - - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - if random_erasing > 0.: - tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) - return transforms.Compose(tfl) - - -def transforms_imagenet_eval( - img_size=224, - crop_pct=None, - interpolation='bilinear', - use_prefetcher=False, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD): - crop_pct = crop_pct or DEFAULT_CROP_PCT - - if isinstance(img_size, tuple): - assert len(img_size) == 2 - if img_size[-1] == img_size[-2]: - # fall-back to older behaviour so Resize scales to shortest edge if target is square - scale_size = int(math.floor(img_size[0] / crop_pct)) - else: - scale_size = tuple([int(x / crop_pct) for x in img_size]) - else: - scale_size = int(math.floor(img_size / crop_pct)) - - tfl = [ - transforms.Resize(scale_size, _pil_interp(interpolation)), - transforms.CenterCrop(img_size), - ] - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: - tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) - ] - - return transforms.Compose(tfl) diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py new file mode 100644 index 00000000..b70ae76f --- /dev/null +++ b/timm/data/transforms_factory.py @@ -0,0 +1,164 @@ +import math + +import torch +from torchvision import transforms + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform +from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.random_erasing import RandomErasing + + +def transforms_imagenet_train( + img_size=224, + scale=(0.08, 1.0), + color_jitter=0.4, + auto_augment=None, + interpolation='random', + random_erasing=0.4, + random_erasing_mode='const', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + separate=False, +): + + primary_tfl = [ + RandomResizedCropAndInterpolation( + img_size, scale=scale, interpolation=interpolation), + transforms.RandomHorizontalFlip() + ] + + secondary_tfl = [] + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] + elif auto_augment.startswith('augmix'): + aa_params['translate_pct'] = 0.3 + secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] + else: + secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] + elif color_jitter is not None: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + secondary_tfl += [transforms.ColorJitter(*color_jitter)] + + final_tfl = [] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + final_tfl += [ToNumpy()] + else: + final_tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if random_erasing > 0.: + final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) + + if separate: + return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + else: + return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + + +def transforms_imagenet_eval( + img_size=224, + crop_pct=None, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): + crop_pct = crop_pct or DEFAULT_CROP_PCT + + if isinstance(img_size, tuple): + assert len(img_size) == 2 + if img_size[-1] == img_size[-2]: + # fall-back to older behaviour so Resize scales to shortest edge if target is square + scale_size = int(math.floor(img_size[0] / crop_pct)) + else: + scale_size = tuple([int(x / crop_pct) for x in img_size]) + else: + scale_size = int(math.floor(img_size / crop_pct)) + + tfl = [ + transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size), + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + + return transforms.Compose(tfl) + + +def create_transform( + input_size, + is_training=False, + use_prefetcher=False, + color_jitter=0.4, + auto_augment=None, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + crop_pct=None, + tf_preprocessing=False, + separate=False): + + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if tf_preprocessing and use_prefetcher: + assert not separate, "Separate transforms not supported for TF preprocessing" + from timm.data.tf_preprocessing import TfPreprocessTransform + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=interpolation) + else: + if is_training: + transform = transforms_imagenet_train( + img_size, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + separate=separate) + else: + assert not separate, "Separate transforms not supported for validation preprocessing" + transform = transforms_imagenet_eval( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + crop_pct=crop_pct) + + return transform diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index f436ccc7..b781472f 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1 +1,2 @@ from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from .jsd import JsdCrossEntropy \ No newline at end of file diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py new file mode 100644 index 00000000..0f99c699 --- /dev/null +++ b/timm/loss/jsd.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cross_entropy import LabelSmoothingCrossEntropy + + +class JsdCrossEntropy(nn.Module): + """ Jenson-Shannon Divergence + Cross-Entropy Loss + + """ + def __init__(self, num_splits=3, alpha=12, smoothing=0.1): + super().__init__() + self.num_splits = num_splits + self.alpha = alpha + if smoothing is not None and smoothing > 0: + self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) + else: + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def __call__(self, output, target): + split_size = output.shape[0] // self.num_splits + assert split_size * self.num_splits == output.shape[0] + logits_split = torch.split(output, split_size) + + # Cross-entropy is only computed on clean images + loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) + probs = [F.softmax(logits, dim=1) for logits in logits_split] + + # Clamp mixture distribution to avoid exploding KL divergence + logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() + loss += self.alpha * sum([F.kl_div( + logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) + return loss diff --git a/train.py b/train.py index a47f1b4d..9910a059 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,6 @@ import argparse import time -import logging import yaml from datetime import datetime @@ -14,13 +13,16 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch from timm.models import create_model, resume_checkpoint from timm.utils import * -from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler +#FIXME +from timm.data.dataset import AugMixDataset + import torch import torch.nn as nn import torchvision.utils @@ -160,6 +162,10 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--jsd', action='store_true', default=False, + help='') + + def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() @@ -311,8 +317,14 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: + assert not args.jsd collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) + separate_transforms = False + if args.jsd: + dataset_train = AugMixDataset(dataset_train) + separate_transforms = True + loader_train = create_loader( dataset_train, input_size=data_config['input_size'], @@ -330,6 +342,7 @@ def main(): num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, + separate_transforms=separate_transforms, ) eval_dir = os.path.join(args.data, 'val') @@ -354,7 +367,10 @@ def main(): crop_pct=data_config['crop_pct'], ) - if args.mixup > 0.: + if args.jsd: + train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + elif args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() @@ -452,11 +468,10 @@ def train_epoch( if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: - lam = 1. - if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: - lam = np.random.beta(args.mixup, args.mixup) - input = input.mul(lam).add_(1 - lam, input.flip(0)) - target = mixup_target(target, args.num_classes, lam, args.smoothing) + input, target = mixup_batch( + input, target, + alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, + disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) output = model(input) From 3afc2a4dc0db55919a897f8e2af8aeb315b10703 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 20 Dec 2019 22:32:05 -0800 Subject: [PATCH 2/7] Some cleanup/improvements to AugMix impl: * make 'increasing' levels for Contrast, Color, Brightness, Saturation ops * remove recursion from faster blending mix * add config striing parsing for AugMix --- timm/data/auto_augment.py | 146 +++++++++++++++++++++++++++++--------- 1 file changed, 112 insertions(+), 34 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index cc1b716c..864ca6e0 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -177,6 +177,14 @@ def _enhance_level_to_arg(level, _hparams): return (level / _MAX_LEVEL) * 1.8 + 0.1, +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * .9 + level = 1.0 + _randomly_negate(level) + return level, + + def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _MAX_LEVEL) * 0.3 @@ -247,12 +255,16 @@ LEVEL_TO_ARG = { 'PosterizeIncreasing': _posterize_increasing_level_to_arg, 'PosterizeOriginal': _posterize_original_level_to_arg, 'Solarize': _solarize_level_to_arg, - 'SolarizeIncreasing': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_increasing_level_to_arg, 'SolarizeAdd': _solarize_add_level_to_arg, 'Color': _enhance_level_to_arg, + 'ColorIncreasing': _enhance_increasing_level_to_arg, 'Contrast': _enhance_level_to_arg, + 'ContrastIncreasing': _enhance_increasing_level_to_arg, 'Brightness': _enhance_level_to_arg, + 'BrightnessIncreasing': _enhance_increasing_level_to_arg, 'Sharpness': _enhance_level_to_arg, + 'SharpnessIncreasing': _enhance_increasing_level_to_arg, 'ShearX': _shear_level_to_arg, 'ShearY': _shear_level_to_arg, 'TranslateX': _translate_abs_level_to_arg, @@ -274,9 +286,13 @@ NAME_TO_OP = { 'SolarizeIncreasing': solarize, 'SolarizeAdd': solarize_add, 'Color': color, + 'ColorIncreasing': color, 'Contrast': contrast, + 'ContrastIncreasing': contrast, 'Brightness': brightness, + 'BrightnessIncreasing': brightness, 'Sharpness': sharpness, + 'SharpnessIncreasing': sharpness, 'ShearX': shear_x, 'ShearY': shear_y, 'TranslateX': translate_x_abs, @@ -527,6 +543,27 @@ _RAND_TRANSFORMS = [ ] +_RAND_INCREASING_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'SolarizeAdd', + 'ColorIncreasing', + 'ContrastIncreasing', + 'BrightnessIncreasing', + 'SharpnessIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # FIXME I implement this as random erasing separately +] + + + # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { @@ -626,9 +663,10 @@ def rand_augment_transform(config_str, hparams): _AUGMIX_TRANSFORMS = [ 'AutoContrast', - 'Contrast', # not in paper - 'Brightness', # not in paper - 'Sharpness', # not in paper + 'ColorIncreasing', # not in paper + 'ContrastIncreasing', # not in paper + 'BrightnessIncreasing', # not in paper + 'SharpnessIncreasing', # not in paper 'Equalize', 'Rotate', 'PosterizeIncreasing', @@ -653,21 +691,38 @@ class AugMixAugment: self.alpha = alpha self.width = width self.depth = depth - self.recursive = True - - def _apply_recursive(self, img, ws, prod=1.): - alpha = ws[-1] / prod - if len(ws) > 1: - img = self._apply_recursive(img, ws[:-1], prod * (1 - alpha)) - - depth = self.depth if self.depth > 0 else np.random.randint(1, 4) - ops = np.random.choice(self.ops, depth, replace=True) - img_aug = img # no ops are in-place, deep copy not necessary - for op in ops: - img_aug = op(img_aug) - return Image.blend(img, img_aug, alpha) + self.blended = True + + def _calc_blended_weights(self, ws, m): + ws = ws * m + cump = 1. + rws = [] + for w in ws[::-1]: + alpha = w / cump + cump *= (1 - alpha) + rws.append(alpha) + return np.array(rws[::-1], dtype=np.float32) + + def _apply_blended(self, img, ws, m): + # This is my first crack and implementing a slightly faster mixed augmentation. Instead + # of accumulating the mix for each chain in a Numpy array and then blending with original, + # it recomputes the blending coefficients and applies one PIL image blend per chain. + # TODO I've verified the results are in the right ballpark but they differ by more than rounding. + img_orig = img.copy() + ws = self._calc_blended_weights(ws, m) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img_orig # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img = Image.blend(img, img_aug, w) + return img def _apply_basic(self, img, ws, m): + # This is a literal adaptation of the paper/official implementation without normalizations and + # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the + # typical augmentation transforms, could use a GPU / Kornia implementation. w, h = img.size c = len(img.getbands()) mixed = np.zeros((w, h, c), dtype=np.float32) @@ -686,30 +741,53 @@ class AugMixAugment: def __call__(self, img): mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) m = np.float32(np.random.beta(self.alpha, self.alpha)) - if self.recursive: - mixing_weights *= m - mixed = self._apply_recursive(img, mixing_weights) + if self.blended: + mixed = self._apply_blended(img, mixing_weights, m) else: mixed = self._apply_basic(img, mixing_weights, m) return mixed def augment_and_mix_transform(config_str, hparams): - """Perform AugMix augmentations and compute mixture. - Args: - image: Raw input image as float32 np.ndarray of shape (h, w, c) - severity: Severity of underlying augmentation operators (between 1 to 10). - width: Width of augmentation chain - depth: Depth of augmentation chain. -1 enables stochastic depth uniformly - from [1, 3] - alpha: Probability coefficient for Beta and Dirichlet distributions. - Returns: - mixed: Augmented and mixed image. + """ Create AugMix PyTorch transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + :param hparams: Other hparams (kwargs) for the Augmentation transforms + + :return: A PyTorch compatible Transform """ - # FIXME parse args from config str - severity = 3 + magnitude = 3 width = 3 depth = -1 alpha = 1. - ops = augmix_ops(magnitude=severity, hparams=hparams) - return AugMixAugment(ops, alpha, width, depth) + config = config_str.split('-') + assert config[0] == 'augmix' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'w': + width = int(val) + elif key == 'd': + depth = int(val) + elif key == 'a': + alpha = float(val) + else: + assert False, 'Unknown AugMix config section' + ops = augmix_ops(magnitude=magnitude, hparams=hparams) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth) From 3cc0f91e232e4c71d37834e81079ee7ad5fbeadd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 5 Jan 2020 14:27:27 -0800 Subject: [PATCH 3/7] Fix augmix variable name scope overlap, default non-blended mode --- timm/data/auto_augment.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 864ca6e0..ec2602b3 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -691,7 +691,7 @@ class AugMixAugment: self.alpha = alpha self.width = width self.depth = depth - self.blended = True + self.blended = False def _calc_blended_weights(self, ws, m): ws = ws * m @@ -703,13 +703,13 @@ class AugMixAugment: rws.append(alpha) return np.array(rws[::-1], dtype=np.float32) - def _apply_blended(self, img, ws, m): + def _apply_blended(self, img, mixing_weights, m): # This is my first crack and implementing a slightly faster mixed augmentation. Instead # of accumulating the mix for each chain in a Numpy array and then blending with original, # it recomputes the blending coefficients and applies one PIL image blend per chain. # TODO I've verified the results are in the right ballpark but they differ by more than rounding. img_orig = img.copy() - ws = self._calc_blended_weights(ws, m) + ws = self._calc_blended_weights(mixing_weights, m) for w in ws: depth = self.depth if self.depth > 0 else np.random.randint(1, 4) ops = np.random.choice(self.ops, depth, replace=True) @@ -719,21 +719,19 @@ class AugMixAugment: img = Image.blend(img, img_aug, w) return img - def _apply_basic(self, img, ws, m): + def _apply_basic(self, img, mixing_weights, m): # This is a literal adaptation of the paper/official implementation without normalizations and # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the # typical augmentation transforms, could use a GPU / Kornia implementation. - w, h = img.size - c = len(img.getbands()) - mixed = np.zeros((w, h, c), dtype=np.float32) - for w in ws: + img_shape = img.size[0], img.size[1], len(img.getbands()) + mixed = np.zeros(img_shape, dtype=np.float32) + for mw in mixing_weights: depth = self.depth if self.depth > 0 else np.random.randint(1, 4) ops = np.random.choice(self.ops, depth, replace=True) img_aug = img # no ops are in-place, deep copy not necessary for op in ops: img_aug = op(img_aug) - img_aug = np.asarray(img_aug, dtype=np.float32) - mixed += w * img_aug + mixed += mw * np.asarray(img_aug, dtype=np.float32) np.clip(mixed, 0, 255., out=mixed) mixed = Image.fromarray(mixed.astype(np.uint8)) return Image.blend(img, mixed, m) From 2e955cfd0cb750526db2bc02363f42334fc88626 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 5 Jan 2020 14:31:48 -0800 Subject: [PATCH 4/7] Update RandomErasing with some improved arg names, tweak to aspect range --- timm/data/random_erasing.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 5eed1387..2b6b61a5 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -23,13 +23,13 @@ class RandomErasing: This variant of RandomErasing is intended to be applied to either a batch or single image tensor after it has been normalized by dataset mean and std. Args: - probability: The probability that the Random Erasing operation will be performed. - sl: Minimum proportion of erased area against input image. - sh: Maximum proportion of erased area against input image. + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. min_aspect: Minimum aspect ratio of erased area. mode: pixel color mode, one of 'const', 'rand', or 'pixel' 'const' - erase block is constant color of 0 for all channels - 'rand' - erase block is same per-cannel random (normal) color + 'rand' - erase block is same per-channel random (normal) color 'pixel' - erase block is per-pixel random (normal) color max_count: maximum number of erasing blocks per image, area per box is scaled by count. per-image count is randomly chosen between 1 and this value. @@ -37,14 +37,15 @@ class RandomErasing: def __init__( self, - probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - mode='const', max_count=1, device='cuda'): + probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, + mode='const', min_count=1, max_count=None, device='cuda'): self.probability = probability - self.sl = sl - self.sh = sh - self.min_aspect = min_aspect - self.min_count = 1 - self.max_count = max_count + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -64,9 +65,8 @@ class RandomErasing: random.randint(self.min_count, self.max_count) for _ in range(count): for attempt in range(10): - target_area = random.uniform(self.sl, self.sh) * area / count - log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect)) - aspect_ratio = math.exp(random.uniform(*log_ratio)) + target_area = random.uniform(self.min_area, self.max_area) * area / count + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: From 75471198915d53e1237d3fb1ff95f7abed5032ea Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 5 Jan 2020 19:58:59 -0800 Subject: [PATCH 5/7] Add SplitBatchNorm. AugMix, Rand/AutoAugment, Split (Aux) BatchNorm, Jensen-Shannon Divergence, RandomErasing all working together --- timm/data/__init__.py | 2 +- timm/data/auto_augment.py | 25 ++++++++----- timm/data/dataset.py | 11 +++--- timm/data/loader.py | 50 ++++++++++++++++---------- timm/data/random_erasing.py | 7 ++-- timm/data/transforms_factory.py | 19 +++++++--- timm/loss/jsd.py | 2 +- timm/models/__init__.py | 1 + timm/models/split_batchnorm.py | 64 +++++++++++++++++++++++++++++++++ train.py | 52 ++++++++++++++++----------- 10 files changed, 173 insertions(+), 60 deletions(-) create mode 100644 timm/models/split_batchnorm.py diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 66c16257..ee2240b4 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .constants import * from .config import resolve_data_config -from .dataset import Dataset, DatasetTar +from .dataset import Dataset, DatasetTar, AugMixDataset from .transforms import * from .loader import create_loader from .transforms_factory import create_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index ec2602b3..8d7b36f9 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -323,7 +323,7 @@ class AugmentOp: self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): - if not self.prob >= 1.0 or random.random() > self.prob: + if self.prob < 1.0 and random.random() > self.prob: return img magnitude = self.magnitude if self.magnitude_std and self.magnitude_std > 0: @@ -539,7 +539,7 @@ _RAND_TRANSFORMS = [ 'ShearY', 'TranslateXRel', 'TranslateYRel', - #'Cutout' # FIXME I implement this as random erasing separately + #'Cutout' # NOTE I've implement this as random erasing separately ] @@ -559,7 +559,7 @@ _RAND_INCREASING_TRANSFORMS = [ 'ShearY', 'TranslateXRel', 'TranslateYRel', - #'Cutout' # FIXME I implement this as random erasing separately + #'Cutout' # NOTE I've implement this as random erasing separately ] @@ -627,6 +627,7 @@ def rand_augment_transform(config_str, hparams): 'n' - integer num layers (number of transform ops selected per image) 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 @@ -637,6 +638,7 @@ def rand_augment_transform(config_str, hparams): magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) num_layers = 2 # default to 2 ops per image weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS config = config_str.split('-') assert config[0] == 'rand' config = config[1:] @@ -648,6 +650,9 @@ def rand_augment_transform(config_str, hparams): if key == 'mstd': # noise param injected via hparams for now hparams.setdefault('magnitude_std', float(val)) + elif key == 'inc': + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS elif key == 'm': magnitude = int(val) elif key == 'n': @@ -656,7 +661,7 @@ def rand_augment_transform(config_str, hparams): weight_idx = int(val) else: assert False, 'Unknown RandAugment config section' - ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) @@ -686,12 +691,12 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None): class AugMixAugment: - def __init__(self, ops, alpha=1., width=3, depth=-1): + def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): self.ops = ops self.alpha = alpha self.width = width self.depth = depth - self.blended = False + self.blended = blended def _calc_blended_weights(self, ws, m): ws = ws * m @@ -707,7 +712,7 @@ class AugMixAugment: # This is my first crack and implementing a slightly faster mixed augmentation. Instead # of accumulating the mix for each chain in a Numpy array and then blending with original, # it recomputes the blending coefficients and applies one PIL image blend per chain. - # TODO I've verified the results are in the right ballpark but they differ by more than rounding. + # TODO the results appear in the right ballpark but they differ by more than rounding. img_orig = img.copy() ws = self._calc_blended_weights(mixing_weights, m) for w in ws: @@ -755,6 +760,7 @@ def augment_and_mix_transform(config_str, hparams): 'm' - integer magnitude (severity) of augmentation mix (default: 3) 'w' - integer width of augmentation chain (default: 3) 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) 'mstd' - float std deviation of magnitude noise applied (default: 0) Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 @@ -766,6 +772,7 @@ def augment_and_mix_transform(config_str, hparams): width = 3 depth = -1 alpha = 1. + blended = False config = config_str.split('-') assert config[0] == 'augmix' config = config[1:] @@ -785,7 +792,9 @@ def augment_and_mix_transform(config_str, hparams): depth = int(val) elif key == 'a': alpha = float(val) + elif key == 'b': + blended = bool(val) else: assert False, 'Unknown AugMix config section' ops = augmix_ops(magnitude=magnitude, hparams=hparams) - return AugMixAugment(ops, alpha=alpha, width=width, depth=depth) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 3220883a..fc252d9e 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -144,13 +144,13 @@ class DatasetTar(data.Dataset): class AugMixDataset(torch.utils.data.Dataset): """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" - def __init__(self, dataset, num_aug=2): + def __init__(self, dataset, num_splits=2): self.augmentation = None self.normalize = None self.dataset = dataset if self.dataset.transform is not None: self._set_transforms(self.dataset.transform) - self.num_aug = num_aug + self.num_splits = num_splits def _set_transforms(self, x): assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' @@ -170,9 +170,10 @@ class AugMixDataset(torch.utils.data.Dataset): return x if self.normalize is None else self.normalize(x) def __getitem__(self, i): - x, y = self.dataset[i] - x_list = [self._normalize(x)] - for n in range(self.num_aug): + x, y = self.dataset[i] # all splits share the same dataset base transform + x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) + # run the full augmentation on the remaining splits + for _ in range(self.num_splits - 1): x_list.append(self._normalize(self.augmentation(x))) return tuple(x_list), y diff --git a/timm/data/loader.py b/timm/data/loader.py index 06d431ee..e2ec8797 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -15,7 +15,7 @@ def fast_collate(batch): if isinstance(batch[0][0], tuple): # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position - inner_tuple_size = len(batch[0][0][0]) + inner_tuple_size = len(batch[0][0]) flattened_batch_size = batch_size * inner_tuple_size targets = torch.zeros(flattened_batch_size, dtype=torch.int64) tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) @@ -46,13 +46,14 @@ def fast_collate(batch): class PrefetchLoader: def __init__(self, - loader, - rand_erase_prob=0., - rand_erase_mode='const', - rand_erase_count=1, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - fp16=False): + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): self.loader = loader self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) @@ -60,9 +61,9 @@ class PrefetchLoader: if fp16: self.mean = self.mean.half() self.std = self.std.half() - if rand_erase_prob > 0.: + if re_prob > 0.: self.random_erasing = RandomErasing( - probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count) + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) else: self.random_erasing = None @@ -122,11 +123,13 @@ def create_loader( batch_size, is_training=False, use_prefetcher=True, - rand_erase_prob=0., - rand_erase_mode='const', - rand_erase_count=1, + re_prob=0., + re_mode='const', + re_count=1, + re_split=False, color_jitter=0.4, auto_augment=None, + num_aug_splits=0, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -136,8 +139,11 @@ def create_loader( collate_fn=None, fp16=False, tf_preprocessing=False, - separate_transforms=False, ): + re_num_splits = 0 + if re_split: + # apply RE to second half of batch if no aug split otherwise line up with aug split + re_num_splits = num_aug_splits or 2 dataset.transform = create_transform( input_size, is_training=is_training, @@ -149,7 +155,11 @@ def create_loader( std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, - separate=separate_transforms, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=num_aug_splits > 0, ) sampler = None @@ -176,11 +186,13 @@ def create_loader( if use_prefetcher: loader = PrefetchLoader( loader, - rand_erase_prob=rand_erase_prob if is_training else 0., - rand_erase_mode=rand_erase_mode, - rand_erase_count=rand_erase_count, mean=mean, std=std, - fp16=fp16) + fp16=fp16, + re_prob=re_prob if is_training else 0., + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits + ) return loader diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 2b6b61a5..589b2f0b 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -38,7 +38,7 @@ class RandomErasing: def __init__( self, probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, - mode='const', min_count=1, max_count=None, device='cuda'): + mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): self.probability = probability self.min_area = min_area self.max_area = max_area @@ -46,6 +46,7 @@ class RandomErasing: self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) self.min_count = min_count self.max_count = max_count or min_count + self.num_splits = num_splits mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -82,6 +83,8 @@ class RandomErasing: self._erase(input, *input.size(), input.dtype) else: batch_size, chan, img_h, img_w = input.size() - for i in range(batch_size): + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + for i in range(batch_start, batch_size): self._erase(input[i], chan, img_h, img_w, input.dtype) return input diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index b70ae76f..faf55b70 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -15,11 +15,13 @@ def transforms_imagenet_train( color_jitter=0.4, auto_augment=None, interpolation='random', - random_erasing=0.4, - random_erasing_mode='const', use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, separate=False, ): @@ -71,8 +73,9 @@ def transforms_imagenet_train( mean=torch.tensor(mean), std=torch.tensor(std)) ] - if random_erasing > 0.: - final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu')) + if re_prob > 0.: + final_tfl.append( + RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) if separate: return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) @@ -126,6 +129,10 @@ def create_transform( interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, crop_pct=None, tf_preprocessing=False, separate=False): @@ -150,6 +157,10 @@ def create_transform( use_prefetcher=use_prefetcher, mean=mean, std=std, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, separate=separate) else: assert not separate, "Separate transforms not supported for validation preprocessing" diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py index 0f99c699..ad6ca1e5 100644 --- a/timm/loss/jsd.py +++ b/timm/loss/jsd.py @@ -6,7 +6,7 @@ from .cross_entropy import LabelSmoothingCrossEntropy class JsdCrossEntropy(nn.Module): - """ Jenson-Shannon Divergence + Cross-Entropy Loss + """ Jensen-Shannon Divergence + Cross-Entropy Loss """ def __init__(self, num_splits=3, alpha=12, smoothing=0.1): diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 7119c4f5..a0be7bd0 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -20,3 +20,4 @@ from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .split_batchnorm import convert_splitbn_model diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py new file mode 100644 index 00000000..0ed30d77 --- /dev/null +++ b/timm/models/split_batchnorm.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=1): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> import apex + >>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/train.py b/train.py index 9910a059..bb6db08d 100644 --- a/train.py +++ b/train.py @@ -13,16 +13,13 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch -from timm.models import create_model, resume_checkpoint +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset +from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer from timm.scheduler import create_scheduler -#FIXME -from timm.data.dataset import AugMixDataset - import torch import torch.nn as nn import torchvision.utils @@ -71,6 +68,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.)') parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP', help='Drop connect rate (default: 0.)') +parser.add_argument('--jsd', action='store_true', default=False, + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -106,18 +105,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), +parser.add_argument('--aug-splits', type=int, default=0, + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') +parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +parser.add_argument('--train-interpolation', type=str, default='random', + help='Training interpolation (random, bilinear, bicubic default: "random")') # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') @@ -129,6 +134,8 @@ parser.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') parser.add_argument('--dist-bn', type=str, default='', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') +parser.add_argument('--split-bn', action='store_true', + help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') @@ -162,10 +169,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument('--jsd', action='store_true', default=False, - help='') - - def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() @@ -233,6 +236,14 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + num_aug_splits = 0 + if args.aug_splits: + num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense + + if args.split_bn: + assert num_aug_splits > 1 or args.resplit + model = convert_splitbn_model(model, max(num_aug_splits, 2)) + if args.num_gpu > 1: if args.amp: logging.warning( @@ -279,6 +290,7 @@ def main(): if args.distributed: if args.sync_bn: + assert not args.split_bn try: if has_apex: model = convert_syncbn_model(model) @@ -317,13 +329,11 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: - assert not args.jsd + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) - separate_transforms = False - if args.jsd: - dataset_train = AugMixDataset(dataset_train) - separate_transforms = True + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) loader_train = create_loader( dataset_train, @@ -331,18 +341,19 @@ def main(): batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, - rand_erase_prob=args.reprob, - rand_erase_mode=args.remode, - rand_erase_count=args.recount, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + re_split=args.resplit, color_jitter=args.color_jitter, auto_augment=args.aa, - interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], + num_aug_splits=num_aug_splits, + interpolation=args.train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, - separate_transforms=separate_transforms, ) eval_dir = os.path.join(args.data, 'val') @@ -368,7 +379,8 @@ def main(): ) if args.jsd: - train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda() + assert num_aug_splits > 1 # JSD only valid with aug splits set + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.mixup > 0.: # smoothing is handled with mixup label transform From 833066b540ca097e111dfcd4bf493cb4f0902e15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 5 Jan 2020 20:07:03 -0800 Subject: [PATCH 6/7] A few minor things in SplitBN --- timm/models/split_batchnorm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py index 0ed30d77..327c35ba 100644 --- a/timm/models/split_batchnorm.py +++ b/timm/models/split_batchnorm.py @@ -6,9 +6,9 @@ import torch.nn.functional as F class SplitBatchNorm2d(torch.nn.BatchNorm2d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, - track_running_stats=True, num_splits=1): + track_running_stats=True, num_splits=2): super().__init__(num_features, eps, momentum, affine, track_running_stats) - assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)' + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' self.num_splits = num_splits self.aux_bn = nn.ModuleList([ nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) @@ -35,8 +35,7 @@ def convert_splitbn_model(module, num_splits=2): num_splits: number of separate batchnorm layers to split input across Example:: >>> # model is an instance of torch.nn.Module - >>> import apex - >>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2) + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) """ mod = module if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): From 3eb4a96edaecd272305fdf97fe50b22939153687 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 11 Jan 2020 12:02:05 -0800 Subject: [PATCH 7/7] Update AugMix, JSD, etc comments and references --- timm/data/auto_augment.py | 25 +++++++++++++++++++++---- timm/data/transforms_factory.py | 11 ++++++++++- timm/loss/jsd.py | 5 +++++ timm/models/split_batchnorm.py | 14 +++++++++++++- train.py | 5 +++-- 5 files changed, 52 insertions(+), 8 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 8d7b36f9..e355eef5 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,7 +1,19 @@ -""" AutoAugment and RandAugment -Implementation adapted from: +""" AutoAugment, RandAugment, and AugMix for PyTorch + +This code implements the searched ImageNet policies with various tweaks and improvements and +does not include any of the search code. + +AA and RA Implementation adapted from: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py -Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 + +AugMix adapted from: + https://github.com/google-research/augmix + +Papers: + AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 Hacked together by Ross Wightman """ @@ -691,12 +703,17 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None): class AugMixAugment: + """ AugMix Transform + Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + """ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): self.ops = ops self.alpha = alpha self.width = width self.depth = depth - self.blended = blended + self.blended = blended # blended mode is faster but not well tested def _calc_blended_weights(self, ws, m): ws = ws * m diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index faf55b70..767dd157 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -1,3 +1,6 @@ +""" Transforms Factory +Factory methods for building image transforms for use with TIMM (PyTorch Image Models) +""" import math import torch @@ -24,7 +27,13 @@ def transforms_imagenet_train( re_num_splits=0, separate=False, ): - + """ + If separate==True, the transforms are returned as a tuple of 3 separate transforms + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform + """ primary_tfl = [ RandomResizedCropAndInterpolation( img_size, scale=scale, interpolation=interpolation), diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py index ad6ca1e5..0f8eb696 100644 --- a/timm/loss/jsd.py +++ b/timm/loss/jsd.py @@ -8,6 +8,11 @@ from .cross_entropy import LabelSmoothingCrossEntropy class JsdCrossEntropy(nn.Module): """ Jensen-Shannon Divergence + Cross-Entropy Loss + Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + + Hacked together by Ross Wightman """ def __init__(self, num_splits=3, alpha=12, smoothing=0.1): super().__init__() diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py index 327c35ba..ad01cfeb 100644 --- a/timm/models/split_batchnorm.py +++ b/timm/models/split_batchnorm.py @@ -1,6 +1,18 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by Ross Wightman +""" import torch import torch.nn as nn -import torch.nn.functional as F class SplitBatchNorm2d(torch.nn.BatchNorm2d): diff --git a/train.py b/train.py index bb6db08d..4db2d7c3 100644 --- a/train.py +++ b/train.py @@ -237,8 +237,9 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 - if args.aug_splits: - num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense + if args.aug_splits > 0: + assert args.aug_splits > 1, 'A split of 1 makes no sense' + num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit