diff --git a/timm/utils.py b/timm/utils.py index ee258aed..59d2bcd0 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -210,12 +210,17 @@ def reduce_tensor(tensor, n): return rt -def reduce_bn(model, world_size): +def distribute_bn(model, world_size, reduce=False): # ensure every node has the same running bn stats for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): if ('running_mean' in bn_name) or ('running_var' in bn_name): - torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) - bn_buf /= float(world_size) + if reduce: + # average bn stats across whole group + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + else: + # broadcast bn stats from rank 0 to whole group + torch.distributed.broadcast(bn_buf, 0) class ModelEma: diff --git a/train.py b/train.py index 41af8fb6..c55cbfb3 100644 --- a/train.py +++ b/train.py @@ -55,6 +55,8 @@ parser.add_argument('--gp', default='avg', type=str, metavar='POOL', help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -121,6 +123,10 @@ parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') +parser.add_argument('--sync-bn', action='store_true', + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') +parser.add_argument('--dist-bn', type=str, default='', + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') @@ -143,10 +149,6 @@ parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') -parser.add_argument('--sync-bn', action='store_true', - help='enabling apex sync BN.') -parser.add_argument('--reduce-bn', action='store_true', - help='average BN running stats across all distributed nodes between train and validation.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', @@ -349,6 +351,7 @@ def main(): std=data_config['std'], num_workers=args.workers, distributed=args.distributed, + crop_pct=data_config['crop_pct'], ) if args.mixup > 0.: @@ -390,16 +393,16 @@ def main(): lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) - if args.distributed and args.reduce_bn: + if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: - logging.info("Averaging bn running means and vars") - reduce_bn(model, args.world_size) + logging.info("Distributing BatchNorm running means and vars") + distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.reduce_bn: - reduce_bn(model_ema, args.world_size) + distribute_bn(model_ema, args.world_size) ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')