More work on dataset / parser split and imagenet21k (tar) support

pull/323/head
Ross Wightman 4 years ago
parent ce69de70d3
commit e35e9760a6

@ -5,32 +5,50 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch.utils.data as data
import os
import torch
import logging
from .parsers import ParserImageFolder, ParserImageTar
from PIL import Image
from .parsers import ParserImageFolder, ParserImageTar, ParserImageClassInTar
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
class ImageDataset(data.Dataset):
def __init__(
self,
img_root,
root,
parser=None,
class_map='',
load_bytes=False,
transform=None,
):
self.img_root = img_root
if parser is None:
if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar':
parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map)
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageTar(root, class_map=class_map)
else:
parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map)
parser = ParserImageFolder(root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:

@ -1,4 +1,4 @@
from .parser import Parser
from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar
from .parser_in21k_tar import ParserIn21kTar
from .parser_image_class_in_tar import ParserImageClassInTar

@ -1,3 +1,4 @@
import os
def load_class_map(filename, root=''):

@ -1,3 +1 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,107 @@
import os
import tarfile
import pickle
from glob import glob
import numpy as np
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
def extract_tarinfos(root, class_name_to_idx=None, cache_filename=None, extensions=None):
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
assert len(tar_filenames)
num_tars = len(tar_filenames)
cache_path = ''
if cache_filename is not None:
cache_path = os.path.join(root, cache_filename)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as pf:
tarinfo_map = pickle.load(pf)
else:
tarinfo_map = {}
for fi, fn in enumerate(tar_filenames):
if fi % 1000 == 0:
print(f'DEBUG: tar {fi}/{num_tars}')
# cannot keep this open across processes, reopen later
name = os.path.splitext(os.path.basename(fn))[0]
with tarfile.open(fn) as tf:
if extensions is None:
# assume all files are valid samples
class_tarinfos = tf.getmembers()
else:
class_tarinfos = [m for m in tf.getmembers() if os.path.splitext(m.name)[1].lower() in extensions]
tarinfo_map[name] = dict(tarinfos=class_tarinfos)
print(f'DEBUG: {len(class_tarinfos)} images for class {name}')
tarinfo_map = {k: v for k, v in sorted(tarinfo_map.items(), key=lambda k: natural_key(k[0]))}
if cache_path:
with open(cache_path, 'wb') as pf:
pickle.dump(tarinfo_map, pf, protocol=pickle.HIGHEST_PROTOCOL)
tarinfos = []
targets = []
build_class_map = False
if class_name_to_idx is None:
class_name_to_idx = {}
build_class_map = True
for i, (name, metadata) in enumerate(tarinfo_map.items()):
class_idx = i
if build_class_map:
class_name_to_idx[name] = i
else:
if name not in class_name_to_idx:
# only samples with class in class mapping are added
continue
class_idx = class_name_to_idx[name]
num_samples = len(metadata['tarinfos'])
tarinfos.extend(metadata['tarinfos'])
targets.extend([class_idx] * num_samples)
return tarinfos, np.array(targets), class_name_to_idx
class ParserImageClassInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
"""
CACHE_FILENAME = '_tarinfos.pickle'
def __init__(self, root, class_map=''):
super().__init__()
class_name_to_idx = None
if class_map:
class_name_to_idx = load_class_map(class_map, root)
assert os.path.isdir(root)
self.root = root
self.tarinfos, self.targets, self.class_name_to_idx = extract_tarinfos(
self.root, class_name_to_idx=class_name_to_idx,
cache_filename=self.CACHE_FILENAME, extensions=IMG_EXTENSIONS)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
self.tarfiles = {} # to open lazily
self.cache_tarfiles = False
def __len__(self):
return len(self.tarinfos)
def __getitem__(self, index):
tarinfo = self.tarinfos[index]
target = self.targets[index]
class_name = self.class_idx_to_name[target]
if self.cache_tarfiles:
tf = self.tarfiles.setdefault(
class_name, tarfile.open(os.path.join(self.root, class_name + '.tar')))
else:
tf = tarfile.open(os.path.join(self.root, class_name + '.tar'))
fileobj = tf.extractfile(tarinfo)
return fileobj, target
def _filename(self, index, basename=False, absolute=False):
filename = self.tarinfos[index].name
if basename:
filename = os.path.basename(filename)
return filename

@ -2,7 +2,6 @@ import os
import io
import torch
from PIL import Image
from timm.utils.misc import natural_key
from .parser import Parser
@ -37,25 +36,21 @@ class ParserImageFolder(Parser):
def __init__(
self,
root,
load_bytes=False,
class_map=''):
super().__init__()
self.root = root
self.load_bytes = load_bytes
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0:
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
raise RuntimeError(
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
def __getitem__(self, index):
path, target = self.samples[index]
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
return img, target
return open(path, 'rb'), target
def __len__(self):
return len(self.samples)

@ -1,16 +1,13 @@
import os
import io
import torch
import tarfile
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
from PIL import Image
from timm.utils.misc import natural_key
def extract_tar_info(tarfile, class_to_idx=None, sort=True):
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
files = []
labels = []
for ti in tarfile.getmembers():
@ -33,8 +30,9 @@ def extract_tar_info(tarfile, class_to_idx=None, sort=True):
class ParserImageTar(Parser):
def __init__(self, root, load_bytes=False, class_map=''):
""" Single tarfile dataset where classes are mapped to folders within tar
"""
def __init__(self, root, class_map=''):
super().__init__()
class_to_idx = None
@ -42,19 +40,18 @@ class ParserImageTar(Parser):
class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root)
self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = extract_tar_info(tf, class_to_idx)
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
def __getitem__(self, index):
if self.tarfile is None:
self.tarfile = tarfile.open(self.root)
tarinfo, target = self.samples[index]
iob = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
return img, target
fileobj = self.tarfile.extractfile(tarinfo)
return fileobj, target
def __len__(self):
return len(self.samples)

@ -1,104 +0,0 @@
import os
import io
import re
import torch
import tarfile
import pickle
from glob import glob
import numpy as np
import torch.utils.data as data
from timm.utils.misc import natural_key
from .constants import IMG_EXTENSIONS
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx
class ParserIn21kTar(data.Dataset):
CACHE_FILENAME = 'class_info.pickle'
def __init__(self, root, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isdir(root)
self.root = root
tar_filenames = glob(os.path.join(self.root, '*.tar'), recursive=True)
assert len(tar_filenames)
num_tars = len(tar_filenames)
if os.path.exists(self.CACHE_FILENAME):
with open(self.CACHE_FILENAME, 'rb') as pf:
class_info = pickle.load(pf)
else:
class_info = {}
for fi, fn in enumerate(tar_filenames):
if fi % 1000 == 0:
print(f'DEBUG: tar {fi}/{num_tars}')
# cannot keep this open across processes, reopen later
name = os.path.splitext(os.path.basename(fn))[0]
img_tarinfos = []
with tarfile.open(fn) as tf:
img_tarinfos.extend(tf.getmembers())
class_info[name] = dict(img_tarinfos=img_tarinfos)
print(f'DEBUG: {len(img_tarinfos)} images for synset {name}')
class_info = {k: v for k, v in sorted(class_info.items())}
with open('class_info.pickle', 'wb') as pf:
pickle.dump(class_info, pf, protocol=pickle.HIGHEST_PROTOCOL)
if class_to_idx is not None:
out_dict = {}
for k, v in class_info.items():
if k in class_to_idx:
class_idx = class_to_idx[k]
v['class_idx'] = class_idx
out_dict[k] = v
class_info = {k: v for k, v in sorted(out_dict.items(), key=lambda x: x[1]['class_idx'])}
else:
for i, (k, v) in enumerate(class_info.items()):
v['class_idx'] = i
self.img_infos = []
self.targets = []
self.tarnames = []
for k, v in class_info.items():
num_samples = len(v['img_tarinfos'])
self.img_infos.extend(v['img_tarinfos'])
self.targets.extend([v['class_idx']] * num_samples)
self.tarnames.extend([k] * num_samples)
self.targets = np.array(self.targets) # separate, uniform np array are more memory efficient
self.tarnames = np.array(self.tarnames)
self.tarfiles = {} # to open lazily
del class_info
def __len__(self):
return len(self.img_infos)
def __getitem__(self, idx):
img_tarinfo = self.img_infos[idx]
name = self.tarnames[idx]
tf = self.tarfiles.setdefault(name, tarfile.open(os.path.join(self.root, name + '.tar')))
img_bytes = tf.extractfile(img_tarinfo)
if self.targets:
target = self.targets[idx]
else:
target = None
return img_bytes, target
Loading…
Cancel
Save