|
|
|
@ -60,8 +60,8 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
|
|
|
|
|
help='dataset train split (default: train)')
|
|
|
|
|
parser.add_argument('--val-split', metavar='NAME', default='validation',
|
|
|
|
|
help='dataset validation split (default: validation)')
|
|
|
|
|
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
|
|
|
|
help='Name of model to train (default: "countception"')
|
|
|
|
|
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
|
|
|
|
|
help='Name of model to train (default: "resnet50"')
|
|
|
|
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
|
|
|
|
help='Start with pretrained version of specified network (if avail)')
|
|
|
|
|
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
@ -215,8 +215,6 @@ parser.add_argument('--split-bn', action='store_true',
|
|
|
|
|
# Model Exponential Moving Average
|
|
|
|
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
|
|
|
|
help='Enable tracking moving average of model weights')
|
|
|
|
|
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
|
|
|
|
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
|
|
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
|
|
|
|
help='decay factor for model weights moving average (default: 0.9998)')
|
|
|
|
|
|
|
|
|
@ -252,6 +250,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
|
|
|
|
|
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
|
|
|
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
|
|
|
|
help='convert model torchscript for inference')
|
|
|
|
|
parser.add_argument('--force-cpu', action='store_true', default=False,
|
|
|
|
|
help='Force CPU to be used even if HW accelerator exists.')
|
|
|
|
|
parser.add_argument('--log-wandb', action='store_true', default=False,
|
|
|
|
|
help='log training and validation metrics to wandb')
|
|
|
|
|
|
|
|
|
@ -277,7 +277,7 @@ def main():
|
|
|
|
|
setup_default_logging()
|
|
|
|
|
args, args_text = _parse_args()
|
|
|
|
|
|
|
|
|
|
dev_env = initialize_device(amp=args.amp, channels_last=args.channels_last)
|
|
|
|
|
dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp, channels_last=args.channels_last)
|
|
|
|
|
if dev_env.distributed:
|
|
|
|
|
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
|
|
|
|
|
% (dev_env.global_rank, dev_env.world_size))
|
|
|
|
@ -364,7 +364,7 @@ def main():
|
|
|
|
|
services.monitor,
|
|
|
|
|
dev_env)
|
|
|
|
|
|
|
|
|
|
if train_state.model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
|
if train_state.model_ema is not None:
|
|
|
|
|
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env)
|
|
|
|
|
|
|
|
|
|