dataset not passed through PrefetchLoader for inference script. Fix #10

* also, make top5 configurable for lower class count cases
pull/13/head
Ross Wightman 6 years ago
parent 2060e433c0
commit c1a84ecb22

@ -61,6 +61,10 @@ class PrefetchLoader:
def sampler(self): def sampler(self):
return self.loader.sampler return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property @property
def mixup_enabled(self): def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup): if isinstance(self.loader.collate_fn, FastCollateMixup):

@ -48,6 +48,8 @@ parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use') help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool') help='disable test time pool')
parser.add_argument('--topk', default=5, type=int,
metavar='N', help='Top-k to output to CSV')
def main(): def main():
@ -85,15 +87,16 @@ def main():
model.eval() model.eval()
k = min(args.topk, args.num_classes)
batch_time = AverageMeter() batch_time = AverageMeter()
end = time.time() end = time.time()
top5_ids = [] topk_ids = []
with torch.no_grad(): with torch.no_grad():
for batch_idx, (input, _) in enumerate(loader): for batch_idx, (input, _) in enumerate(loader):
input = input.cuda() input = input.cuda()
labels = model(input) labels = model(input)
top5 = labels.topk(5)[1] topk = labels.topk(k)[1]
top5_ids.append(top5.cpu().numpy()) topk_ids.append(topk.cpu().numpy())
# measure elapsed time # measure elapsed time
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
@ -104,11 +107,11 @@ def main():
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time)) batch_idx, len(loader), batch_time=batch_time))
top5_ids = np.concatenate(top5_ids, axis=0).squeeze() topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file: with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
filenames = loader.dataset.filenames() filenames = loader.dataset.filenames()
for filename, label in zip(filenames, top5_ids): for filename, label in zip(filenames, topk_ids):
filename = os.path.basename(filename) filename = os.path.basename(filename)
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
filename, label[0], label[1], label[2], label[3], label[4])) filename, label[0], label[1], label[2], label[3], label[4]))

Loading…
Cancel
Save