|
|
@ -51,6 +51,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
|
|
|
|
|
|
|
|
|
|
|
# Dataset / Model parameters
|
|
|
|
# Dataset / Model parameters
|
|
|
|
parser.add_argument('data', metavar='DIR',
|
|
|
|
parser.add_argument('data', metavar='DIR',
|
|
|
|
help='path to dataset')
|
|
|
|
help='path to dataset')
|
|
|
@ -82,16 +83,7 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
|
|
|
help='input batch size for training (default: 32)')
|
|
|
|
help='input batch size for training (default: 32)')
|
|
|
|
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
|
|
|
|
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
|
|
|
|
help='ratio of validation batch size to training batch size (default: 1)')
|
|
|
|
help='ratio of validation batch size to training batch size (default: 1)')
|
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
|
|
|
|
|
|
|
help='Dropout rate (default: 0.)')
|
|
|
|
|
|
|
|
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
|
|
|
|
|
|
|
|
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
|
|
|
|
|
|
|
|
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
|
|
|
|
|
|
|
help='Drop path rate (default: None)')
|
|
|
|
|
|
|
|
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
|
|
|
|
|
|
|
help='Drop block rate (default: None)')
|
|
|
|
|
|
|
|
parser.add_argument('--jsd', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
|
|
|
|
|
|
|
# Optimizer parameters
|
|
|
|
# Optimizer parameters
|
|
|
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
|
|
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
|
|
|
help='Optimizer (default: "sgd"')
|
|
|
|
help='Optimizer (default: "sgd"')
|
|
|
@ -101,6 +93,7 @@ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|
|
|
help='SGD momentum (default: 0.9)')
|
|
|
|
help='SGD momentum (default: 0.9)')
|
|
|
|
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
|
|
|
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
|
|
|
help='weight decay (default: 0.0001)')
|
|
|
|
help='weight decay (default: 0.0001)')
|
|
|
|
|
|
|
|
|
|
|
|
# Learning rate schedule parameters
|
|
|
|
# Learning rate schedule parameters
|
|
|
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|
|
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|
|
|
help='LR scheduler (default: "step"')
|
|
|
|
help='LR scheduler (default: "step"')
|
|
|
@ -134,13 +127,26 @@ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
|
|
|
help='patience epochs for Plateau LR scheduler (default: 10')
|
|
|
|
help='patience epochs for Plateau LR scheduler (default: 10')
|
|
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
|
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
|
|
|
help='LR decay rate (default: 0.1)')
|
|
|
|
help='LR decay rate (default: 0.1)')
|
|
|
|
# Augmentation parameters
|
|
|
|
|
|
|
|
|
|
|
|
# Augmentation & regularization parameters
|
|
|
|
|
|
|
|
parser.add_argument('--no-aug', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Disable all training augmentation, override other train aug args')
|
|
|
|
|
|
|
|
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
|
|
|
|
|
|
|
|
help='Random resize scale (default: 0.08 1.0)')
|
|
|
|
|
|
|
|
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
|
|
|
|
|
|
|
|
help='Random resize aspect ratio (default: 0.75 1.33)')
|
|
|
|
|
|
|
|
parser.add_argument('--hflip', type=float, default=0.5,
|
|
|
|
|
|
|
|
help='Horizontal flip training aug probability')
|
|
|
|
|
|
|
|
parser.add_argument('--vflip', type=float, default=0.,
|
|
|
|
|
|
|
|
help='Vertical flip training aug probability')
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
|
|
|
help='Color jitter factor (default: 0.4)')
|
|
|
|
help='Color jitter factor (default: 0.4)')
|
|
|
|
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
|
|
|
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
|
|
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
|
|
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
|
|
|
parser.add_argument('--aug-splits', type=int, default=0,
|
|
|
|
parser.add_argument('--aug-splits', type=int, default=0,
|
|
|
|
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
|
|
|
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
|
|
|
|
|
|
|
parser.add_argument('--jsd', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
|
|
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
|
|
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
|
|
|
help='Random erase prob (default: 0.)')
|
|
|
|
help='Random erase prob (default: 0.)')
|
|
|
|
parser.add_argument('--remode', type=str, default='const',
|
|
|
|
parser.add_argument('--remode', type=str, default='const',
|
|
|
@ -150,13 +156,22 @@ parser.add_argument('--recount', type=int, default=1,
|
|
|
|
parser.add_argument('--resplit', action='store_true', default=False,
|
|
|
|
parser.add_argument('--resplit', action='store_true', default=False,
|
|
|
|
help='Do not random erase first (clean) augmentation split')
|
|
|
|
help='Do not random erase first (clean) augmentation split')
|
|
|
|
parser.add_argument('--mixup', type=float, default=0.0,
|
|
|
|
parser.add_argument('--mixup', type=float, default=0.0,
|
|
|
|
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
|
|
|
help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
|
|
|
|
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
|
|
|
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
|
|
|
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
|
|
|
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
|
|
|
|
parser.add_argument('--smoothing', type=float, default=0.1,
|
|
|
|
parser.add_argument('--smoothing', type=float, default=0.1,
|
|
|
|
help='label smoothing (default: 0.1)')
|
|
|
|
help='Label smoothing (default: 0.1)')
|
|
|
|
parser.add_argument('--train-interpolation', type=str, default='random',
|
|
|
|
parser.add_argument('--train-interpolation', type=str, default='random',
|
|
|
|
help='Training interpolation (random, bilinear, bicubic default: "random")')
|
|
|
|
help='Training interpolation (random, bilinear, bicubic default: "random")')
|
|
|
|
|
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
|
|
|
|
|
|
|
help='Dropout rate (default: 0.)')
|
|
|
|
|
|
|
|
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
|
|
|
|
|
|
|
|
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
|
|
|
|
|
|
|
|
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
|
|
|
|
|
|
|
help='Drop path rate (default: None)')
|
|
|
|
|
|
|
|
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
|
|
|
|
|
|
|
help='Drop block rate (default: None)')
|
|
|
|
|
|
|
|
|
|
|
|
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
|
|
|
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
|
|
|
parser.add_argument('--bn-tf', action='store_true', default=False,
|
|
|
|
parser.add_argument('--bn-tf', action='store_true', default=False,
|
|
|
|
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
|
|
|
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
|
|
@ -170,6 +185,7 @@ parser.add_argument('--dist-bn', type=str, default='',
|
|
|
|
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
|
|
|
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
|
|
|
|
parser.add_argument('--split-bn', action='store_true',
|
|
|
|
parser.add_argument('--split-bn', action='store_true',
|
|
|
|
help='Enable separate BN layers per augmentation split.')
|
|
|
|
help='Enable separate BN layers per augmentation split.')
|
|
|
|
|
|
|
|
|
|
|
|
# Model Exponential Moving Average
|
|
|
|
# Model Exponential Moving Average
|
|
|
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
|
|
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
|
|
|
help='Enable tracking moving average of model weights')
|
|
|
|
help='Enable tracking moving average of model weights')
|
|
|
@ -177,6 +193,7 @@ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
|
|
|
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
|
|
|
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
|
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
|
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
|
|
|
help='decay factor for model weights moving average (default: 0.9998)')
|
|
|
|
help='decay factor for model weights moving average (default: 0.9998)')
|
|
|
|
|
|
|
|
|
|
|
|
# Misc
|
|
|
|
# Misc
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
help='random seed (default: 42)')
|
|
|
|
help='random seed (default: 42)')
|
|
|
@ -378,20 +395,28 @@ def main():
|
|
|
|
if num_aug_splits > 1:
|
|
|
|
if num_aug_splits > 1:
|
|
|
|
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
|
|
|
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_interpolation = args.train_interpolation
|
|
|
|
|
|
|
|
if args.no_aug or not train_interpolation:
|
|
|
|
|
|
|
|
train_interpolation = data_config['interpolation']
|
|
|
|
loader_train = create_loader(
|
|
|
|
loader_train = create_loader(
|
|
|
|
dataset_train,
|
|
|
|
dataset_train,
|
|
|
|
input_size=data_config['input_size'],
|
|
|
|
input_size=data_config['input_size'],
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
is_training=True,
|
|
|
|
is_training=True,
|
|
|
|
use_prefetcher=args.prefetcher,
|
|
|
|
use_prefetcher=args.prefetcher,
|
|
|
|
|
|
|
|
no_aug=args.no_aug,
|
|
|
|
re_prob=args.reprob,
|
|
|
|
re_prob=args.reprob,
|
|
|
|
re_mode=args.remode,
|
|
|
|
re_mode=args.remode,
|
|
|
|
re_count=args.recount,
|
|
|
|
re_count=args.recount,
|
|
|
|
re_split=args.resplit,
|
|
|
|
re_split=args.resplit,
|
|
|
|
|
|
|
|
scale=args.scale,
|
|
|
|
|
|
|
|
ratio=args.ratio,
|
|
|
|
|
|
|
|
hflip=args.hflip,
|
|
|
|
|
|
|
|
vflip=args.vflip,
|
|
|
|
color_jitter=args.color_jitter,
|
|
|
|
color_jitter=args.color_jitter,
|
|
|
|
auto_augment=args.aa,
|
|
|
|
auto_augment=args.aa,
|
|
|
|
num_aug_splits=num_aug_splits,
|
|
|
|
num_aug_splits=num_aug_splits,
|
|
|
|
interpolation=args.train_interpolation,
|
|
|
|
interpolation=train_interpolation,
|
|
|
|
mean=data_config['mean'],
|
|
|
|
mean=data_config['mean'],
|
|
|
|
std=data_config['std'],
|
|
|
|
std=data_config['std'],
|
|
|
|
num_workers=args.workers,
|
|
|
|
num_workers=args.workers,
|
|
|
|