From e35e9760a64ee687960b8ae14c6b77c223d598aa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 28 Dec 2020 14:39:05 -0800 Subject: [PATCH] More work on dataset / parser split and imagenet21k (tar) support --- timm/data/dataset.py | 30 ++++- timm/data/parsers/__init__.py | 2 +- timm/data/parsers/class_map.py | 1 + timm/data/parsers/constants.py | 2 - .../data/parsers/parser_image_class_in_tar.py | 107 ++++++++++++++++++ timm/data/parsers/parser_image_folder.py | 11 +- timm/data/parsers/parser_image_tar.py | 19 ++-- timm/data/parsers/parser_in21k_tar.py | 104 ----------------- 8 files changed, 144 insertions(+), 132 deletions(-) create mode 100644 timm/data/parsers/parser_image_class_in_tar.py delete mode 100644 timm/data/parsers/parser_in21k_tar.py diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 8013c846..42a46eef 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -5,32 +5,50 @@ Hacked together by / Copyright 2020 Ross Wightman import torch.utils.data as data import os import torch +import logging -from .parsers import ParserImageFolder, ParserImageTar +from PIL import Image + +from .parsers import ParserImageFolder, ParserImageTar, ParserImageClassInTar + +_logger = logging.getLogger(__name__) + + +_ERROR_RETRY = 50 class ImageDataset(data.Dataset): def __init__( self, - img_root, + root, parser=None, class_map='', load_bytes=False, transform=None, ): - self.img_root = img_root if parser is None: - if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar': - parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map) + if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': + parser = ParserImageTar(root, class_map=class_map) else: - parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map) + parser = ParserImageFolder(root, class_map=class_map) self.parser = parser self.load_bytes = load_bytes self.transform = transform + self._consecutive_errors = 0 def __getitem__(self, index): img, target = self.parser[index] + try: + img = img.read() if self.load_bytes else Image.open(img).convert('RGB') + except Exception as e: + _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') + self._consecutive_errors += 1 + if self._consecutive_errors < _ERROR_RETRY: + return self.__getitem__((index + 1) % len(self.parser)) + else: + raise e + self._consecutive_errors = 0 if self.transform is not None: img = self.transform(img) if target is None: diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py index c502eec8..4ecb3a22 100644 --- a/timm/data/parsers/__init__.py +++ b/timm/data/parsers/__init__.py @@ -1,4 +1,4 @@ from .parser import Parser from .parser_image_folder import ParserImageFolder from .parser_image_tar import ParserImageTar -from .parser_in21k_tar import ParserIn21kTar \ No newline at end of file +from .parser_image_class_in_tar import ParserImageClassInTar \ No newline at end of file diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py index f5fa7e2a..9ef4d1fa 100644 --- a/timm/data/parsers/class_map.py +++ b/timm/data/parsers/class_map.py @@ -1,3 +1,4 @@ +import os def load_class_map(filename, root=''): diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py index 6e3be34b..e7ba484e 100644 --- a/timm/data/parsers/constants.py +++ b/timm/data/parsers/constants.py @@ -1,3 +1 @@ IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') - - diff --git a/timm/data/parsers/parser_image_class_in_tar.py b/timm/data/parsers/parser_image_class_in_tar.py new file mode 100644 index 00000000..f43ff359 --- /dev/null +++ b/timm/data/parsers/parser_image_class_in_tar.py @@ -0,0 +1,107 @@ +import os +import tarfile +import pickle +from glob import glob +import numpy as np + +from timm.utils.misc import natural_key + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS + + +def extract_tarinfos(root, class_name_to_idx=None, cache_filename=None, extensions=None): + tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True) + assert len(tar_filenames) + num_tars = len(tar_filenames) + + cache_path = '' + if cache_filename is not None: + cache_path = os.path.join(root, cache_filename) + if os.path.exists(cache_path): + with open(cache_path, 'rb') as pf: + tarinfo_map = pickle.load(pf) + else: + tarinfo_map = {} + for fi, fn in enumerate(tar_filenames): + if fi % 1000 == 0: + print(f'DEBUG: tar {fi}/{num_tars}') + # cannot keep this open across processes, reopen later + name = os.path.splitext(os.path.basename(fn))[0] + with tarfile.open(fn) as tf: + if extensions is None: + # assume all files are valid samples + class_tarinfos = tf.getmembers() + else: + class_tarinfos = [m for m in tf.getmembers() if os.path.splitext(m.name)[1].lower() in extensions] + tarinfo_map[name] = dict(tarinfos=class_tarinfos) + print(f'DEBUG: {len(class_tarinfos)} images for class {name}') + tarinfo_map = {k: v for k, v in sorted(tarinfo_map.items(), key=lambda k: natural_key(k[0]))} + if cache_path: + with open(cache_path, 'wb') as pf: + pickle.dump(tarinfo_map, pf, protocol=pickle.HIGHEST_PROTOCOL) + + tarinfos = [] + targets = [] + build_class_map = False + if class_name_to_idx is None: + class_name_to_idx = {} + build_class_map = True + for i, (name, metadata) in enumerate(tarinfo_map.items()): + class_idx = i + if build_class_map: + class_name_to_idx[name] = i + else: + if name not in class_name_to_idx: + # only samples with class in class mapping are added + continue + class_idx = class_name_to_idx[name] + num_samples = len(metadata['tarinfos']) + tarinfos.extend(metadata['tarinfos']) + targets.extend([class_idx] * num_samples) + + return tarinfos, np.array(targets), class_name_to_idx + + +class ParserImageClassInTar(Parser): + """ Multi-tarfile dataset parser where there is one .tar file per class + """ + + CACHE_FILENAME = '_tarinfos.pickle' + + def __init__(self, root, class_map=''): + super().__init__() + + class_name_to_idx = None + if class_map: + class_name_to_idx = load_class_map(class_map, root) + assert os.path.isdir(root) + self.root = root + self.tarinfos, self.targets, self.class_name_to_idx = extract_tarinfos( + self.root, class_name_to_idx=class_name_to_idx, + cache_filename=self.CACHE_FILENAME, extensions=IMG_EXTENSIONS) + self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} + self.tarfiles = {} # to open lazily + self.cache_tarfiles = False + + def __len__(self): + return len(self.tarinfos) + + def __getitem__(self, index): + tarinfo = self.tarinfos[index] + target = self.targets[index] + class_name = self.class_idx_to_name[target] + if self.cache_tarfiles: + tf = self.tarfiles.setdefault( + class_name, tarfile.open(os.path.join(self.root, class_name + '.tar'))) + else: + tf = tarfile.open(os.path.join(self.root, class_name + '.tar')) + fileobj = tf.extractfile(tarinfo) + return fileobj, target + + def _filename(self, index, basename=False, absolute=False): + filename = self.tarinfos[index].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py index 8a61007f..93b16e40 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/parsers/parser_image_folder.py @@ -2,7 +2,6 @@ import os import io import torch -from PIL import Image from timm.utils.misc import natural_key from .parser import Parser @@ -37,25 +36,21 @@ class ParserImageFolder(Parser): def __init__( self, root, - load_bytes=False, class_map=''): super().__init__() self.root = root - self.load_bytes = load_bytes - class_to_idx = None if class_map: class_to_idx = load_class_map(class_map, root) self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) if len(self.samples) == 0: - raise RuntimeError(f'Found 0 images in subfolders of {root}. ' - f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + raise RuntimeError( + f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') def __getitem__(self, index): path, target = self.samples[index] - img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB') - return img, target + return open(path, 'rb'), target def __len__(self): return len(self.samples) diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py index 504e71e8..657b56f9 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/parsers/parser_image_tar.py @@ -1,16 +1,13 @@ import os -import io -import torch import tarfile from .parser import Parser from .class_map import load_class_map from .constants import IMG_EXTENSIONS -from PIL import Image from timm.utils.misc import natural_key -def extract_tar_info(tarfile, class_to_idx=None, sort=True): +def extract_tarinfo(tarfile, class_to_idx=None, sort=True): files = [] labels = [] for ti in tarfile.getmembers(): @@ -33,8 +30,9 @@ def extract_tar_info(tarfile, class_to_idx=None, sort=True): class ParserImageTar(Parser): - - def __init__(self, root, load_bytes=False, class_map=''): + """ Single tarfile dataset where classes are mapped to folders within tar + """ + def __init__(self, root, class_map=''): super().__init__() class_to_idx = None @@ -42,19 +40,18 @@ class ParserImageTar(Parser): class_to_idx = load_class_map(class_map, root) assert os.path.isfile(root) self.root = root + with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later - self.samples, self.class_to_idx = extract_tar_info(tf, class_to_idx) + self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) self.imgs = self.samples self.tarfile = None # lazy init in __getitem__ - self.load_bytes = load_bytes def __getitem__(self, index): if self.tarfile is None: self.tarfile = tarfile.open(self.root) tarinfo, target = self.samples[index] - iob = self.tarfile.extractfile(tarinfo) - img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB') - return img, target + fileobj = self.tarfile.extractfile(tarinfo) + return fileobj, target def __len__(self): return len(self.samples) diff --git a/timm/data/parsers/parser_in21k_tar.py b/timm/data/parsers/parser_in21k_tar.py deleted file mode 100644 index da7e9d26..00000000 --- a/timm/data/parsers/parser_in21k_tar.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -import io -import re -import torch -import tarfile -import pickle -from glob import glob -import numpy as np - -import torch.utils.data as data - -from timm.utils.misc import natural_key - -from .constants import IMG_EXTENSIONS - - -def load_class_map(filename, root=''): - class_map_path = filename - if not os.path.exists(class_map_path): - class_map_path = os.path.join(root, filename) - assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename - class_map_ext = os.path.splitext(filename)[-1].lower() - if class_map_ext == '.txt': - with open(class_map_path) as f: - class_to_idx = {v.strip(): k for k, v in enumerate(f)} - else: - assert False, 'Unsupported class map extension' - return class_to_idx - - -class ParserIn21kTar(data.Dataset): - - CACHE_FILENAME = 'class_info.pickle' - - def __init__(self, root, class_map=''): - - class_to_idx = None - if class_map: - class_to_idx = load_class_map(class_map, root) - assert os.path.isdir(root) - self.root = root - tar_filenames = glob(os.path.join(self.root, '*.tar'), recursive=True) - assert len(tar_filenames) - num_tars = len(tar_filenames) - - if os.path.exists(self.CACHE_FILENAME): - with open(self.CACHE_FILENAME, 'rb') as pf: - class_info = pickle.load(pf) - else: - class_info = {} - for fi, fn in enumerate(tar_filenames): - if fi % 1000 == 0: - print(f'DEBUG: tar {fi}/{num_tars}') - # cannot keep this open across processes, reopen later - name = os.path.splitext(os.path.basename(fn))[0] - img_tarinfos = [] - with tarfile.open(fn) as tf: - img_tarinfos.extend(tf.getmembers()) - class_info[name] = dict(img_tarinfos=img_tarinfos) - print(f'DEBUG: {len(img_tarinfos)} images for synset {name}') - class_info = {k: v for k, v in sorted(class_info.items())} - - with open('class_info.pickle', 'wb') as pf: - pickle.dump(class_info, pf, protocol=pickle.HIGHEST_PROTOCOL) - - if class_to_idx is not None: - out_dict = {} - for k, v in class_info.items(): - if k in class_to_idx: - class_idx = class_to_idx[k] - v['class_idx'] = class_idx - out_dict[k] = v - class_info = {k: v for k, v in sorted(out_dict.items(), key=lambda x: x[1]['class_idx'])} - else: - for i, (k, v) in enumerate(class_info.items()): - v['class_idx'] = i - - self.img_infos = [] - self.targets = [] - self.tarnames = [] - for k, v in class_info.items(): - num_samples = len(v['img_tarinfos']) - self.img_infos.extend(v['img_tarinfos']) - self.targets.extend([v['class_idx']] * num_samples) - self.tarnames.extend([k] * num_samples) - self.targets = np.array(self.targets) # separate, uniform np array are more memory efficient - self.tarnames = np.array(self.tarnames) - - self.tarfiles = {} # to open lazily - del class_info - - def __len__(self): - return len(self.img_infos) - - def __getitem__(self, idx): - img_tarinfo = self.img_infos[idx] - name = self.tarnames[idx] - tf = self.tarfiles.setdefault(name, tarfile.open(os.path.join(self.root, name + '.tar'))) - img_bytes = tf.extractfile(img_tarinfo) - if self.targets: - target = self.targets[idx] - else: - target = None - return img_bytes, target