More work on dataset / parser split and imagenet21k (tar) support

pull/323/head
Ross Wightman 4 years ago
parent ce69de70d3
commit e35e9760a6

@ -5,32 +5,50 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch.utils.data as data import torch.utils.data as data
import os import os
import torch import torch
import logging
from .parsers import ParserImageFolder, ParserImageTar from PIL import Image
from .parsers import ParserImageFolder, ParserImageTar, ParserImageClassInTar
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
class ImageDataset(data.Dataset): class ImageDataset(data.Dataset):
def __init__( def __init__(
self, self,
img_root, root,
parser=None, parser=None,
class_map='', class_map='',
load_bytes=False, load_bytes=False,
transform=None, transform=None,
): ):
self.img_root = img_root
if parser is None: if parser is None:
if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar': if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map) parser = ParserImageTar(root, class_map=class_map)
else: else:
parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map) parser = ParserImageFolder(root, class_map=class_map)
self.parser = parser self.parser = parser
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.transform = transform self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index): def __getitem__(self, index):
img, target = self.parser[index] img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:

@ -1,4 +1,4 @@
from .parser import Parser from .parser import Parser
from .parser_image_folder import ParserImageFolder from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar from .parser_image_tar import ParserImageTar
from .parser_in21k_tar import ParserIn21kTar from .parser_image_class_in_tar import ParserImageClassInTar

@ -1,3 +1,4 @@
import os
def load_class_map(filename, root=''): def load_class_map(filename, root=''):

@ -1,3 +1 @@
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')

@ -0,0 +1,107 @@
import os
import tarfile
import pickle
from glob import glob
import numpy as np
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
def extract_tarinfos(root, class_name_to_idx=None, cache_filename=None, extensions=None):
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
assert len(tar_filenames)
num_tars = len(tar_filenames)
cache_path = ''
if cache_filename is not None:
cache_path = os.path.join(root, cache_filename)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as pf:
tarinfo_map = pickle.load(pf)
else:
tarinfo_map = {}
for fi, fn in enumerate(tar_filenames):
if fi % 1000 == 0:
print(f'DEBUG: tar {fi}/{num_tars}')
# cannot keep this open across processes, reopen later
name = os.path.splitext(os.path.basename(fn))[0]
with tarfile.open(fn) as tf:
if extensions is None:
# assume all files are valid samples
class_tarinfos = tf.getmembers()
else:
class_tarinfos = [m for m in tf.getmembers() if os.path.splitext(m.name)[1].lower() in extensions]
tarinfo_map[name] = dict(tarinfos=class_tarinfos)
print(f'DEBUG: {len(class_tarinfos)} images for class {name}')
tarinfo_map = {k: v for k, v in sorted(tarinfo_map.items(), key=lambda k: natural_key(k[0]))}
if cache_path:
with open(cache_path, 'wb') as pf:
pickle.dump(tarinfo_map, pf, protocol=pickle.HIGHEST_PROTOCOL)
tarinfos = []
targets = []
build_class_map = False
if class_name_to_idx is None:
class_name_to_idx = {}
build_class_map = True
for i, (name, metadata) in enumerate(tarinfo_map.items()):
class_idx = i
if build_class_map:
class_name_to_idx[name] = i
else:
if name not in class_name_to_idx:
# only samples with class in class mapping are added
continue
class_idx = class_name_to_idx[name]
num_samples = len(metadata['tarinfos'])
tarinfos.extend(metadata['tarinfos'])
targets.extend([class_idx] * num_samples)
return tarinfos, np.array(targets), class_name_to_idx
class ParserImageClassInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
"""
CACHE_FILENAME = '_tarinfos.pickle'
def __init__(self, root, class_map=''):
super().__init__()
class_name_to_idx = None
if class_map:
class_name_to_idx = load_class_map(class_map, root)
assert os.path.isdir(root)
self.root = root
self.tarinfos, self.targets, self.class_name_to_idx = extract_tarinfos(
self.root, class_name_to_idx=class_name_to_idx,
cache_filename=self.CACHE_FILENAME, extensions=IMG_EXTENSIONS)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
self.tarfiles = {} # to open lazily
self.cache_tarfiles = False
def __len__(self):
return len(self.tarinfos)
def __getitem__(self, index):
tarinfo = self.tarinfos[index]
target = self.targets[index]
class_name = self.class_idx_to_name[target]
if self.cache_tarfiles:
tf = self.tarfiles.setdefault(
class_name, tarfile.open(os.path.join(self.root, class_name + '.tar')))
else:
tf = tarfile.open(os.path.join(self.root, class_name + '.tar'))
fileobj = tf.extractfile(tarinfo)
return fileobj, target
def _filename(self, index, basename=False, absolute=False):
filename = self.tarinfos[index].name
if basename:
filename = os.path.basename(filename)
return filename

@ -2,7 +2,6 @@ import os
import io import io
import torch import torch
from PIL import Image
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
from .parser import Parser from .parser import Parser
@ -37,25 +36,21 @@ class ParserImageFolder(Parser):
def __init__( def __init__(
self, self,
root, root,
load_bytes=False,
class_map=''): class_map=''):
super().__init__() super().__init__()
self.root = root self.root = root
self.load_bytes = load_bytes
class_to_idx = None class_to_idx = None
if class_map: if class_map:
class_to_idx = load_class_map(class_map, root) class_to_idx = load_class_map(class_map, root)
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
if len(self.samples) == 0: if len(self.samples) == 0:
raise RuntimeError(f'Found 0 images in subfolders of {root}. ' raise RuntimeError(
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}') f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.samples[index] path, target = self.samples[index]
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') return open(path, 'rb'), target
return img, target
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)

@ -1,16 +1,13 @@
import os import os
import io
import torch
import tarfile import tarfile
from .parser import Parser from .parser import Parser
from .class_map import load_class_map from .class_map import load_class_map
from .constants import IMG_EXTENSIONS from .constants import IMG_EXTENSIONS
from PIL import Image
from timm.utils.misc import natural_key from timm.utils.misc import natural_key
def extract_tar_info(tarfile, class_to_idx=None, sort=True): def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
files = [] files = []
labels = [] labels = []
for ti in tarfile.getmembers(): for ti in tarfile.getmembers():
@ -33,8 +30,9 @@ def extract_tar_info(tarfile, class_to_idx=None, sort=True):
class ParserImageTar(Parser): class ParserImageTar(Parser):
""" Single tarfile dataset where classes are mapped to folders within tar
def __init__(self, root, load_bytes=False, class_map=''): """
def __init__(self, root, class_map=''):
super().__init__() super().__init__()
class_to_idx = None class_to_idx = None
@ -42,19 +40,18 @@ class ParserImageTar(Parser):
class_to_idx = load_class_map(class_map, root) class_to_idx = load_class_map(class_map, root)
assert os.path.isfile(root) assert os.path.isfile(root)
self.root = root self.root = root
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
self.samples, self.class_to_idx = extract_tar_info(tf, class_to_idx) self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
self.imgs = self.samples self.imgs = self.samples
self.tarfile = None # lazy init in __getitem__ self.tarfile = None # lazy init in __getitem__
self.load_bytes = load_bytes
def __getitem__(self, index): def __getitem__(self, index):
if self.tarfile is None: if self.tarfile is None:
self.tarfile = tarfile.open(self.root) self.tarfile = tarfile.open(self.root)
tarinfo, target = self.samples[index] tarinfo, target = self.samples[index]
iob = self.tarfile.extractfile(tarinfo) fileobj = self.tarfile.extractfile(tarinfo)
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') return fileobj, target
return img, target
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)

@ -1,104 +0,0 @@
import os
import io
import re
import torch
import tarfile
import pickle
from glob import glob
import numpy as np
import torch.utils.data as data
from timm.utils.misc import natural_key
from .constants import IMG_EXTENSIONS
def load_class_map(filename, root=''):
class_map_path = filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
return class_to_idx
class ParserIn21kTar(data.Dataset):
CACHE_FILENAME = 'class_info.pickle'
def __init__(self, root, class_map=''):
class_to_idx = None
if class_map:
class_to_idx = load_class_map(class_map, root)
assert os.path.isdir(root)
self.root = root
tar_filenames = glob(os.path.join(self.root, '*.tar'), recursive=True)
assert len(tar_filenames)
num_tars = len(tar_filenames)
if os.path.exists(self.CACHE_FILENAME):
with open(self.CACHE_FILENAME, 'rb') as pf:
class_info = pickle.load(pf)
else:
class_info = {}
for fi, fn in enumerate(tar_filenames):
if fi % 1000 == 0:
print(f'DEBUG: tar {fi}/{num_tars}')
# cannot keep this open across processes, reopen later
name = os.path.splitext(os.path.basename(fn))[0]
img_tarinfos = []
with tarfile.open(fn) as tf:
img_tarinfos.extend(tf.getmembers())
class_info[name] = dict(img_tarinfos=img_tarinfos)
print(f'DEBUG: {len(img_tarinfos)} images for synset {name}')
class_info = {k: v for k, v in sorted(class_info.items())}
with open('class_info.pickle', 'wb') as pf:
pickle.dump(class_info, pf, protocol=pickle.HIGHEST_PROTOCOL)
if class_to_idx is not None:
out_dict = {}
for k, v in class_info.items():
if k in class_to_idx:
class_idx = class_to_idx[k]
v['class_idx'] = class_idx
out_dict[k] = v
class_info = {k: v for k, v in sorted(out_dict.items(), key=lambda x: x[1]['class_idx'])}
else:
for i, (k, v) in enumerate(class_info.items()):
v['class_idx'] = i
self.img_infos = []
self.targets = []
self.tarnames = []
for k, v in class_info.items():
num_samples = len(v['img_tarinfos'])
self.img_infos.extend(v['img_tarinfos'])
self.targets.extend([v['class_idx']] * num_samples)
self.tarnames.extend([k] * num_samples)
self.targets = np.array(self.targets) # separate, uniform np array are more memory efficient
self.tarnames = np.array(self.tarnames)
self.tarfiles = {} # to open lazily
del class_info
def __len__(self):
return len(self.img_infos)
def __getitem__(self, idx):
img_tarinfo = self.img_infos[idx]
name = self.tarnames[idx]
tf = self.tarfiles.setdefault(name, tarfile.open(os.path.join(self.root, name + '.tar')))
img_bytes = tf.extractfile(img_tarinfo)
if self.targets:
target = self.targets[idx]
else:
target = None
return img_bytes, target
Loading…
Cancel
Save