|
|
|
@ -50,7 +50,11 @@ parser.add_argument('--num-gpu', type=int, default=1,
|
|
|
|
|
help='Number of GPUS to use')
|
|
|
|
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
|
|
|
|
help='disable test time pool')
|
|
|
|
|
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
|
|
|
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
|
parser.add_argument('--fp16', action='store_true', default=False,
|
|
|
|
|
help='Use half precision (fp16)')
|
|
|
|
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
|
|
|
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
|
|
|
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|
|
|
|
help='use ema version of weights if present')
|
|
|
|
@ -59,6 +63,7 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|
|
|
|
def validate(args):
|
|
|
|
|
# might as well try to validate something
|
|
|
|
|
args.pretrained = args.pretrained or not args.checkpoint
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
|
|
|
|
|
# create model
|
|
|
|
|
model = create_model(
|
|
|
|
@ -81,6 +86,9 @@ def validate(args):
|
|
|
|
|
else:
|
|
|
|
|
model = model.cuda()
|
|
|
|
|
|
|
|
|
|
if args.fp16:
|
|
|
|
|
model = model.half()
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
|
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
|
|
|
@ -88,12 +96,13 @@ def validate(args):
|
|
|
|
|
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
|
|
|
|
input_size=data_config['input_size'],
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
use_prefetcher=True,
|
|
|
|
|
use_prefetcher=args.prefetcher,
|
|
|
|
|
interpolation=data_config['interpolation'],
|
|
|
|
|
mean=data_config['mean'],
|
|
|
|
|
std=data_config['std'],
|
|
|
|
|
num_workers=args.workers,
|
|
|
|
|
crop_pct=crop_pct,
|
|
|
|
|
fp16=args.fp16,
|
|
|
|
|
tf_preprocessing=args.tf_preprocessing)
|
|
|
|
|
|
|
|
|
|
batch_time = AverageMeter()
|
|
|
|
@ -105,8 +114,11 @@ def validate(args):
|
|
|
|
|
end = time.time()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for i, (input, target) in enumerate(loader):
|
|
|
|
|
target = target.cuda()
|
|
|
|
|
input = input.cuda()
|
|
|
|
|
if args.no_prefetcher:
|
|
|
|
|
target = target.cuda()
|
|
|
|
|
input = input.cuda()
|
|
|
|
|
if args.fp16:
|
|
|
|
|
input = input.half()
|
|
|
|
|
|
|
|
|
|
# compute output
|
|
|
|
|
output = model(input)
|
|
|
|
@ -125,7 +137,7 @@ def validate(args):
|
|
|
|
|
if i % args.log_freq == 0:
|
|
|
|
|
logging.info(
|
|
|
|
|
'Test: [{0:>4d}/{1}] '
|
|
|
|
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
|
|
|
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
|
|
|
|
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
|
|
|
|
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
|
|
|
|