More inference script changes, arg naming, multiple output fmts at once

pull/1582/head
Ross Wightman 2 years ago committed by Ross Wightman
parent eceeb9409a
commit 05637a4bb0

@ -48,6 +48,13 @@ except ImportError:
has_dynamo = False
_FMT_EXT = {
'json': '.json',
'json-split': '.json',
'parquet': '.parquet',
'csv': '.csv',
}
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('inference')
@ -103,8 +110,6 @@ parser.add_argument('--amp', action='store_true', default=False,
help='use Native AMP for mixed precision training')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--dynamo-backend', default=None, type=str,
@ -122,21 +127,23 @@ parser.add_argument('--results-dir',type=str, default=None,
help='folder for output results')
parser.add_argument('--results-file', type=str, default=None,
help='results filename (relative to results-dir)')
parser.add_argument('--results-format', type=str, default='csv',
parser.add_argument('--results-format', type=str, nargs='+', default=['csv'],
help='results format (one of "csv", "json", "json-split", "parquet")')
parser.add_argument('--results-separate-col', action='store_true', default=False,
help='separate output columns per result index.')
parser.add_argument('--topk', default=1, type=int,
metavar='N', help='Top-k to output to CSV')
parser.add_argument('--fullname', action='store_true', default=False,
help='use full sample name in output (not just basename).')
parser.add_argument('--indices-name', default='index',
parser.add_argument('--filename-col', default='filename',
help='name for filename / sample name column')
parser.add_argument('--index-col', default='index',
help='name for output indices column(s)')
parser.add_argument('--outputs-name', default=None,
parser.add_argument('--output-col', default=None,
help='name for logit/probs output column(s)')
parser.add_argument('--outputs-type', default='prob',
parser.add_argument('--output-type', default='prob',
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
parser.add_argument('--separate-columns', action='store_true', default=False,
help='separate output columns per result index.')
parser.add_argument('--exclude-outputs', action='store_true', default=False,
parser.add_argument('--exclude-output', action='store_true', default=False,
help='exclude logits/probs from results, just indices. topk must be set !=0.')
@ -179,9 +186,6 @@ def main():
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
_logger.info(
f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}')
@ -221,11 +225,12 @@ def main():
if test_time_pool:
data_config['crop_pct'] = 1.0
workers = 1 if 'tfds' in args.dataset or 'wds' in args.dataset else args.workers
loader = create_loader(
dataset,
batch_size=args.batch_size,
use_prefetcher=True,
num_workers=args.workers,
num_workers=workers,
**data_config,
)
@ -234,7 +239,7 @@ def main():
end = time.time()
all_indices = []
all_outputs = []
use_probs = args.outputs_type == 'prob'
use_probs = args.output_type == 'prob'
with torch.no_grad():
for batch_idx, (input, _) in enumerate(loader):
@ -262,52 +267,53 @@ def main():
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
filenames = loader.dataset.filenames(basename=not args.fullname)
outputs_name = args.outputs_name or ('prob' if use_probs else 'logit')
data_dict = {'filename': filenames}
if args.separate_columns and all_outputs.shape[-1] > 1:
output_col = args.output_col or ('prob' if use_probs else 'logit')
data_dict = {args.filename_col: filenames}
if args.results_separate_col and all_outputs.shape[-1] > 1:
if all_indices is not None:
for i in range(all_indices.shape[-1]):
data_dict[f'{args.indices_name}_{i}'] = all_indices[:, i]
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
for i in range(all_outputs.shape[-1]):
data_dict[f'{outputs_name}_{i}'] = all_outputs[:, i]
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
else:
if all_indices is not None:
if all_indices.shape[-1] == 1:
all_indices = all_indices.squeeze(-1)
data_dict[args.indices_name] = list(all_indices)
data_dict[args.index_col] = list(all_indices)
if all_outputs.shape[-1] == 1:
all_outputs = all_outputs.squeeze(-1)
data_dict[outputs_name] = list(all_outputs)
data_dict[output_col] = list(all_outputs)
df = pd.DataFrame(data=data_dict)
results_filename = args.results_file
needs_ext = False
if not results_filename:
if results_filename:
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
if ext and ext in _FMT_EXT.values():
# if filename provided with one of expected ext,
# remove it as it will be added back
results_filename = filename_no_ext
else:
# base default filename on model name + img-size
img_size = data_config["input_size"][1]
results_filename = f'{args.model}-{img_size}'
needs_ext = True
if args.results_dir:
results_filename = os.path.join(args.results_dir, results_filename)
if args.results_format == 'parquet':
if needs_ext:
results_filename += '.parquet'
df = df.set_index('filename')
df.to_parquet(results_filename)
elif args.results_format == 'json':
if needs_ext:
results_filename += '.json'
for fmt in args.results_format:
save_results(df, results_filename, fmt)
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
results_filename += _FMT_EXT[results_format]
if results_format == 'parquet':
df.set_index(filename_col).to_parquet(results_filename)
elif results_format == 'json':
df.to_json(results_filename, lines=True, orient='records')
elif args.results_format == 'json-split':
if needs_ext:
results_filename += '.json'
elif results_format == 'json-split':
df.to_json(results_filename, indent=4, orient='split', index=False)
else:
if needs_ext:
results_filename += '.csv'
df.to_csv(results_filename, index=False)

Loading…
Cancel
Save