|
|
|
@ -30,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|
|
|
|
metavar='N', help='mini-batch size (default: 256)')
|
|
|
|
|
parser.add_argument('--img-size', default=None, type=int,
|
|
|
|
|
metavar='N', help='Input image dimension, uses model default if empty')
|
|
|
|
|
parser.add_argument('--crop-pct', default=None, type=float,
|
|
|
|
|
metavar='N', help='Input image center crop pct')
|
|
|
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
|
|
|
help='Override mean pixel value of dataset')
|
|
|
|
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
|
|
|
@ -81,6 +83,7 @@ def validate(args):
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
|
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
|
|
|
|
loader = create_loader(
|
|
|
|
|
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
|
|
|
|
input_size=data_config['input_size'],
|
|
|
|
@ -90,7 +93,7 @@ def validate(args):
|
|
|
|
|
mean=data_config['mean'],
|
|
|
|
|
std=data_config['std'],
|
|
|
|
|
num_workers=args.workers,
|
|
|
|
|
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
|
|
|
|
|
crop_pct=crop_pct,
|
|
|
|
|
tf_preprocessing=args.tf_preprocessing)
|
|
|
|
|
|
|
|
|
|
batch_time = AverageMeter()
|
|
|
|
@ -124,16 +127,19 @@ def validate(args):
|
|
|
|
|
'Test: [{0:>4d}/{1}] '
|
|
|
|
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
|
|
|
|
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
|
|
|
|
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
|
|
|
|
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
|
|
|
|
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
|
|
|
|
i, len(loader), batch_time=batch_time,
|
|
|
|
|
rate_avg=input.size(0) / batch_time.avg,
|
|
|
|
|
loss=losses, top1=top1, top5=top5))
|
|
|
|
|
|
|
|
|
|
results = OrderedDict(
|
|
|
|
|
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
|
|
|
|
|
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
|
|
|
|
|
param_count=round(param_count / 1e6, 2))
|
|
|
|
|
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
|
|
|
|
|
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
|
|
|
|
|
param_count=round(param_count / 1e6, 2),
|
|
|
|
|
img_size=data_config['input_size'][-1],
|
|
|
|
|
cropt_pct=crop_pct,
|
|
|
|
|
interpolation=data_config['interpolation'])
|
|
|
|
|
|
|
|
|
|
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
|
|
|
|
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
|
|
|
@ -155,7 +161,7 @@ def main():
|
|
|
|
|
if args.model == 'all':
|
|
|
|
|
# validate all models in a list of names with pretrained checkpoints
|
|
|
|
|
args.pretrained = True
|
|
|
|
|
model_names = list_models()
|
|
|
|
|
model_names = list_models(pretrained=True)
|
|
|
|
|
model_cfgs = [(n, '') for n in model_names]
|
|
|
|
|
elif not is_model(args.model):
|
|
|
|
|
# model name doesn't exist, try as wildcard filter
|
|
|
|
@ -170,7 +176,8 @@ def main():
|
|
|
|
|
args.model = m
|
|
|
|
|
args.checkpoint = c
|
|
|
|
|
result = OrderedDict(model=args.model)
|
|
|
|
|
result.update(validate(args))
|
|
|
|
|
r = validate(args)
|
|
|
|
|
result.update(r)
|
|
|
|
|
if args.checkpoint:
|
|
|
|
|
result['checkpoint'] = args.checkpoint
|
|
|
|
|
dw = csv.DictWriter(cf, fieldnames=result.keys())
|
|
|
|
|