From c1a84ecb220ab2cf611b8e1ef6e35bb80152db12 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Jun 2019 21:59:56 -0700 Subject: [PATCH] dataset not passed through PrefetchLoader for inference script. Fix #10 * also, make top5 configurable for lower class count cases --- data/loader.py | 4 ++++ inference.py | 15 +++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/data/loader.py b/data/loader.py index 71ff3c9f..f6710a3b 100644 --- a/data/loader.py +++ b/data/loader.py @@ -61,6 +61,10 @@ class PrefetchLoader: def sampler(self): return self.loader.sampler + @property + def dataset(self): + return self.loader.dataset + @property def mixup_enabled(self): if isinstance(self.loader.collate_fn, FastCollateMixup): diff --git a/inference.py b/inference.py index ac57bbe2..d6c7e48a 100644 --- a/inference.py +++ b/inference.py @@ -48,6 +48,8 @@ parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool') +parser.add_argument('--topk', default=5, type=int, + metavar='N', help='Top-k to output to CSV') def main(): @@ -85,15 +87,16 @@ def main(): model.eval() + k = min(args.topk, args.num_classes) batch_time = AverageMeter() end = time.time() - top5_ids = [] + topk_ids = [] with torch.no_grad(): for batch_idx, (input, _) in enumerate(loader): input = input.cuda() labels = model(input) - top5 = labels.topk(5)[1] - top5_ids.append(top5.cpu().numpy()) + topk = labels.topk(k)[1] + topk_ids.append(topk.cpu().numpy()) # measure elapsed time batch_time.update(time.time() - end) @@ -104,11 +107,11 @@ def main(): 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) - top5_ids = np.concatenate(top5_ids, axis=0).squeeze() + topk_ids = np.concatenate(topk_ids, axis=0).squeeze() - with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file: + with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: filenames = loader.dataset.filenames() - for filename, label in zip(filenames, top5_ids): + for filename, label in zip(filenames, topk_ids): filename = os.path.basename(filename) out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( filename, label[0], label[1], label[2], label[3], label[4]))