diff --git a/inference.py b/inference.py index 89efb1fb..445dc4e5 100755 --- a/inference.py +++ b/inference.py @@ -119,8 +119,8 @@ def main(): with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: filenames = loader.dataset.filenames(basename=True) for filename, label in zip(filenames, topk_ids): - out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( - filename, label[0], label[1], label[2], label[3], label[4])) + out_file.write('{0},{1}\n'.format( + filename, ','.join([ str(v) for v in label]))) if __name__ == '__main__':