|
|
@ -145,6 +145,8 @@ parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
help='use NVIDIA amp for mixed precision training')
|
|
|
|
help='use NVIDIA amp for mixed precision training')
|
|
|
|
parser.add_argument('--sync-bn', action='store_true',
|
|
|
|
parser.add_argument('--sync-bn', action='store_true',
|
|
|
|
help='enabling apex sync BN.')
|
|
|
|
help='enabling apex sync BN.')
|
|
|
|
|
|
|
|
parser.add_argument('--reduce-bn', action='store_true',
|
|
|
|
|
|
|
|
help='average BN running stats across all distributed nodes between train and validation.')
|
|
|
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
@ -256,7 +258,7 @@ def main():
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
|
|
|
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
|
|
|
amp.load_state_dict(resume_state['amp'])
|
|
|
|
amp.load_state_dict(resume_state['amp'])
|
|
|
|
resume_state = None # clear it
|
|
|
|
del resume_state
|
|
|
|
|
|
|
|
|
|
|
|
model_ema = None
|
|
|
|
model_ema = None
|
|
|
|
if args.model_ema:
|
|
|
|
if args.model_ema:
|
|
|
@ -388,9 +390,17 @@ def main():
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
use_amp=use_amp, model_ema=model_ema)
|
|
|
|
use_amp=use_amp, model_ema=model_ema)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed and args.reduce_bn:
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
|
|
|
logging.info("Averaging bn running means and vars")
|
|
|
|
|
|
|
|
reduce_bn(model, args.world_size)
|
|
|
|
|
|
|
|
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
|
|
|
|
|
|
|
|
|
|
|
if model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
if model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
|
|
|
|
if args.distributed and args.reduce_bn:
|
|
|
|
|
|
|
|
reduce_bn(model_ema, args.world_size)
|
|
|
|
|
|
|
|
|
|
|
|
ema_eval_metrics = validate(
|
|
|
|
ema_eval_metrics = validate(
|
|
|
|
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
|
|
|
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
|
|
|
eval_metrics = ema_eval_metrics
|
|
|
|
eval_metrics = ema_eval_metrics
|
|
|
|