Add more augmentation arguments, including a no_aug disable flag. Fix #209

pull/175/head
Ross Wightman 5 years ago
parent e3f58fc90c
commit fa28067704

@ -131,10 +131,15 @@ def create_loader(
batch_size, batch_size,
is_training=False, is_training=False,
use_prefetcher=True, use_prefetcher=True,
no_aug=False,
re_prob=0., re_prob=0.,
re_mode='const', re_mode='const',
re_count=1, re_count=1,
re_split=False, re_split=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4, color_jitter=0.4,
auto_augment=None, auto_augment=None,
num_aug_splits=0, num_aug_splits=0,
@ -158,6 +163,11 @@ def create_loader(
input_size, input_size,
is_training=is_training, is_training=is_training,
use_prefetcher=use_prefetcher, use_prefetcher=use_prefetcher,
no_aug=no_aug,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter, color_jitter=color_jitter,
auto_augment=auto_augment, auto_augment=auto_augment,
interpolation=interpolation, interpolation=interpolation,
@ -200,12 +210,13 @@ def create_loader(
drop_last=is_training, drop_last=is_training,
) )
if use_prefetcher: if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader( loader = PrefetchLoader(
loader, loader,
mean=mean, mean=mean,
std=std, std=std,
fp16=fp16, fp16=fp16,
re_prob=re_prob if is_training else 0., re_prob=prefetch_re_prob,
re_mode=re_mode, re_mode=re_mode,
re_count=re_count, re_count=re_count,
re_num_splits=re_num_splits re_num_splits=re_num_splits

@ -14,9 +14,39 @@ from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation,
from timm.data.random_erasing import RandomErasing from timm.data.random_erasing import RandomErasing
def transforms_noaug_train(
img_size=224,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
):
if interpolation == 'random':
# random interpolation no supported with no-aug
interpolation = 'bilinear'
tfl = [
transforms.Resize(img_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(tfl)
def transforms_imagenet_train( def transforms_imagenet_train(
img_size=224, img_size=224,
scale=(0.08, 1.0), scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4, color_jitter=0.4,
auto_augment=None, auto_augment=None,
interpolation='random', interpolation='random',
@ -36,11 +66,14 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform * a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform * normalizes and converts the branches above with the third, final transform
""" """
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
primary_tfl = [ primary_tfl = [
RandomResizedCropAndInterpolation( RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
img_size, scale=scale, interpolation=interpolation), if hflip > 0.:
transforms.RandomHorizontalFlip() primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
] if vflip > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
secondary_tfl = [] secondary_tfl = []
if auto_augment: if auto_augment:
@ -135,6 +168,11 @@ def create_transform(
input_size, input_size,
is_training=False, is_training=False,
use_prefetcher=False, use_prefetcher=False,
no_aug=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4, color_jitter=0.4,
auto_augment=None, auto_augment=None,
interpolation='bilinear', interpolation='bilinear',
@ -159,9 +197,21 @@ def create_transform(
transform = TfPreprocessTransform( transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation) is_training=is_training, size=img_size, interpolation=interpolation)
else: else:
if is_training: if is_training and no_aug:
assert not separate, "Cannot perform split augmentation with no_aug"
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
elif is_training:
transform = transforms_imagenet_train( transform = transforms_imagenet_train(
img_size, img_size,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter, color_jitter=color_jitter,
auto_augment=auto_augment, auto_augment=auto_augment,
interpolation=interpolation, interpolation=interpolation,

@ -51,6 +51,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters # Dataset / Model parameters
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
@ -82,16 +83,7 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)') help='input batch size for training (default: 32)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)') help='ratio of validation batch size to training batch size (default: 1)')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
# Optimizer parameters # Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"') help='Optimizer (default: "sgd"')
@ -101,6 +93,7 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)') help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0001, parser.add_argument('--weight-decay', type=float, default=0.0001,
help='weight decay (default: 0.0001)') help='weight decay (default: 0.0001)')
# Learning rate schedule parameters # Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"') help='LR scheduler (default: "step"')
@ -134,13 +127,26 @@ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10') help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)') help='LR decay rate (default: 0.1)')
# Augmentation parameters
# Augmentation & regularization parameters
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)') help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME', parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'), help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-splits', type=int, default=0, parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)') help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT', parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const', parser.add_argument('--remode', type=str, default='const',
@ -150,13 +156,22 @@ parser.add_argument('--recount', type=int, default=1,
parser.add_argument('--resplit', action='store_true', default=False, parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split') help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0, parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)') help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='turn off mixup after this epoch, disabled if 0 (default: 0)') help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1, parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)') help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random', parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")') help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently) # Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False, parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
@ -170,6 +185,7 @@ parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true', parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.') help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average # Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False, parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights') help='Enable tracking moving average of model weights')
@ -177,6 +193,7 @@ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998, parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)') help='decay factor for model weights moving average (default: 0.9998)')
# Misc # Misc
parser.add_argument('--seed', type=int, default=42, metavar='S', parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)') help='random seed (default: 42)')
@ -378,20 +395,28 @@ def main():
if num_aug_splits > 1: if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader( loader_train = create_loader(
dataset_train, dataset_train,
input_size=data_config['input_size'], input_size=data_config['input_size'],
batch_size=args.batch_size, batch_size=args.batch_size,
is_training=True, is_training=True,
use_prefetcher=args.prefetcher, use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob, re_prob=args.reprob,
re_mode=args.remode, re_mode=args.remode,
re_count=args.recount, re_count=args.recount,
re_split=args.resplit, re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
auto_augment=args.aa, auto_augment=args.aa,
num_aug_splits=num_aug_splits, num_aug_splits=num_aug_splits,
interpolation=args.train_interpolation, interpolation=train_interpolation,
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,

Loading…
Cancel
Save