fix: the exeption when topk is 1

pull/625/head
comar 3 years ago
parent ffa90e04d3
commit d7e1e7144a

@ -114,7 +114,7 @@ def main():
_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time))
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
topk_ids = np.concatenate(topk_ids, axis=0)
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
filenames = loader.dataset.filenames(basename=True)

Loading…
Cancel
Save