From 0161de01276100810f8aef5fac270164f4079de4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 5 Dec 2019 22:35:08 -0800 Subject: [PATCH 1/4] Switch RandoErasing back to on GPU normal sampling --- timm/data/random_erasing.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index e944f22c..5eed1387 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -7,12 +7,10 @@ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device=' # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 - # will revert back to doing normal_() on GPU when it's in next release if per_pixel: - return torch.empty( - patch_size, dtype=dtype).normal_().to(device=device) + return torch.empty(patch_size, dtype=dtype, device=device).normal_() elif rand_color: - return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device) + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) From 3bff2b21dcd0fab2177a5a7ccfa17609d16ec5aa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 5 Dec 2019 22:35:40 -0800 Subject: [PATCH 2/4] Add support for keeping running bn stats the same across distributed training nodes before eval/save --- timm/utils.py | 18 +++++++++++++++--- train.py | 12 +++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/timm/utils.py b/timm/utils.py index 8ed8f195..ee258aed 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -21,11 +21,15 @@ except ImportError: from torch import distributed as dist -def get_state_dict(model): +def unwrap_model(model): if isinstance(model, ModelEma): - return get_state_dict(model.ema) + return unwrap_model(model.ema) else: - return model.module.state_dict() if hasattr(model, 'module') else model.state_dict() + return model.module if hasattr(model, 'module') else model + + +def get_state_dict(model): + return unwrap_model(model).state_dict() class CheckpointSaver: @@ -206,6 +210,14 @@ def reduce_tensor(tensor, n): return rt +def reduce_bn(model, world_size): + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + + class ModelEma: """ Model Exponential Moving Average Keep a moving average of everything in the model state_dict (parameters and buffers). diff --git a/train.py b/train.py index b79a342e..41af8fb6 100644 --- a/train.py +++ b/train.py @@ -145,6 +145,8 @@ parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') +parser.add_argument('--reduce-bn', action='store_true', + help='average BN running stats across all distributed nodes between train and validation.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', @@ -256,7 +258,7 @@ def main(): if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) - resume_state = None # clear it + del resume_state model_ema = None if args.model_ema: @@ -388,9 +390,17 @@ def main(): lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) + if args.distributed and args.reduce_bn: + if args.local_rank == 0: + logging.info("Averaging bn running means and vars") + reduce_bn(model, args.world_size) + eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: + if args.distributed and args.reduce_bn: + reduce_bn(model_ema, args.world_size) + ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics From a435ea132721a388f16a63fe581057090620bd99 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 19 Dec 2019 22:56:54 -0800 Subject: [PATCH 3/4] Change reduce_bn to distribute_bn, add ability to choose between broadcast and reduce (mean). Add crop_pct arg to allow selecting validation crop while training. --- timm/utils.py | 11 ++++++++--- train.py | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/timm/utils.py b/timm/utils.py index ee258aed..59d2bcd0 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -210,12 +210,17 @@ def reduce_tensor(tensor, n): return rt -def reduce_bn(model, world_size): +def distribute_bn(model, world_size, reduce=False): # ensure every node has the same running bn stats for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): if ('running_mean' in bn_name) or ('running_var' in bn_name): - torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) - bn_buf /= float(world_size) + if reduce: + # average bn stats across whole group + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + else: + # broadcast bn stats from rank 0 to whole group + torch.distributed.broadcast(bn_buf, 0) class ModelEma: diff --git a/train.py b/train.py index 41af8fb6..c55cbfb3 100644 --- a/train.py +++ b/train.py @@ -55,6 +55,8 @@ parser.add_argument('--gp', default='avg', type=str, metavar='POOL', help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -121,6 +123,10 @@ parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') +parser.add_argument('--sync-bn', action='store_true', + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') +parser.add_argument('--dist-bn', type=str, default='', + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') @@ -143,10 +149,6 @@ parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') -parser.add_argument('--sync-bn', action='store_true', - help='enabling apex sync BN.') -parser.add_argument('--reduce-bn', action='store_true', - help='average BN running stats across all distributed nodes between train and validation.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', @@ -349,6 +351,7 @@ def main(): std=data_config['std'], num_workers=args.workers, distributed=args.distributed, + crop_pct=data_config['crop_pct'], ) if args.mixup > 0.: @@ -390,16 +393,16 @@ def main(): lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) - if args.distributed and args.reduce_bn: + if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: - logging.info("Averaging bn running means and vars") - reduce_bn(model, args.world_size) + logging.info("Distributing BatchNorm running means and vars") + distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.reduce_bn: - reduce_bn(model_ema, args.world_size) + distribute_bn(model_ema, args.world_size) ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') From 5719b493adb4b1f42844bcbdbc58af44c9f3056b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 19 Dec 2019 23:03:04 -0800 Subject: [PATCH 4/4] Missed update dist-bn logic for EMA model --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index c55cbfb3..a47f1b4d 100644 --- a/train.py +++ b/train.py @@ -393,7 +393,7 @@ def main(): lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) - if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'): + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') @@ -401,8 +401,8 @@ def main(): eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: - if args.distributed and args.reduce_bn: - distribute_bn(model_ema, args.world_size) + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')