Update inference script for new loader style

pull/1/head
Ross Wightman 6 years ago
parent 58571e992e
commit 1e23727f2f

@ -10,10 +10,10 @@ import time
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
import torch.utils.data as data
from models import create_model, transforms_imagenet_eval from models import create_model
from dataset import Dataset from data import Dataset, create_loader, get_model_meanstd
from utils import AverageMeter
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
@ -70,14 +70,15 @@ def main():
else: else:
model = model.cuda() model = model.cuda()
dataset = Dataset( data_mean, data_std = get_model_meanstd(args.model)
args.data, loader = create_loader(
transforms_imagenet_eval(args.model, args.img_size)) Dataset(args.data),
img_size=args.img_size,
loader = data.DataLoader( batch_size=args.batch_size,
dataset, use_prefetcher=True,
batch_size=args.batch_size, shuffle=False, mean=data_mean,
num_workers=args.workers, pin_memory=True) std=data_std,
num_workers=args.workers)
model.eval() model.eval()
@ -103,31 +104,12 @@ def main():
top5_ids = np.concatenate(top5_ids, axis=0).squeeze() top5_ids = np.concatenate(top5_ids, axis=0).squeeze()
with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file: with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file:
filenames = dataset.filenames() filenames = loader.dataset.filenames()
for filename, label in zip(filenames, top5_ids): for filename, label in zip(filenames, top5_ids):
filename = os.path.basename(filename) filename = os.path.basename(filename)
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
filename, label[0], label[1], label[2], label[3], label[4])) filename, label[0], label[1], label[2], label[3], label[4]))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if __name__ == '__main__': if __name__ == '__main__':
main() main()

Loading…
Cancel
Save