Add support to Dataset for class id mapping file, clean up a bit of old logic. Add results file arg for validation and update script.

pull/83/head
Ross Wightman 5 years ago
parent 91534522f9
commit 1daa303744

@ -20,34 +20,40 @@ def natural_key(string_):
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
if class_to_idx is None:
class_to_idx = dict()
build_class_idx = True
else:
build_class_idx = False
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
if build_class_idx and not subdirs:
class_to_idx[label] = None
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if build_class_idx:
classes = sorted(class_to_idx.keys(), key=natural_key)
for idx, c in enumerate(classes):
class_to_idx[c] = idx
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = zip(filenames, [class_to_idx[l] for l in labels])
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
if build_class_idx:
return images_and_targets, classes, class_to_idx
return images_and_targets, class_to_idx
def load_class_map(filename, root=''):
class_to_idx = {}
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
return images_and_targets
assert False, 'Unsupported class map extension'
return class_to_idx
class Dataset(data.Dataset):
@ -56,19 +62,25 @@ class Dataset(data.Dataset):
self,
root,
load_bytes=False,
transform=None):
imgs, _, _ = find_images_and_targets(root)
if len(imgs) == 0:
transform=None,
class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(images) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.samples = images
self.imgs = self.samples # torchvision ImageFolder compat
self.class_to_idx = class_to_idx
self.load_bytes = load_bytes
self.transform = transform
def __getitem__(self, index):
path, target = self.imgs[index]
path, target = self.samples[index]
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
@ -82,18 +94,17 @@ class Dataset(data.Dataset):
def filenames(self, indices=[], basename=False):
if indices:
if basename:
return [os.path.basename(self.imgs[i][0]) for i in indices]
return [os.path.basename(self.samples[i][0]) for i in indices]
else:
return [self.imgs[i][0] for i in indices]
return [self.samples[i][0] for i in indices]
else:
if basename:
return [os.path.basename(x[0]) for x in self.imgs]
return [os.path.basename(x[0]) for x in self.samples]
else:
return [x[0] for x in self.imgs]
return [x[0] for x in self.samples]
def _extract_tar_info(tarfile):
class_to_idx = {}
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
files = []
labels = []
for ti in tarfile.getmembers():
@ -101,26 +112,31 @@ def _extract_tar_info(tarfile):
continue
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
class_to_idx[label] = None
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
files.append(ti)
labels.append(label)
for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)):
class_to_idx[c] = idx
if class_to_idx is None:
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels])
if sort:
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets
return tarinfo_and_targets, class_to_idx
class DatasetTar(data.Dataset):
def __init__(self, root, load_bytes=False, transform=None):
def __init__(self, root, load_bytes=False, transform=None, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.imgs = _extract_tar_info(tf)
self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx)
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
self.transform = transform
@ -128,7 +144,7 @@ class DatasetTar(data.Dataset):
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.imgs[index]
tarinfo, target = self.samples[index]
iob = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
if self.transform is not None:
@ -138,7 +154,7 @@ class DatasetTar(data.Dataset):
return img, target
def __len__(self):
return len(self.imgs)
return len(self.samples)
class AugMixDataset(torch.utils.data.Dataset):

@ -45,6 +45,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -67,6 +69,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
def validate(args):
@ -104,10 +108,12 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda()
#from torchvision.datasets import ImageNet
#dataset = ImageNet(args.data, split='val')
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
else:
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
@ -201,9 +207,10 @@ def main():
model_cfgs = [(n, '') for n in model_names]
if len(model_cfgs):
results_file = args.results_file or './results-all.csv'
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
header_written = False
with open('./results-all.csv', mode='w') as cf:
results = []
try:
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
@ -212,15 +219,24 @@ def main():
result.update(r)
if args.checkpoint:
result['checkpoint'] = args.checkpoint
dw = csv.DictWriter(cf, fieldnames=result.keys())
if not header_written:
dw.writeheader()
header_written = True
dw.writerow(result)
cf.flush()
results.append(result)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
if len(results):
write_results(results_file, results)
else:
validate(args)
def write_results(results_file, results):
with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()

Loading…
Cancel
Save