@ -52,12 +52,16 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help = ' disable test time pool ' )
help = ' disable test time pool ' )
parser . add_argument ( ' --no-prefetcher ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --no-prefetcher ' , action = ' store_true ' , default = False ,
help = ' disable fast prefetcher ' )
help = ' disable fast prefetcher ' )
parser . add_argument ( ' --pin-mem ' , action = ' store_true ' , default = False ,
help = ' Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. ' )
parser . add_argument ( ' --fp16 ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --fp16 ' , action = ' store_true ' , default = False ,
help = ' Use half precision (fp16) ' )
help = ' Use half precision (fp16) ' )
parser . add_argument ( ' --tf-preprocessing ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --tf-preprocessing ' , action = ' store_true ' , default = False ,
help = ' Use Tensorflow preprocessing pipeline (require CPU TF installed ' )
help = ' Use Tensorflow preprocessing pipeline (require CPU TF installed ' )
parser . add_argument ( ' --use-ema ' , dest = ' use_ema ' , action = ' store_true ' ,
parser . add_argument ( ' --use-ema ' , dest = ' use_ema ' , action = ' store_true ' ,
help = ' use ema version of weights if present ' )
help = ' use ema version of weights if present ' )
parser . add_argument ( ' --torchscript ' , dest = ' torchscript ' , action = ' store_true ' ,
help = ' convert model torchscript for inference ' )
def validate ( args ) :
def validate ( args ) :
@ -81,6 +85,10 @@ def validate(args):
data_config = resolve_data_config ( vars ( args ) , model = model )
data_config = resolve_data_config ( vars ( args ) , model = model )
model , test_time_pool = apply_test_time_pool ( model , data_config , args )
model , test_time_pool = apply_test_time_pool ( model , data_config , args )
if args . torchscript :
torch . jit . optimized_execution ( True )
model = torch . jit . script ( model )
if args . num_gpu > 1 :
if args . num_gpu > 1 :
model = torch . nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
model = torch . nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
else :
else :
@ -107,6 +115,7 @@ def validate(args):
std = data_config [ ' std ' ] ,
std = data_config [ ' std ' ] ,
num_workers = args . workers ,
num_workers = args . workers ,
crop_pct = crop_pct ,
crop_pct = crop_pct ,
pin_memory = args . pin_mem ,
fp16 = args . fp16 ,
fp16 = args . fp16 ,
tf_preprocessing = args . tf_preprocessing )
tf_preprocessing = args . tf_preprocessing )