Merge pull request #62 from rwightman/reduce-bn

Distribute BatchNorm stats
pull/74/head
Ross Wightman 5 years ago committed by GitHub
commit ff8688ca3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,12 +7,10 @@ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
# paths, flip the order so normal is run on CPU if this becomes a problem # paths, flip the order so normal is run on CPU if this becomes a problem
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
# will revert back to doing normal_() on GPU when it's in next release
if per_pixel: if per_pixel:
return torch.empty( return torch.empty(patch_size, dtype=dtype, device=device).normal_()
patch_size, dtype=dtype).normal_().to(device=device)
elif rand_color: elif rand_color:
return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device) return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
else: else:
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)

@ -21,11 +21,15 @@ except ImportError:
from torch import distributed as dist from torch import distributed as dist
def get_state_dict(model): def unwrap_model(model):
if isinstance(model, ModelEma): if isinstance(model, ModelEma):
return get_state_dict(model.ema) return unwrap_model(model.ema)
else: else:
return model.module.state_dict() if hasattr(model, 'module') else model.state_dict() return model.module if hasattr(model, 'module') else model
def get_state_dict(model):
return unwrap_model(model).state_dict()
class CheckpointSaver: class CheckpointSaver:
@ -206,6 +210,19 @@ def reduce_tensor(tensor, n):
return rt return rt
def distribute_bn(model, world_size, reduce=False):
# ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= float(world_size)
else:
# broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0)
class ModelEma: class ModelEma:
""" Model Exponential Moving Average """ Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and buffers). Keep a moving average of everything in the model state_dict (parameters and buffers).

@ -55,6 +55,8 @@ parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--img-size', type=int, default=None, metavar='N', parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)') help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -121,6 +123,10 @@ parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)') help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None, parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)') help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
# Model Exponential Moving Average # Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False, parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights') help='Enable tracking moving average of model weights')
@ -143,8 +149,6 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training') help='use NVIDIA amp for mixed precision training')
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--no-prefetcher', action='store_true', default=False, parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',
@ -256,7 +260,7 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Restoring NVIDIA AMP state from checkpoint') logging.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp']) amp.load_state_dict(resume_state['amp'])
resume_state = None # clear it del resume_state
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
@ -347,6 +351,7 @@ def main():
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed, distributed=args.distributed,
crop_pct=data_config['crop_pct'],
) )
if args.mixup > 0.: if args.mixup > 0.:
@ -388,9 +393,17 @@ def main():
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema) use_amp=use_amp, model_ema=model_ema)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
logging.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args) eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
if model_ema is not None and not args.model_ema_force_cpu: if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate( ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics

Loading…
Cancel
Save