@ -10,6 +10,7 @@ try:
from apex . parallel import convert_syncbn_model
from apex . parallel import convert_syncbn_model
has_apex = True
has_apex = True
except ImportError :
except ImportError :
from torch . nn . parallel import DistributedDataParallel as DDP
has_apex = False
has_apex = False
from timm . data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
from timm . data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
@ -169,8 +170,9 @@ def main():
bn_eps = args . bn_eps ,
bn_eps = args . bn_eps ,
checkpoint_path = args . initial_checkpoint )
checkpoint_path = args . initial_checkpoint )
logging . info ( ' Model %s created, param count: %d ' %
if args . local_rank == 0 :
( args . model , sum ( [ m . numel ( ) for m in model . parameters ( ) ] ) ) )
logging . info ( ' Model %s created, param count: %d ' %
( args . model , sum ( [ m . numel ( ) for m in model . parameters ( ) ] ) ) )
data_config = resolve_data_config ( model , args , verbose = args . local_rank == 0 )
data_config = resolve_data_config ( model , args , verbose = args . local_rank == 0 )
@ -187,24 +189,23 @@ def main():
args . amp = False
args . amp = False
model = nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
model = nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
else :
else :
if args . distributed and args . sync_bn and has_apex :
model = convert_syncbn_model ( model )
model . cuda ( )
model . cuda ( )
optimizer = create_optimizer ( args , model )
optimizer = create_optimizer ( args , model )
if optimizer_state is not None :
if optimizer_state is not None :
optimizer . load_state_dict ( optimizer_state )
optimizer . load_state_dict ( optimizer_state )
use_amp = False
if has_apex and args . amp :
if has_apex and args . amp :
model , optimizer = amp . initialize ( model , optimizer , opt_level = ' O1 ' )
model , optimizer = amp . initialize ( model , optimizer , opt_level = ' O1 ' )
use_amp = True
use_amp = True
logging . info ( ' AMP enabled ' )
if args . local_rank == 0 :
else :
logging . info ( ' NVIDIA APEX {} . AMP {} . ' . format (
use_amp = False
' installed ' if has_apex else ' not installed ' , ' on ' if use_amp else ' off ' ) )
logging . info ( ' AMP disabled ' )
model_ema = None
model_ema = None
if args . model_ema :
if args . model_ema :
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma (
model_ema = ModelEma (
model ,
model ,
decay = args . model_ema_decay ,
decay = args . model_ema_decay ,
@ -212,11 +213,23 @@ def main():
resume = args . resume )
resume = args . resume )
if args . distributed :
if args . distributed :
model = DDP ( model , delay_allreduce = True )
if args . sync_bn :
if model_ema is not None and not args . model_ema_force_cpu :
try :
# must also distribute EMA model to allow validation
if has_apex :
model_ema . ema = DDP ( model_ema . ema , delay_allreduce = True )
model = convert_syncbn_model ( model )
model_ema . ema_has_module = True
else :
model = torch . nn . SyncBatchNorm . convert_sync_batchnorm ( model )
if args . local_rank == 0 :
logging . info ( ' Converted model to use Synchronized BatchNorm. ' )
except Exception as e :
logging . error ( ' Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 ' )
if has_apex :
model = DDP ( model , delay_allreduce = True )
else :
if args . local_rank == 0 :
logging . info ( " Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP. " )
model = DDP ( model , device_ids = [ args . local_rank ] ) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
if start_epoch > 0 :
if start_epoch > 0 :