From 1daa303744763c141a137c589aa6068c174aa669 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 1 Feb 2020 18:07:32 -0800 Subject: [PATCH] Add support to Dataset for class id mapping file, clean up a bit of old logic. Add results file arg for validation and update script. --- timm/data/dataset.py | 86 ++++++++++++++++++++++++++------------------ validate.py | 36 +++++++++++++------ 2 files changed, 77 insertions(+), 45 deletions(-) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index fc252d9e..2ce79e7e 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -20,34 +20,40 @@ def natural_key(string_): def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): - if class_to_idx is None: - class_to_idx = dict() - build_class_idx = True - else: - build_class_idx = False labels = [] filenames = [] for root, subdirs, files in os.walk(folder, topdown=False): rel_path = os.path.relpath(root, folder) if (root != folder) else '' label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') - if build_class_idx and not subdirs: - class_to_idx[label] = None for f in files: base, ext = os.path.splitext(f) if ext.lower() in types: filenames.append(os.path.join(root, f)) labels.append(label) - if build_class_idx: - classes = sorted(class_to_idx.keys(), key=natural_key) - for idx, c in enumerate(classes): - class_to_idx[c] = idx + if class_to_idx is None: + # building class index + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} images_and_targets = zip(filenames, [class_to_idx[l] for l in labels]) if sort: images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) - if build_class_idx: - return images_and_targets, classes, class_to_idx + return images_and_targets, class_to_idx + + +def load_class_map(filename, root=''): + class_to_idx = {} + class_map_path = filename + if not os.path.exists(class_map_path): + class_map_path = os.path.join(root, filename) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename + class_map_ext = os.path.splitext(filename)[-1].lower() + if class_map_ext == '.txt': + with open(class_map_path) as f: + class_to_idx = {v.strip(): k for k, v in enumerate(f)} else: - return images_and_targets + assert False, 'Unsupported class map extension' + return class_to_idx class Dataset(data.Dataset): @@ -56,19 +62,25 @@ class Dataset(data.Dataset): self, root, load_bytes=False, - transform=None): - - imgs, _, _ = find_images_and_targets(root) - if len(imgs) == 0: + transform=None, + class_map=''): + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) + if len(images) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root - self.imgs = imgs + self.samples = images + self.imgs = self.samples # torchvision ImageFolder compat + self.class_to_idx = class_to_idx self.load_bytes = load_bytes self.transform = transform def __getitem__(self, index): - path, target = self.imgs[index] + path, target = self.samples[index] img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') if self.transform is not None: img = self.transform(img) @@ -82,18 +94,17 @@ class Dataset(data.Dataset): def filenames(self, indices=[], basename=False): if indices: if basename: - return [os.path.basename(self.imgs[i][0]) for i in indices] + return [os.path.basename(self.samples[i][0]) for i in indices] else: - return [self.imgs[i][0] for i in indices] + return [self.samples[i][0] for i in indices] else: if basename: - return [os.path.basename(x[0]) for x in self.imgs] + return [os.path.basename(x[0]) for x in self.samples] else: - return [x[0] for x in self.imgs] + return [x[0] for x in self.samples] -def _extract_tar_info(tarfile): - class_to_idx = {} +def _extract_tar_info(tarfile, class_to_idx=None, sort=True): files = [] labels = [] for ti in tarfile.getmembers(): @@ -101,26 +112,31 @@ def _extract_tar_info(tarfile): continue dirname, basename = os.path.split(ti.path) label = os.path.basename(dirname) - class_to_idx[label] = None ext = os.path.splitext(basename)[1] if ext.lower() in IMG_EXTENSIONS: files.append(ti) labels.append(label) - for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)): - class_to_idx[c] = idx + if class_to_idx is None: + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels]) - tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) - return tarinfo_and_targets + if sort: + tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) + return tarinfo_and_targets, class_to_idx class DatasetTar(data.Dataset): - def __init__(self, root, load_bytes=False, transform=None): + def __init__(self, root, load_bytes=False, transform=None, class_map=''): + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) assert os.path.isfile(root) self.root = root with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later - self.imgs = _extract_tar_info(tf) + self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx) self.tarfile = None # lazy init in __getitem__ self.load_bytes = load_bytes self.transform = transform @@ -128,7 +144,7 @@ class DatasetTar(data.Dataset): def __getitem__(self, index): if self.tarfile is None: self.tarfile = tarfile.open(self.root) - tarinfo, target = self.imgs[index] + tarinfo, target = self.samples[index] iob = self.tarfile.extractfile(tarinfo) img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') if self.transform is not None: @@ -138,7 +154,7 @@ class DatasetTar(data.Dataset): return img, target def __len__(self): - return len(self.imgs) + return len(self.samples) class AugMixDataset(torch.utils.data.Dataset): diff --git a/validate.py b/validate.py index 93a82021..7993faaa 100755 --- a/validate.py +++ b/validate.py @@ -45,6 +45,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') parser.add_argument('--log-freq', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -67,6 +69,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for validation results (summary)') def validate(args): @@ -104,10 +108,12 @@ def validate(args): criterion = nn.CrossEntropyLoss().cuda() + #from torchvision.datasets import ImageNet + #dataset = ImageNet(args.data, split='val') if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): - dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing) + dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) else: - dataset = Dataset(args.data, load_bytes=args.tf_preprocessing) + dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( @@ -201,9 +207,10 @@ def main(): model_cfgs = [(n, '') for n in model_names] if len(model_cfgs): + results_file = args.results_file or './results-all.csv' logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) - header_written = False - with open('./results-all.csv', mode='w') as cf: + results = [] + try: for m, c in model_cfgs: args.model = m args.checkpoint = c @@ -212,15 +219,24 @@ def main(): result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint - dw = csv.DictWriter(cf, fieldnames=result.keys()) - if not header_written: - dw.writeheader() - header_written = True - dw.writerow(result) - cf.flush() + results.append(result) + except KeyboardInterrupt as e: + pass + results = sorted(results, key=lambda x: x['top1'], reverse=True) + if len(results): + write_results(results_file, results) else: validate(args) +def write_results(results_file, results): + with open(results_file, mode='w') as cf: + dw = csv.DictWriter(cf, fieldnames=results[0].keys()) + dw.writeheader() + for r in results: + dw.writerow(r) + cf.flush() + + if __name__ == '__main__': main()