|
|
|
@ -26,12 +26,11 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
|
|
|
|
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
|
|
|
|
|
decay_batch_step, check_batch_size_retry
|
|
|
|
|
|
|
|
|
|
has_apex = False
|
|
|
|
|
try:
|
|
|
|
|
from apex import amp
|
|
|
|
|
has_apex = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
pass
|
|
|
|
|
has_apex = False
|
|
|
|
|
|
|
|
|
|
has_native_amp = False
|
|
|
|
|
try:
|
|
|
|
@ -46,21 +45,18 @@ try:
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
has_functorch = False
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import torch._dynamo
|
|
|
|
|
has_dynamo = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
has_dynamo = False
|
|
|
|
|
pass
|
|
|
|
|
has_compile = hasattr(torch, 'compile')
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger('validate')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
|
|
|
|
parser.add_argument('data', metavar='DIR',
|
|
|
|
|
help='path to dataset')
|
|
|
|
|
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
|
|
|
|
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
|
|
|
|
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
|
|
|
|
|
help='path to dataset (*deprecated*, use --data-dir)')
|
|
|
|
|
parser.add_argument('--data-dir', metavar='DIR',
|
|
|
|
|
help='path to dataset (root dir)')
|
|
|
|
|
parser.add_argument('--dataset', metavar='NAME', default='',
|
|
|
|
|
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
|
|
|
|
|
parser.add_argument('--split', metavar='NAME', default='validation',
|
|
|
|
|
help='dataset split (default: validation)')
|
|
|
|
|
parser.add_argument('--dataset-download', action='store_true', default=False,
|
|
|
|
@ -125,19 +121,19 @@ parser.add_argument('--fuser', default='', type=str,
|
|
|
|
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
|
|
|
|
parser.add_argument('--fast-norm', default=False, action='store_true',
|
|
|
|
|
help='enable experimental fast-norm')
|
|
|
|
|
parser.add_argument('--dynamo-backend', default=None, type=str,
|
|
|
|
|
help="Select dynamo backend. Default: None")
|
|
|
|
|
|
|
|
|
|
scripting_group = parser.add_mutually_exclusive_group()
|
|
|
|
|
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
|
|
|
|
help='torch.jit.script the full model')
|
|
|
|
|
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
|
|
|
|
|
help="Enable compilation w/ specified backend (default: inductor).")
|
|
|
|
|
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
|
|
|
|
help="Enable AOT Autograd support.")
|
|
|
|
|
scripting_group.add_argument('--dynamo', default=False, action='store_true',
|
|
|
|
|
help="Enable Dynamo optimization.")
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
|
|
|
|
help='Output csv file for validation results (summary)')
|
|
|
|
|
parser.add_argument('--results-format', default='csv', type=str,
|
|
|
|
|
help='Format for results file one of (csv, json) (default: csv).')
|
|
|
|
|
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
|
|
|
|
help='Real labels JSON file for imagenet evaluation')
|
|
|
|
|
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
|
|
|
@ -218,16 +214,13 @@ def validate(args):
|
|
|
|
|
if args.torchscript:
|
|
|
|
|
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
elif args.torchcompile:
|
|
|
|
|
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
|
|
|
|
torch._dynamo.reset()
|
|
|
|
|
model = torch.compile(model, backend=args.torchcompile)
|
|
|
|
|
elif args.aot_autograd:
|
|
|
|
|
assert has_functorch, "functorch is needed for --aot-autograd"
|
|
|
|
|
model = memory_efficient_fusion(model)
|
|
|
|
|
elif args.dynamo:
|
|
|
|
|
assert has_dynamo, "torch._dynamo is needed for --dynamo"
|
|
|
|
|
torch._dynamo.reset()
|
|
|
|
|
if args.dynamo_backend is not None:
|
|
|
|
|
model = torch._dynamo.optimize(args.dynamo_backend)(model)
|
|
|
|
|
else:
|
|
|
|
|
model = torch._dynamo.optimize()(model)
|
|
|
|
|
|
|
|
|
|
if use_amp == 'apex':
|
|
|
|
|
model = amp.initialize(model, opt_level='O1')
|
|
|
|
@ -407,7 +400,6 @@ def main():
|
|
|
|
|
model_cfgs = [(n, None) for n in model_names if n]
|
|
|
|
|
|
|
|
|
|
if len(model_cfgs):
|
|
|
|
|
results_file = args.results_file or './results-all.csv'
|
|
|
|
|
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
|
|
|
|
results = []
|
|
|
|
|
try:
|
|
|
|
@ -424,19 +416,28 @@ def main():
|
|
|
|
|
except KeyboardInterrupt as e:
|
|
|
|
|
pass
|
|
|
|
|
results = sorted(results, key=lambda x: x['top1'], reverse=True)
|
|
|
|
|
if len(results):
|
|
|
|
|
write_results(results_file, results)
|
|
|
|
|
else:
|
|
|
|
|
if args.retry:
|
|
|
|
|
results = _try_run(args, args.batch_size)
|
|
|
|
|
else:
|
|
|
|
|
results = validate(args)
|
|
|
|
|
|
|
|
|
|
if args.results_file:
|
|
|
|
|
write_results(args.results_file, results, format=args.results_format)
|
|
|
|
|
|
|
|
|
|
# output results in JSON to stdout w/ delimiter for runner script
|
|
|
|
|
print(f'--result\n{json.dumps(results, indent=4)}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_results(results_file, results):
|
|
|
|
|
def write_results(results_file, results, format='csv'):
|
|
|
|
|
with open(results_file, mode='w') as cf:
|
|
|
|
|
if format == 'json':
|
|
|
|
|
json.dump(results, cf, indent=4)
|
|
|
|
|
else:
|
|
|
|
|
if not isinstance(results, (list, tuple)):
|
|
|
|
|
results = [results]
|
|
|
|
|
if not results:
|
|
|
|
|
return
|
|
|
|
|
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
|
|
|
|
dw.writeheader()
|
|
|
|
|
for r in results:
|
|
|
|
@ -444,5 +445,6 @@ def write_results(results_file, results):
|
|
|
|
|
cf.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main()
|
|
|
|
|