Distributed tweaks

* Support PyTorch native DDP as fallback if APEX not present
* Support SyncBN for both APEX and Torch native (if torch >= 1.1)
* EMA model does not appear to need DDP wrapper, no gradients, updated from wrapped model
pull/16/head
Ross Wightman 6 years ago
parent 6fc886acaf
commit b20bb58284

@ -10,6 +10,7 @@ try:
from apex.parallel import convert_syncbn_model from apex.parallel import convert_syncbn_model
has_apex = True has_apex = True
except ImportError: except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
@ -169,8 +170,9 @@ def main():
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
logging.info('Model %s created, param count: %d' % if args.local_rank == 0:
(args.model, sum([m.numel() for m in model.parameters()]))) logging.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
@ -187,24 +189,23 @@ def main():
args.amp = False args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
if args.distributed and args.sync_bn and has_apex:
model = convert_syncbn_model(model)
model.cuda() model.cuda()
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
if optimizer_state is not None: if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state) optimizer.load_state_dict(optimizer_state)
use_amp = False
if has_apex and args.amp: if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True use_amp = True
logging.info('AMP enabled') if args.local_rank == 0:
else: logging.info('NVIDIA APEX {}. AMP {}.'.format(
use_amp = False 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
logging.info('AMP disabled')
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma( model_ema = ModelEma(
model, model,
decay=args.model_ema_decay, decay=args.model_ema_decay,
@ -212,11 +213,23 @@ def main():
resume=args.resume) resume=args.resume)
if args.distributed: if args.distributed:
model = DDP(model, delay_allreduce=True) if args.sync_bn:
if model_ema is not None and not args.model_ema_force_cpu: try:
# must also distribute EMA model to allow validation if has_apex:
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True) model = convert_syncbn_model(model)
model_ema.ema_has_module = True else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
logging.info('Converted model to use Synchronized BatchNorm.')
except Exception as e:
logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
if has_apex:
model = DDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
if start_epoch > 0: if start_epoch > 0:

Loading…
Cancel
Save