Add SplitBatchNorm. AugMix, Rand/AutoAugment, Split (Aux) BatchNorm, Jensen-Shannon Divergence, RandomErasing all working together

pull/74/head
Ross Wightman 4 years ago
parent 2e955cfd0c
commit 7547119891

@ -1,6 +1,6 @@
from .constants import *
from .config import resolve_data_config
from .dataset import Dataset, DatasetTar
from .dataset import Dataset, DatasetTar, AugMixDataset
from .transforms import *
from .loader import create_loader
from .transforms_factory import create_transform

@ -323,7 +323,7 @@ class AugmentOp:
self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img):
if not self.prob >= 1.0 or random.random() > self.prob:
if self.prob < 1.0 and random.random() > self.prob:
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
@ -539,7 +539,7 @@ _RAND_TRANSFORMS = [
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # FIXME I implement this as random erasing separately
#'Cutout' # NOTE I've implement this as random erasing separately
]
@ -559,7 +559,7 @@ _RAND_INCREASING_TRANSFORMS = [
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # FIXME I implement this as random erasing separately
#'Cutout' # NOTE I've implement this as random erasing separately
]
@ -627,6 +627,7 @@ def rand_augment_transform(config_str, hparams):
'n' - integer num layers (number of transform ops selected per image)
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
'mstd' - float std deviation of magnitude noise applied
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
@ -637,6 +638,7 @@ def rand_augment_transform(config_str, hparams):
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image
weight_idx = None # default to no probability weights for op choice
transforms = _RAND_TRANSFORMS
config = config_str.split('-')
assert config[0] == 'rand'
config = config[1:]
@ -648,6 +650,9 @@ def rand_augment_transform(config_str, hparams):
if key == 'mstd':
# noise param injected via hparams for now
hparams.setdefault('magnitude_std', float(val))
elif key == 'inc':
if bool(val):
transforms = _RAND_INCREASING_TRANSFORMS
elif key == 'm':
magnitude = int(val)
elif key == 'n':
@ -656,7 +661,7 @@ def rand_augment_transform(config_str, hparams):
weight_idx = int(val)
else:
assert False, 'Unknown RandAugment config section'
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams)
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
@ -686,12 +691,12 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
class AugMixAugment:
def __init__(self, ops, alpha=1., width=3, depth=-1):
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
self.ops = ops
self.alpha = alpha
self.width = width
self.depth = depth
self.blended = False
self.blended = blended
def _calc_blended_weights(self, ws, m):
ws = ws * m
@ -707,7 +712,7 @@ class AugMixAugment:
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
# of accumulating the mix for each chain in a Numpy array and then blending with original,
# it recomputes the blending coefficients and applies one PIL image blend per chain.
# TODO I've verified the results are in the right ballpark but they differ by more than rounding.
# TODO the results appear in the right ballpark but they differ by more than rounding.
img_orig = img.copy()
ws = self._calc_blended_weights(mixing_weights, m)
for w in ws:
@ -755,6 +760,7 @@ def augment_and_mix_transform(config_str, hparams):
'm' - integer magnitude (severity) of augmentation mix (default: 3)
'w' - integer width of augmentation chain (default: 3)
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
'mstd' - float std deviation of magnitude noise applied (default: 0)
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
@ -766,6 +772,7 @@ def augment_and_mix_transform(config_str, hparams):
width = 3
depth = -1
alpha = 1.
blended = False
config = config_str.split('-')
assert config[0] == 'augmix'
config = config[1:]
@ -785,7 +792,9 @@ def augment_and_mix_transform(config_str, hparams):
depth = int(val)
elif key == 'a':
alpha = float(val)
elif key == 'b':
blended = bool(val)
else:
assert False, 'Unknown AugMix config section'
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth)
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)

@ -144,13 +144,13 @@ class DatasetTar(data.Dataset):
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_aug=2):
def __init__(self, dataset, num_splits=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_aug = num_aug
self.num_splits = num_splits
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
@ -170,9 +170,10 @@ class AugMixDataset(torch.utils.data.Dataset):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i]
x_list = [self._normalize(x)]
for n in range(self.num_aug):
x, y = self.dataset[i] # all splits share the same dataset base transform
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
# run the full augmentation on the remaining splits
for _ in range(self.num_splits - 1):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y

@ -15,7 +15,7 @@ def fast_collate(batch):
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0][0])
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
@ -46,13 +46,14 @@ def fast_collate(batch):
class PrefetchLoader:
def __init__(self,
loader,
rand_erase_prob=0.,
rand_erase_mode='const',
rand_erase_count=1,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False):
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
@ -60,9 +61,9 @@ class PrefetchLoader:
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if rand_erase_prob > 0.:
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
else:
self.random_erasing = None
@ -122,11 +123,13 @@ def create_loader(
batch_size,
is_training=False,
use_prefetcher=True,
rand_erase_prob=0.,
rand_erase_mode='const',
rand_erase_count=1,
re_prob=0.,
re_mode='const',
re_count=1,
re_split=False,
color_jitter=0.4,
auto_augment=None,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
@ -136,8 +139,11 @@ def create_loader(
collate_fn=None,
fp16=False,
tf_preprocessing=False,
separate_transforms=False,
):
re_num_splits = 0
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = num_aug_splits or 2
dataset.transform = create_transform(
input_size,
is_training=is_training,
@ -149,7 +155,11 @@ def create_loader(
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
separate=separate_transforms,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=num_aug_splits > 0,
)
sampler = None
@ -176,11 +186,13 @@ def create_loader(
if use_prefetcher:
loader = PrefetchLoader(
loader,
rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_mode=rand_erase_mode,
rand_erase_count=rand_erase_count,
mean=mean,
std=std,
fp16=fp16)
fp16=fp16,
re_prob=re_prob if is_training else 0.,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
return loader

@ -38,7 +38,7 @@ class RandomErasing:
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, device='cuda'):
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
@ -46,6 +46,7 @@ class RandomErasing:
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count
self.max_count = max_count or min_count
self.num_splits = num_splits
mode = mode.lower()
self.rand_color = False
self.per_pixel = False
@ -82,6 +83,8 @@ class RandomErasing:
self._erase(input, *input.size(), input.dtype)
else:
batch_size, chan, img_h, img_w = input.size()
for i in range(batch_size):
# skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype)
return input

@ -15,11 +15,13 @@ def transforms_imagenet_train(
color_jitter=0.4,
auto_augment=None,
interpolation='random',
random_erasing=0.4,
random_erasing_mode='const',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
separate=False,
):
@ -71,8 +73,9 @@ def transforms_imagenet_train(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if random_erasing > 0.:
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
if re_prob > 0.:
final_tfl.append(
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
@ -126,6 +129,10 @@ def create_transform(
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
crop_pct=None,
tf_preprocessing=False,
separate=False):
@ -150,6 +157,10 @@ def create_transform(
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=separate)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"

@ -6,7 +6,7 @@ from .cross_entropy import LabelSmoothingCrossEntropy
class JsdCrossEntropy(nn.Module):
""" Jenson-Shannon Divergence + Cross-Entropy Loss
""" Jensen-Shannon Divergence + Cross-Entropy Loss
"""
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):

@ -20,3 +20,4 @@ from .registry import *
from .factory import create_model
from .helpers import load_checkpoint, resume_checkpoint
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .split_batchnorm import convert_splitbn_model

@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, num_splits=1):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)'
self.num_splits = num_splits
self.aux_bn = nn.ModuleList([
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
def forward(self, input: torch.Tensor):
if self.training: # aux BN only relevant while training
split_size = input.shape[0] // self.num_splits
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
split_input = input.split(split_size)
x = [super().forward(split_input[0])]
for i, a in enumerate(self.aux_bn):
x.append(a(split_input[i + 1]))
return torch.cat(x, dim=0)
else:
return super().forward(input)
def convert_splitbn_model(module, num_splits=2):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
Args:
module (torch.nn.Module): input module
num_splits: number of separate batchnorm layers to split input across
Example::
>>> # model is an instance of torch.nn.Module
>>> import apex
>>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2)
"""
mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SplitBatchNorm2d(
module.num_features, module.eps, module.momentum, module.affine,
module.track_running_stats, num_splits=num_splits)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
mod.num_batches_tracked = module.num_batches_tracked
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for aux in mod.aux_bn:
aux.running_mean = module.running_mean.clone()
aux.running_var = module.running_var.clone()
aux.num_batches_tracked = module.num_batches_tracked.clone()
if module.affine:
aux.weight.data = module.weight.data.clone().detach()
aux.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
del module
return mod

@ -13,16 +13,13 @@ except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch
from timm.models import create_model, resume_checkpoint
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
#FIXME
from timm.data.dataset import AugMixDataset
import torch
import torch.nn as nn
import torchvision.utils
@ -71,6 +68,8 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
help='Drop connect rate (default: 0.)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
@ -106,18 +105,24 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
@ -129,6 +134,8 @@ parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
@ -162,10 +169,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--jsd', action='store_true', default=False,
help='')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
@ -233,6 +236,14 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
num_aug_splits = 0
if args.aug_splits:
num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
if args.num_gpu > 1:
if args.amp:
logging.warning(
@ -279,6 +290,7 @@ def main():
if args.distributed:
if args.sync_bn:
assert not args.split_bn
try:
if has_apex:
model = convert_syncbn_model(model)
@ -317,13 +329,11 @@ def main():
collate_fn = None
if args.prefetcher and args.mixup > 0:
assert not args.jsd
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
separate_transforms = False
if args.jsd:
dataset_train = AugMixDataset(dataset_train)
separate_transforms = True
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
loader_train = create_loader(
dataset_train,
@ -331,18 +341,19 @@ def main():
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob,
rand_erase_mode=args.remode,
rand_erase_count=args.recount,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
num_aug_splits=num_aug_splits,
interpolation=args.train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
separate_transforms=separate_transforms,
)
eval_dir = os.path.join(args.data, 'val')
@ -368,7 +379,8 @@ def main():
)
if args.jsd:
train_loss_fn = JsdCrossEntropy(smoothing=args.smoothing).cuda()
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.mixup > 0.:
# smoothing is handled with mixup label transform

Loading…
Cancel
Save