From ffa90e04d31e96c577abe7e3327f72b463957d04 Mon Sep 17 00:00:00 2001 From: comar Date: Mon, 10 May 2021 16:20:19 +0900 Subject: [PATCH 1/2] fix: the exception not using default topk argument --- inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference.py b/inference.py index 89efb1fb..445dc4e5 100755 --- a/inference.py +++ b/inference.py @@ -119,8 +119,8 @@ def main(): with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: filenames = loader.dataset.filenames(basename=True) for filename, label in zip(filenames, topk_ids): - out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( - filename, label[0], label[1], label[2], label[3], label[4])) + out_file.write('{0},{1}\n'.format( + filename, ','.join([ str(v) for v in label]))) if __name__ == '__main__': From d7e1e7144a62f0368787c081eb999d27cb97d159 Mon Sep 17 00:00:00 2001 From: comar Date: Mon, 10 May 2021 16:52:40 +0900 Subject: [PATCH 2/2] fix: the exeption when topk is 1 --- inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference.py b/inference.py index 445dc4e5..5fcf1e60 100755 --- a/inference.py +++ b/inference.py @@ -114,7 +114,7 @@ def main(): _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) - topk_ids = np.concatenate(topk_ids, axis=0).squeeze() + topk_ids = np.concatenate(topk_ids, axis=0) with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: filenames = loader.dataset.filenames(basename=True)