|
|
|
@ -393,7 +393,7 @@ def main():
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
use_amp=use_amp, model_ema=model_ema)
|
|
|
|
|
|
|
|
|
|
if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
logging.info("Distributing BatchNorm running means and vars")
|
|
|
|
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
@ -401,8 +401,8 @@ def main():
|
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
|
|
|
|
|
|
|
|
|
if model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
|
if args.distributed and args.reduce_bn:
|
|
|
|
|
distribute_bn(model_ema, args.world_size)
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
|
|
|
|
|
ema_eval_metrics = validate(
|
|
|
|
|
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
|
|
|
|