From 1e23727f2fd43459e746b56b5251d1b73601252a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 5 Apr 2019 11:58:16 -0700 Subject: [PATCH] Update inference script for new loader style --- inference.py | 44 +++++++++++++------------------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/inference.py b/inference.py index 0afb529e..d6d4d385 100644 --- a/inference.py +++ b/inference.py @@ -10,10 +10,10 @@ import time import argparse import numpy as np import torch -import torch.utils.data as data -from models import create_model, transforms_imagenet_eval -from dataset import Dataset +from models import create_model +from data import Dataset, create_loader, get_model_meanstd +from utils import AverageMeter parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') @@ -70,14 +70,15 @@ def main(): else: model = model.cuda() - dataset = Dataset( - args.data, - transforms_imagenet_eval(args.model, args.img_size)) - - loader = data.DataLoader( - dataset, - batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True) + data_mean, data_std = get_model_meanstd(args.model) + loader = create_loader( + Dataset(args.data), + img_size=args.img_size, + batch_size=args.batch_size, + use_prefetcher=True, + mean=data_mean, + std=data_std, + num_workers=args.workers) model.eval() @@ -103,31 +104,12 @@ def main(): top5_ids = np.concatenate(top5_ids, axis=0).squeeze() with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file: - filenames = dataset.filenames() + filenames = loader.dataset.filenames() for filename, label in zip(filenames, top5_ids): filename = os.path.basename(filename) out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( filename, label[0], label[1], label[2], label[3], label[4])) -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - if __name__ == '__main__': main()