Add basic image folder style dataset to read directly out of tar files, example in validate.py

pull/23/head
Ross Wightman 5 years ago
parent d6ac5bbc48
commit 7a92caa560

@ -1,6 +1,6 @@
from .constants import *
from .config import resolve_data_config
from .dataset import Dataset
from .dataset import Dataset, DatasetTar
from .transforms import *
from .loader import create_loader
from .mixup import mixup_target, FastCollateMixup

@ -7,6 +7,7 @@ import torch.utils.data as data
import os
import re
import torch
import tarfile
from PIL import Image
@ -89,3 +90,53 @@ class Dataset(data.Dataset):
return [os.path.basename(x[0]) for x in self.imgs]
else:
return [x[0] for x in self.imgs]
def _extract_tar_info(tarfile):
class_to_idx = {}
files = []
labels = []
for ti in tarfile.getmembers():
if not ti.isfile():
continue
dirname, basename = os.path.split(ti.path)
label = os.path.basename(dirname)
class_to_idx[label] = None
ext = os.path.splitext(basename)[1]
if ext.lower() in IMG_EXTENSIONS:
files.append(ti)
labels.append(label)
for idx, c in enumerate(sorted(class_to_idx.keys(), key=natural_key)):
class_to_idx[c] = idx
tarinfo_and_targets = zip(files, [class_to_idx[l] for l in labels])
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
return tarinfo_and_targets
class DatasetTar(data.Dataset):
def __init__(self, root, load_bytes=False, transform=None):
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.imgs = _extract_tar_info(tf)
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
self.transform = transform
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.imgs[index]
iob = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.zeros(1).long()
return img, target
def __len__(self):
return len(self.imgs)

@ -14,7 +14,7 @@ import torch.nn.parallel
from collections import OrderedDict
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, create_loader, resolve_data_config
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
torch.backends.cudnn.benchmark = True
@ -24,7 +24,7 @@ parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
@ -91,9 +91,14 @@ def validate(args):
criterion = nn.CrossEntropyLoss().cuda()
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
else:
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
Dataset(args.data, load_bytes=args.tf_preprocessing),
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,

Loading…
Cancel
Save