|
|
|
@ -27,22 +27,21 @@ import torchvision.utils
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Training')
|
|
|
|
|
# Dataset / Model parameters
|
|
|
|
|
parser.add_argument('data', metavar='DIR',
|
|
|
|
|
help='path to dataset')
|
|
|
|
|
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
|
|
|
|
help='Name of model to train (default: "countception"')
|
|
|
|
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
|
|
|
|
help='Start with pretrained version of specified network (if avail)')
|
|
|
|
|
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
|
help='Initialize model from this checkpoint (default: none)')
|
|
|
|
|
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|
|
|
|
help='Resume full model and optimizer state from checkpoint (default: none)')
|
|
|
|
|
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
|
|
|
|
help='number of label classes (default: 1000)')
|
|
|
|
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
|
|
|
|
help='Optimizer (default: "sgd"')
|
|
|
|
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
|
|
|
|
help='Optimizer Epsilon (default: 1e-8)')
|
|
|
|
|
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
|
|
|
|
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
|
|
|
|
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
|
|
|
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
|
|
|
|
help='Start with pretrained version of specified network (if avail)')
|
|
|
|
|
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
|
|
|
|
help='Image patch size (default: None => model default)')
|
|
|
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
|
|
@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|
|
|
|
help='Image resize interpolation type (overrides model)')
|
|
|
|
|
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
|
|
|
|
help='input batch size for training (default: 32)')
|
|
|
|
|
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
|
|
|
|
|
help='initial input batch size for training (default: 0)')
|
|
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
|
|
|
|
help='Dropout rate (default: 0.)')
|
|
|
|
|
# Optimizer parameters
|
|
|
|
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
|
|
|
|
help='Optimizer (default: "sgd"')
|
|
|
|
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
|
|
|
|
help='Optimizer Epsilon (default: 1e-8)')
|
|
|
|
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|
|
|
|
help='SGD momentum (default: 0.9)')
|
|
|
|
|
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
|
|
|
|
help='weight decay (default: 0.0001)')
|
|
|
|
|
# Learning rate schedule parameters
|
|
|
|
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|
|
|
|
help='LR scheduler (default: "step"')
|
|
|
|
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
|
|
|
|
help='learning rate (default: 0.01)')
|
|
|
|
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
|
|
|
|
help='warmup learning rate (default: 0.0001)')
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
|
|
|
|
help='number of epochs to train (default: 2)')
|
|
|
|
|
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
|
|
|
@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
|
|
|
|
help='epochs to warmup LR, if scheduler supports')
|
|
|
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
|
|
|
|
help='LR decay rate (default: 0.1)')
|
|
|
|
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|
|
|
|
help='LR scheduler (default: "step"')
|
|
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
|
|
|
|
help='Dropout rate (default: 0.)')
|
|
|
|
|
# Augmentation parameters
|
|
|
|
|
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
|
|
|
|
|
help='Color jitter factor (default: 0.4)')
|
|
|
|
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
|
|
|
|
help='Random erase prob (default: 0.)')
|
|
|
|
|
parser.add_argument('--remode', type=str, default='const',
|
|
|
|
|
help='Random erase mode (default: "const")')
|
|
|
|
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
|
|
|
|
help='learning rate (default: 0.01)')
|
|
|
|
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
|
|
|
|
help='warmup learning rate (default: 0.0001)')
|
|
|
|
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|
|
|
|
help='SGD momentum (default: 0.9)')
|
|
|
|
|
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
|
|
|
|
help='weight decay (default: 0.0001)')
|
|
|
|
|
parser.add_argument('--mixup', type=float, default=0.0,
|
|
|
|
|
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
|
|
|
|
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
|
|
|
|
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
|
|
|
|
parser.add_argument('--smoothing', type=float, default=0.1,
|
|
|
|
|
help='label smoothing (default: 0.1)')
|
|
|
|
|
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
|
|
|
|
parser.add_argument('--bn-tf', action='store_true', default=False,
|
|
|
|
|
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
|
|
|
|
parser.add_argument('--bn-momentum', type=float, default=None,
|
|
|
|
|
help='BatchNorm momentum override (if not None)')
|
|
|
|
|
parser.add_argument('--bn-eps', type=float, default=None,
|
|
|
|
|
help='BatchNorm epsilon override (if not None)')
|
|
|
|
|
# Model Exponential Moving Average
|
|
|
|
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
|
|
|
|
help='Enable tracking moving average of model weights')
|
|
|
|
|
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
|
|
|
|
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
|
|
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
|
|
|
|
help='decay factor for model weights moving average (default: 0.9998)')
|
|
|
|
|
# Misc
|
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
|
help='random seed (default: 42)')
|
|
|
|
|
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|
|
|
@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|
|
|
|
help='how many training processes to use (default: 1)')
|
|
|
|
|
parser.add_argument('--num-gpu', type=int, default=1,
|
|
|
|
|
help='Number of GPUS to use')
|
|
|
|
|
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
|
help='path to init checkpoint (default: none)')
|
|
|
|
|
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|
|
|
|
help='path to latest checkpoint (default: none)')
|
|
|
|
|
parser.add_argument('--save-images', action='store_true', default=False,
|
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
|
help='path to output folder (default: none, current dir)')
|
|
|
|
|
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
|
|
|
|
help='Best metric (default: "prec1"')
|
|
|
|
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
|
|
|
|
parser.add_argument("--local_rank", default=0, type=int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -174,13 +181,13 @@ def main():
|
|
|
|
|
logging.info('Model %s created, param count: %d' %
|
|
|
|
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
|
|
|
|
|
|
|
|
|
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
|
|
|
|
|
|
|
|
|
# optionally resume from a checkpoint
|
|
|
|
|
start_epoch = 0
|
|
|
|
|
optimizer_state = None
|
|
|
|
|
resume_epoch = None
|
|
|
|
|
if args.resume:
|
|
|
|
|
optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
|
|
|
|
|
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
|
|
|
|
|
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
|
if args.amp:
|
|
|
|
@ -232,8 +239,15 @@ def main():
|
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
|
|
|
|
|
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
|
|
|
|
start_epoch = 0
|
|
|
|
|
if args.start_epoch is not None:
|
|
|
|
|
# a specified start_epoch will always override the resume epoch
|
|
|
|
|
start_epoch = args.start_epoch
|
|
|
|
|
elif resume_epoch is not None:
|
|
|
|
|
start_epoch = resume_epoch
|
|
|
|
|
if start_epoch > 0:
|
|
|
|
|
lr_scheduler.step(start_epoch)
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
logging.info('Scheduled epochs: {}'.format(num_epochs))
|
|
|
|
|
|
|
|
|
@ -255,6 +269,7 @@ def main():
|
|
|
|
|
use_prefetcher=args.prefetcher,
|
|
|
|
|
rand_erase_prob=args.reprob,
|
|
|
|
|
rand_erase_mode=args.remode,
|
|
|
|
|
color_jitter=args.color_jitter,
|
|
|
|
|
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
|
|
|
|
mean=data_config['mean'],
|
|
|
|
|
std=data_config['std'],
|
|
|
|
@ -327,7 +342,8 @@ def main():
|
|
|
|
|
eval_metrics = ema_eval_metrics
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
lr_scheduler.step(epoch, eval_metrics[eval_metric])
|
|
|
|
|
# step LR for next epoch
|
|
|
|
|
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
|
|
|
|
|
|
|
|
|
update_summary(
|
|
|
|
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
|
|
|
@ -338,9 +354,7 @@ def main():
|
|
|
|
|
save_metric = eval_metrics[eval_metric]
|
|
|
|
|
best_metric, best_epoch = saver.save_checkpoint(
|
|
|
|
|
model, optimizer, args,
|
|
|
|
|
epoch=epoch + 1,
|
|
|
|
|
model_ema=model_ema,
|
|
|
|
|
metric=save_metric)
|
|
|
|
|
epoch=epoch, model_ema=model_ema, metric=save_metric)
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
pass
|
|
|
|
@ -433,9 +447,8 @@ def train_epoch(
|
|
|
|
|
|
|
|
|
|
if saver is not None and args.recovery_interval and (
|
|
|
|
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
|
save_epoch = epoch + 1 if last_batch else epoch
|
|
|
|
|
saver.save_recovery(
|
|
|
|
|
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
|
|
|
|
|
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
|