diff --git a/timm/data/loader.py b/timm/data/loader.py index 8a5f38af..317f77df 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -131,10 +131,15 @@ def create_loader( batch_size, is_training=False, use_prefetcher=True, + no_aug=False, re_prob=0., re_mode='const', re_count=1, re_split=False, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., color_jitter=0.4, auto_augment=None, num_aug_splits=0, @@ -158,6 +163,11 @@ def create_loader( input_size, is_training=is_training, use_prefetcher=use_prefetcher, + no_aug=no_aug, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, @@ -200,12 +210,13 @@ def create_loader( drop_last=is_training, ) if use_prefetcher: + prefetch_re_prob = re_prob if is_training and not no_aug else 0. loader = PrefetchLoader( loader, mean=mean, std=std, fp16=fp16, - re_prob=re_prob if is_training else 0., + re_prob=prefetch_re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index fd987c85..3ce774ff 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -14,9 +14,39 @@ from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, from timm.data.random_erasing import RandomErasing +def transforms_noaug_train( + img_size=224, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, +): + if interpolation == 'random': + # random interpolation no supported with no-aug + interpolation = 'bilinear' + tfl = [ + transforms.Resize(img_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size) + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + return transforms.Compose(tfl) + + def transforms_imagenet_train( img_size=224, - scale=(0.08, 1.0), + scale=None, + ratio=None, + hflip=0.5, + vflip=0., color_jitter=0.4, auto_augment=None, interpolation='random', @@ -36,11 +66,14 @@ def transforms_imagenet_train( * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range primary_tfl = [ - RandomResizedCropAndInterpolation( - img_size, scale=scale, interpolation=interpolation), - transforms.RandomHorizontalFlip() - ] + RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] + if hflip > 0.: + primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] + if vflip > 0.: + primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] secondary_tfl = [] if auto_augment: @@ -135,6 +168,11 @@ def create_transform( input_size, is_training=False, use_prefetcher=False, + no_aug=False, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., color_jitter=0.4, auto_augment=None, interpolation='bilinear', @@ -159,9 +197,21 @@ def create_transform( transform = TfPreprocessTransform( is_training=is_training, size=img_size, interpolation=interpolation) else: - if is_training: + if is_training and no_aug: + assert not separate, "Cannot perform split augmentation with no_aug" + transform = transforms_noaug_train( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + elif is_training: transform = transforms_imagenet_train( img_size, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, diff --git a/train.py b/train.py index bc856f34..365dea6f 100755 --- a/train.py +++ b/train.py @@ -51,6 +51,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') + # Dataset / Model parameters parser.add_argument('data', metavar='DIR', help='path to dataset') @@ -82,16 +83,7 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', help='ratio of validation batch size to training batch size (default: 1)') -parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', - help='Dropout rate (default: 0.)') -parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', - help='Drop connect rate, DEPRECATED, use drop-path (default: None)') -parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', - help='Drop path rate (default: None)') -parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', - help='Drop block rate (default: None)') -parser.add_argument('--jsd', action='store_true', default=False, - help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') + # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -101,6 +93,7 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.0001, help='weight decay (default: 0.0001)') + # Learning rate schedule parameters parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') @@ -134,13 +127,26 @@ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') -# Augmentation parameters + +# Augmentation & regularization parameters +parser.add_argument('--no-aug', action='store_true', default=False, + help='Disable all training augmentation, override other train aug args') +parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', + help='Random resize scale (default: 0.08 1.0)') +parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', + help='Random resize aspect ratio (default: 0.75 1.33)') +parser.add_argument('--hflip', type=float, default=0.5, + help='Horizontal flip training aug probability') +parser.add_argument('--vflip', type=float, default=0., + help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') +parser.add_argument('--jsd', action='store_true', default=False, + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', @@ -150,13 +156,22 @@ parser.add_argument('--recount', type=int, default=1, parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, - help='mixup alpha, mixup enabled if > 0. (default: 0.)') + help='Mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', - help='turn off mixup after this epoch, disabled if 0 (default: 0)') + help='Turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, - help='label smoothing (default: 0.1)') + help='Label smoothing (default: 0.1)') parser.add_argument('--train-interpolation', type=str, default='random', help='Training interpolation (random, bilinear, bicubic default: "random")') +parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') +parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', + help='Drop connect rate, DEPRECATED, use drop-path (default: None)') +parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', + help='Drop path rate (default: None)') +parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', + help='Drop block rate (default: None)') + # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') @@ -170,6 +185,7 @@ parser.add_argument('--dist-bn', type=str, default='', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') parser.add_argument('--split-bn', action='store_true', help='Enable separate BN layers per augmentation split.') + # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') @@ -177,6 +193,7 @@ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') + # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') @@ -378,20 +395,28 @@ def main(): if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + train_interpolation = args.train_interpolation + if args.no_aug or not train_interpolation: + train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, + no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_splits=num_aug_splits, - interpolation=args.train_interpolation, + interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers,