|
|
|
@ -11,6 +11,7 @@ import argparse
|
|
|
|
|
import os
|
|
|
|
|
import csv
|
|
|
|
|
import glob
|
|
|
|
|
import json
|
|
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
import torch
|
|
|
|
@ -263,6 +264,7 @@ def validate(args):
|
|
|
|
|
else:
|
|
|
|
|
top1a, top5a = top1.avg, top5.avg
|
|
|
|
|
results = OrderedDict(
|
|
|
|
|
model=args.model,
|
|
|
|
|
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
|
|
|
|
|
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
|
|
|
|
|
param_count=round(param_count / 1e6, 2),
|
|
|
|
@ -276,6 +278,27 @@ def validate(args):
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _try_run(args, initial_batch_size):
|
|
|
|
|
batch_size = initial_batch_size
|
|
|
|
|
results = OrderedDict()
|
|
|
|
|
error_str = 'Unknown'
|
|
|
|
|
while batch_size >= 1:
|
|
|
|
|
args.batch_size = batch_size
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
try:
|
|
|
|
|
results = validate(args)
|
|
|
|
|
return results
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
error_str = str(e)
|
|
|
|
|
if 'channels_last' in error_str:
|
|
|
|
|
break
|
|
|
|
|
_logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.')
|
|
|
|
|
batch_size = batch_size // 2
|
|
|
|
|
results['error'] = error_str
|
|
|
|
|
_logger.error(f'{args.model} failed to validate ({error_str}).')
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
setup_default_logging()
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
@ -308,36 +331,25 @@ def main():
|
|
|
|
|
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
|
|
|
|
results = []
|
|
|
|
|
try:
|
|
|
|
|
start_batch_size = args.batch_size
|
|
|
|
|
initial_batch_size = args.batch_size
|
|
|
|
|
for m, c in model_cfgs:
|
|
|
|
|
batch_size = start_batch_size
|
|
|
|
|
args.model = m
|
|
|
|
|
args.checkpoint = c
|
|
|
|
|
result = OrderedDict(model=args.model)
|
|
|
|
|
r = {}
|
|
|
|
|
while not r and batch_size >= args.num_gpu:
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
try:
|
|
|
|
|
args.batch_size = batch_size
|
|
|
|
|
print('Validating with batch size: %d' % args.batch_size)
|
|
|
|
|
r = validate(args)
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
if batch_size <= args.num_gpu:
|
|
|
|
|
print("Validation failed with no ability to reduce batch size. Exiting.")
|
|
|
|
|
raise e
|
|
|
|
|
batch_size = max(batch_size // 2, args.num_gpu)
|
|
|
|
|
print("Validation failed, reducing batch size by 50%")
|
|
|
|
|
result.update(r)
|
|
|
|
|
r = _try_run(args, initial_batch_size)
|
|
|
|
|
if 'error' in r:
|
|
|
|
|
continue
|
|
|
|
|
if args.checkpoint:
|
|
|
|
|
result['checkpoint'] = args.checkpoint
|
|
|
|
|
results.append(result)
|
|
|
|
|
r['checkpoint'] = args.checkpoint
|
|
|
|
|
results.append(r)
|
|
|
|
|
except KeyboardInterrupt as e:
|
|
|
|
|
pass
|
|
|
|
|
results = sorted(results, key=lambda x: x['top1'], reverse=True)
|
|
|
|
|
if len(results):
|
|
|
|
|
write_results(results_file, results)
|
|
|
|
|
else:
|
|
|
|
|
validate(args)
|
|
|
|
|
results = validate(args)
|
|
|
|
|
# output results in JSON to stdout w/ delimiter for runner script
|
|
|
|
|
print(f'--result\n{json.dumps(results, indent=4)}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_results(results_file, results):
|
|
|
|
|