Missed update dist-bn logic for EMA model

pull/62/head
Ross Wightman 4 years ago
parent a435ea1327
commit 5719b493ad

@ -393,7 +393,7 @@ def main():
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema)
if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'):
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
logging.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
@ -401,8 +401,8 @@ def main():
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.reduce_bn:
distribute_bn(model_ema, args.world_size)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')

Loading…
Cancel
Save