Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570

pull/1520/head
Ross Wightman 1 year ago
parent c59d88339b
commit bc8776085a

@ -56,13 +56,7 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try: has_compile = hasattr(torch, 'compile')
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
pass
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -81,8 +75,10 @@ parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--no-retry', action='store_true', default=False, parser.add_argument('--no-retry', action='store_true', default=False,
help='Do not decay batch size and retry on error.') help='Do not decay batch size and retry on error.')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', parser.add_argument('--results-file', default='', type=str,
help='Output csv file for validation results (summary)') help='Output csv file for validation results (summary)')
parser.add_argument('--results-format', default='csv', type=str,
help='Format for results file one of (csv, json) (default: csv).')
parser.add_argument('--num-warm-iter', default=10, type=int, parser.add_argument('--num-warm-iter', default=10, type=int,
metavar='N', help='Number of warmup iterations (default: 10)') metavar='N', help='Number of warmup iterations (default: 10)')
parser.add_argument('--num-bench-iter', default=40, type=int, parser.add_argument('--num-bench-iter', default=40, type=int,
@ -113,8 +109,6 @@ parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
@ -122,10 +116,11 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference') help='convert model torchscript for inference')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd optimization.") help="Enable AOT Autograd optimization.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
# train optimizer parameters # train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -218,9 +213,8 @@ class BenchmarkRunner:
detail=False, detail=False,
device='cuda', device='cuda',
torchscript=False, torchscript=False,
torchcompile=None,
aot_autograd=False, aot_autograd=False,
dynamo=False,
dynamo_backend=None,
precision='float32', precision='float32',
fuser='', fuser='',
num_warm_iter=10, num_warm_iter=10,
@ -259,20 +253,19 @@ class BenchmarkRunner:
self.input_size = data_config['input_size'] self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256) self.batch_size = kwargs.pop('batch_size', 256)
self.scripted = False self.compiled = False
if torchscript: if torchscript:
self.model = torch.jit.script(self.model) self.model = torch.jit.script(self.model)
self.scripted = True self.compiled = True
elif dynamo: elif torchcompile:
assert has_dynamo, "torch._dynamo is needed for --dynamo" assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
torch._dynamo.reset() torch._dynamo.reset()
if dynamo_backend is not None: self.model = torch.compile(self.model, backend=torchcompile)
self.model = torch._dynamo.optimize(dynamo_backend)(self.model) self.compiled = True
else:
self.model = torch._dynamo.optimize()(self.model)
elif aot_autograd: elif aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model) self.model = memory_efficient_fusion(self.model)
self.compiled = True
self.example_inputs = None self.example_inputs = None
self.num_warm_iter = num_warm_iter self.num_warm_iter = num_warm_iter
@ -344,7 +337,7 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
param_count=round(self.param_count / 1e6, 2), param_count=round(self.param_count / 1e6, 2),
) )
retries = 0 if self.scripted else 2 # skip profiling if model is scripted retries = 0 if self.compiled else 2 # skip profiling if model is scripted
while retries: while retries:
retries -= 1 retries -= 1
try: try:
@ -642,7 +635,6 @@ def main():
model_cfgs = [(n, None) for n in model_names] model_cfgs = [(n, None) for n in model_names]
if len(model_cfgs): if len(model_cfgs):
results_file = args.results_file or './benchmark.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = [] results = []
try: try:
@ -663,22 +655,30 @@ def main():
sort_key = 'infer_gmacs' sort_key = 'infer_gmacs'
results = filter(lambda x: sort_key in x, results) results = filter(lambda x: sort_key in x, results)
results = sorted(results, key=lambda x: x[sort_key], reverse=True) results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results):
write_results(results_file, results)
else: else:
results = benchmark(args) results = benchmark(args)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script # output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}') print(f'--result\n{json.dumps(results, indent=4)}')
def write_results(results_file, results): def write_results(results_file, results, format='csv'):
with open(results_file, mode='w') as cf: with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys()) if format == 'json':
dw.writeheader() json.dump(results, cf, indent=4)
for r in results: else:
dw.writerow(r) if not isinstance(results, (list, tuple)):
cf.flush() results = [results]
if not results:
return
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__': if __name__ == '__main__':

@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
import os import os
import time import time
import argparse import argparse
import json
import logging import logging
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
@ -41,11 +42,7 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try: has_compile = hasattr(torch, 'compile')
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
_FMT_EXT = { _FMT_EXT = {
@ -60,14 +57,16 @@ _logger = logging.getLogger('inference')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset') help='path to dataset (*deprecated*, use --data-dir)')
parser.add_argument('--dataset', '-d', metavar='NAME', default='', parser.add_argument('--data-dir', metavar='DIR',
help='dataset type (default: ImageFolder/ImageTar if empty)') help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation', parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)') help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', parser.add_argument('--model', '-m', metavar='MODEL', default='resnet50',
help='model architecture (default: dpn92)') help='model architecture (default: resnet50)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)') help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -112,16 +111,14 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)') help='lower precision AMP dtype (default: float16)')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true', scripting_group.add_argument('--torchscript', default=False, action='store_true',
help='torch.jit.script the full model') help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.") help="Enable AOT Autograd support.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
parser.add_argument('--results-dir',type=str, default=None, parser.add_argument('--results-dir',type=str, default=None,
help='folder for output results') help='folder for output results')
@ -160,7 +157,6 @@ def main():
device = torch.device(args.device) device = torch.device(args.device)
# resolve AMP arguments based on PyTorch / Apex availability # resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress amp_autocast = suppress
if args.amp: if args.amp:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
@ -201,22 +197,20 @@ def main():
if args.torchscript: if args.torchscript:
model = torch.jit.script(model) model = torch.jit.script(model)
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
elif args.aot_autograd: elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
elif args.dynamo:
assert has_dynamo, "torch._dynamo is needed for --dynamo"
torch._dynamo.reset()
if args.dynamo_backend is not None:
model = torch._dynamo.optimize(args.dynamo_backend)(model)
else:
model = torch._dynamo.optimize()(model)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
root_dir = args.data or args.data_dir
dataset = create_dataset( dataset = create_dataset(
root=args.data, root=root_dir,
name=args.dataset, name=args.dataset,
split=args.split, split=args.split,
class_map=args.class_map, class_map=args.class_map,
@ -304,6 +298,9 @@ def main():
for fmt in args.results_format: for fmt in args.results_format:
save_results(df, results_filename, fmt) save_results(df, results_filename, fmt)
print(f'--result')
print(json.dumps(dict(filename=results_filename)))
def save_results(df, results_filename, results_format='csv', filename_col='filename'): def save_results(df, results_filename, results_format='csv', filename_col='filename'):
results_filename += _FMT_EXT[results_format] results_filename += _FMT_EXT[results_format]

@ -66,12 +66,7 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try: has_compile = hasattr(torch, 'compile')
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
pass
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
@ -88,10 +83,12 @@ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters # Dataset parameters
group = parser.add_argument_group('Dataset parameters') group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside of the dataset group because it is positional. # Keep this argument outside of the dataset group because it is positional.
parser.add_argument('data_dir', metavar='DIR', parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset') help='path to dataset (positional is *deprecated*, use --data-dir)')
group.add_argument('--dataset', '-d', metavar='NAME', default='', parser.add_argument('--data-dir', metavar='DIR',
help='dataset type (default: ImageFolder/ImageTar if empty)') help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train', group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)') help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation', group.add_argument('--val-split', metavar='NAME', default='validation',
@ -143,16 +140,14 @@ group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages') help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true', group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
scripting_group = group.add_mutually_exclusive_group() scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model') help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.") help="Enable AOT Autograd support.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
# Optimizer parameters # Optimizer parameters
group = parser.add_argument_group('Optimizer parameters') group = parser.add_argument_group('Optimizer parameters')
@ -377,6 +372,8 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
if args.data and not args.data_dir:
args.data_dir = args.data
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
device = utils.init_distributed_device(args) device = utils.init_distributed_device(args)
if args.distributed: if args.distributed:
@ -485,18 +482,16 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
elif args.torchcompile:
# FIXME dynamo might need move below DDP wrapping? TBD
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
elif args.aot_autograd: elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
elif args.dynamo:
# FIXME dynamo might need move below DDP wrapping? TBD
assert has_dynamo, "torch._dynamo is needed for --dynamo"
if args.dynamo_backend is not None:
model = torch._dynamo.optimize(args.dynamo_backend)(model)
else:
model = torch._dynamo.optimize()(model)
if args.lr is None: if not args.lr:
global_batch_size = args.batch_size * args.world_size global_batch_size = args.batch_size * args.world_size
batch_ratio = global_batch_size / args.lr_base_size batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale: if not args.lr_base_scale:

@ -26,12 +26,11 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
decay_batch_step, check_batch_size_retry decay_batch_step, check_batch_size_retry
has_apex = False
try: try:
from apex import amp from apex import amp
has_apex = True has_apex = True
except ImportError: except ImportError:
pass has_apex = False
has_native_amp = False has_native_amp = False
try: try:
@ -46,21 +45,18 @@ try:
except ImportError as e: except ImportError as e:
has_functorch = False has_functorch = False
try: has_compile = hasattr(torch, 'compile')
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
pass
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset') help='path to dataset (*deprecated*, use --data-dir)')
parser.add_argument('--dataset', '-d', metavar='NAME', default='', parser.add_argument('--data-dir', metavar='DIR',
help='dataset type (default: ImageFolder/ImageTar if empty)') help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation', parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)') help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False, parser.add_argument('--dataset-download', action='store_true', default=False,
@ -125,19 +121,19 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true', parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm') help='enable experimental fast-norm')
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
scripting_group = parser.add_mutually_exclusive_group() scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true', scripting_group.add_argument('--torchscript', default=False, action='store_true',
help='torch.jit.script the full model') help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.") help="Enable AOT Autograd support.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)') help='Output csv file for validation results (summary)')
parser.add_argument('--results-format', default='csv', type=str,
help='Format for results file one of (csv, json) (default: csv).')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation') help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
@ -218,16 +214,13 @@ def validate(args):
if args.torchscript: if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
elif args.torchcompile:
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
torch._dynamo.reset()
model = torch.compile(model, backend=args.torchcompile)
elif args.aot_autograd: elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
elif args.dynamo:
assert has_dynamo, "torch._dynamo is needed for --dynamo"
torch._dynamo.reset()
if args.dynamo_backend is not None:
model = torch._dynamo.optimize(args.dynamo_backend)(model)
else:
model = torch._dynamo.optimize()(model)
if use_amp == 'apex': if use_amp == 'apex':
model = amp.initialize(model, opt_level='O1') model = amp.initialize(model, opt_level='O1')
@ -407,7 +400,6 @@ def main():
model_cfgs = [(n, None) for n in model_names if n] model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs): if len(model_cfgs):
results_file = args.results_file or './results-all.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = [] results = []
try: try:
@ -424,24 +416,34 @@ def main():
except KeyboardInterrupt as e: except KeyboardInterrupt as e:
pass pass
results = sorted(results, key=lambda x: x['top1'], reverse=True) results = sorted(results, key=lambda x: x['top1'], reverse=True)
if len(results):
write_results(results_file, results)
else: else:
if args.retry: if args.retry:
results = _try_run(args, args.batch_size) results = _try_run(args, args.batch_size)
else: else:
results = validate(args) results = validate(args)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script # output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}') print(f'--result\n{json.dumps(results, indent=4)}')
def write_results(results_file, results): def write_results(results_file, results, format='csv'):
with open(results_file, mode='w') as cf: with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys()) if format == 'json':
dw.writeheader() json.dump(results, cf, indent=4)
for r in results: else:
dw.writerow(r) if not isinstance(results, (list, tuple)):
cf.flush() results = [results]
if not results:
return
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save