diff --git a/timm/utils.py b/timm/utils.py index 8ed8f195..ee258aed 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -21,11 +21,15 @@ except ImportError: from torch import distributed as dist -def get_state_dict(model): +def unwrap_model(model): if isinstance(model, ModelEma): - return get_state_dict(model.ema) + return unwrap_model(model.ema) else: - return model.module.state_dict() if hasattr(model, 'module') else model.state_dict() + return model.module if hasattr(model, 'module') else model + + +def get_state_dict(model): + return unwrap_model(model).state_dict() class CheckpointSaver: @@ -206,6 +210,14 @@ def reduce_tensor(tensor, n): return rt +def reduce_bn(model, world_size): + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + + class ModelEma: """ Model Exponential Moving Average Keep a moving average of everything in the model state_dict (parameters and buffers). diff --git a/train.py b/train.py index b79a342e..41af8fb6 100644 --- a/train.py +++ b/train.py @@ -145,6 +145,8 @@ parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') +parser.add_argument('--reduce-bn', action='store_true', + help='average BN running stats across all distributed nodes between train and validation.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', @@ -256,7 +258,7 @@ def main(): if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) - resume_state = None # clear it + del resume_state model_ema = None if args.model_ema: @@ -388,9 +390,17 @@ def main(): lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) + if args.distributed and args.reduce_bn: + if args.local_rank == 0: + logging.info("Averaging bn running means and vars") + reduce_bn(model, args.world_size) + eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: + if args.distributed and args.reduce_bn: + reduce_bn(model_ema, args.world_size) + ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics