diff --git a/inference.py b/inference.py index 445dc4e5..5fcf1e60 100755 --- a/inference.py +++ b/inference.py @@ -114,7 +114,7 @@ def main(): _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) - topk_ids = np.concatenate(topk_ids, axis=0).squeeze() + topk_ids = np.concatenate(topk_ids, axis=0) with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: filenames = loader.dataset.filenames(basename=True)