Add crop_pct arg to validate, extra fields to csv output, 'all' filters pretrained

pull/19/head
Ross Wightman 6 years ago
parent 949b7a81c4
commit edb425ea48

@ -30,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -81,6 +83,7 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda()
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
Dataset(args.data, load_bytes=args.tf_preprocessing),
input_size=data_config['input_size'],
@ -90,7 +93,7 @@ def validate(args):
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
crop_pct=crop_pct,
tf_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter()
@ -124,16 +127,19 @@ def validate(args):
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
i, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
results = OrderedDict(
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
param_count=round(param_count / 1e6, 2))
top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4),
top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
cropt_pct=crop_pct,
interpolation=data_config['interpolation'])
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
@ -155,7 +161,7 @@ def main():
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models()
model_names = list_models(pretrained=True)
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
@ -170,7 +176,8 @@ def main():
args.model = m
args.checkpoint = c
result = OrderedDict(model=args.model)
result.update(validate(args))
r = validate(args)
result.update(r)
if args.checkpoint:
result['checkpoint'] = args.checkpoint
dw = csv.DictWriter(cf, fieldnames=result.keys())

Loading…
Cancel
Save