@ -10,6 +10,7 @@ try:
from apex . parallel import convert_syncbn_model
has_apex = True
except ImportError :
from torch . nn . parallel import DistributedDataParallel as DDP
has_apex = False
from timm . data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
@ -169,6 +170,7 @@ def main():
bn_eps = args . bn_eps ,
checkpoint_path = args . initial_checkpoint )
if args . local_rank == 0 :
logging . info ( ' Model %s created, param count: %d ' %
( args . model , sum ( [ m . numel ( ) for m in model . parameters ( ) ] ) ) )
@ -187,24 +189,23 @@ def main():
args . amp = False
model = nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
else :
if args . distributed and args . sync_bn and has_apex :
model = convert_syncbn_model ( model )
model . cuda ( )
optimizer = create_optimizer ( args , model )
if optimizer_state is not None :
optimizer . load_state_dict ( optimizer_state )
use_amp = False
if has_apex and args . amp :
model , optimizer = amp . initialize ( model , optimizer , opt_level = ' O1 ' )
use_amp = True
logging . info ( ' AMP enabled ' )
else :
use_amp = False
logging . info ( ' AMP disabled ' )
if args . local_rank == 0 :
logging . info ( ' NVIDIA APEX {} . AMP {} . ' . format (
' installed ' if has_apex else ' not installed ' , ' on ' if use_amp else ' off ' ) )
model_ema = None
if args . model_ema :
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma (
model ,
decay = args . model_ema_decay ,
@ -212,11 +213,23 @@ def main():
resume = args . resume )
if args . distributed :
if args . sync_bn :
try :
if has_apex :
model = convert_syncbn_model ( model )
else :
model = torch . nn . SyncBatchNorm . convert_sync_batchnorm ( model )
if args . local_rank == 0 :
logging . info ( ' Converted model to use Synchronized BatchNorm. ' )
except Exception as e :
logging . error ( ' Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 ' )
if has_apex :
model = DDP ( model , delay_allreduce = True )
if model_ema is not None and not args . model_ema_force_cpu :
# must also distribute EMA model to allow validation
model_ema . ema = DDP ( model_ema . ema , delay_allreduce = True )
model_ema . ema_has_module = True
else :
if args . local_rank == 0 :
logging . info ( " Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP. " )
model = DDP ( model , device_ids = [ args . local_rank ] ) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
if start_epoch > 0 :