From 7a92caa560c65fa641c761b5cd511714e2562946 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Jul 2019 10:51:03 -0700 Subject: [PATCH] Add basic image folder style dataset to read directly out of tar files, example in validate.py --- timm/data/__init__.py | 2 +- timm/data/dataset.py | 51 +++++++++++++++++++++++++++++++++++++++++++ validate.py | 11 +++++++--- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 94997658..74872aac 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .constants import * from .config import resolve_data_config -from .dataset import Dataset +from .dataset import Dataset, DatasetTar from .transforms import * from .loader import create_loader from .mixup import mixup_target, FastCollateMixup diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 57cccfc2..47437d5e 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -7,6 +7,7 @@ import torch.utils.data as data import os import re import torch +import tarfile from PIL import Image @@ -89,3 +90,53 @@ class Dataset(data.Dataset): return [os.path.basename(x[0]) for x in self.imgs] else: return [x[0] for x in self.imgs] + + +def _extract_tar_info(tarfile): + class_to_idx = {} + files = [] + labels = [] + for ti in tarfile.getmembers(): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + label = os.path.basename(dirname) + class_to_idx[label] = None + ext = os.path.splitext(basename)[1] + if ext.lower() in IMG_EXTENSIONS: + files.append(ti) + labels.append(label) + for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)): + class_to_idx[c] = idx + tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels]) + tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) + return tarinfo_and_targets + + +class DatasetTar(data.Dataset): + + def __init__(self, root, load_bytes=False, transform=None): + + assert os.path.isfile(root) + self.root = root + with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later + self.imgs = _extract_tar_info(tf) + self.tarfile = None # lazy init in __getitem__ + self.load_bytes = load_bytes + self.transform = transform + + def __getitem__(self, index): + if self.tarfile is None: + self.tarfile = tarfile.open(self.root) + tarinfo, target = self.imgs[index] + iob = self.tarfile.extractfile(tarinfo) + img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.zeros(1).long() + return img, target + + def __len__(self): + return len(self.imgs) + diff --git a/validate.py b/validate.py index d90d54fd..21dfcf89 100644 --- a/validate.py +++ b/validate.py @@ -14,7 +14,7 @@ import torch.nn.parallel from collections import OrderedDict from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models -from timm.data import Dataset, create_loader, resolve_data_config +from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging torch.backends.cudnn.benchmark = True @@ -24,7 +24,7 @@ parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', help='model architecture (default: dpn92)') -parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') @@ -91,9 +91,14 @@ def validate(args): criterion = nn.CrossEntropyLoss().cuda() + if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): + dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing) + else: + dataset = Dataset(args.data, load_bytes=args.tf_preprocessing) + crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( - Dataset(args.data, load_bytes=args.tf_preprocessing), + dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher,