@ -55,6 +55,8 @@ parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help = ' Type of global pool, " avg " , " max " , " avgmax " , " avgmaxc " (default: " avg " ) ' )
parser . add_argument ( ' --img-size ' , type = int , default = None , metavar = ' N ' ,
help = ' Image patch size (default: None => model default) ' )
parser . add_argument ( ' --crop-pct ' , default = None , type = float ,
metavar = ' N ' , help = ' Input image center crop percent (for validation only) ' )
parser . add_argument ( ' --mean ' , type = float , nargs = ' + ' , default = None , metavar = ' MEAN ' ,
help = ' Override mean pixel value of dataset ' )
parser . add_argument ( ' --std ' , type = float , nargs = ' + ' , default = None , metavar = ' STD ' ,
@ -121,6 +123,10 @@ parser.add_argument('--bn-momentum', type=float, default=None,
help = ' BatchNorm momentum override (if not None) ' )
parser . add_argument ( ' --bn-eps ' , type = float , default = None ,
help = ' BatchNorm epsilon override (if not None) ' )
parser . add_argument ( ' --sync-bn ' , action = ' store_true ' ,
help = ' Enable NVIDIA Apex or Torch synchronized BatchNorm. ' )
parser . add_argument ( ' --dist-bn ' , type = str , default = ' ' ,
help = ' Distribute BatchNorm stats between nodes after each epoch ( " broadcast " , " reduce " , or " " ) ' )
# Model Exponential Moving Average
parser . add_argument ( ' --model-ema ' , action = ' store_true ' , default = False ,
help = ' Enable tracking moving average of model weights ' )
@ -143,10 +149,6 @@ parser.add_argument('--save-images', action='store_true', default=False,
help = ' save images of input bathes every log interval for debugging ' )
parser . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
help = ' use NVIDIA amp for mixed precision training ' )
parser . add_argument ( ' --sync-bn ' , action = ' store_true ' ,
help = ' enabling apex sync BN. ' )
parser . add_argument ( ' --reduce-bn ' , action = ' store_true ' ,
help = ' average BN running stats across all distributed nodes between train and validation. ' )
parser . add_argument ( ' --no-prefetcher ' , action = ' store_true ' , default = False ,
help = ' disable fast prefetcher ' )
parser . add_argument ( ' --output ' , default = ' ' , type = str , metavar = ' PATH ' ,
@ -349,6 +351,7 @@ def main():
std = data_config [ ' std ' ] ,
num_workers = args . workers ,
distributed = args . distributed ,
crop_pct = data_config [ ' crop_pct ' ] ,
)
if args . mixup > 0. :
@ -390,16 +393,16 @@ def main():
lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
use_amp = use_amp , model_ema = model_ema )
if args . distributed and args . reduce_bn :
if args . distributed and args . dist_bn and args . dist_bn in ( ' broadcast ' , ' reduce ' ) :
if args . local_rank == 0 :
logging . info ( " Averaging bn running means and vars" )
reduc e_bn( model , args . world_size )
logging . info ( " Distributing BatchNorm running means and vars" )
distribut e_bn( model , args . world_size , args . dist_bn == ' reduce ' )
eval_metrics = validate ( model , loader_eval , validate_loss_fn , args )
if model_ema is not None and not args . model_ema_force_cpu :
if args . distributed and args . reduce_bn :
reduc e_bn( model_ema , args . world_size )
distribut e_bn( model_ema , args . world_size )
ema_eval_metrics = validate (
model_ema . ema , loader_eval , validate_loss_fn , args , log_suffix = ' (EMA) ' )