You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.9 KiB
108 lines
3.9 KiB
4 years ago
|
import os
|
||
|
import tarfile
|
||
|
import pickle
|
||
|
from glob import glob
|
||
|
import numpy as np
|
||
|
|
||
|
from timm.utils.misc import natural_key
|
||
|
|
||
|
from .parser import Parser
|
||
|
from .class_map import load_class_map
|
||
|
from .constants import IMG_EXTENSIONS
|
||
|
|
||
|
|
||
|
def extract_tarinfos(root, class_name_to_idx=None, cache_filename=None, extensions=None):
|
||
|
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
|
||
|
assert len(tar_filenames)
|
||
|
num_tars = len(tar_filenames)
|
||
|
|
||
|
cache_path = ''
|
||
|
if cache_filename is not None:
|
||
|
cache_path = os.path.join(root, cache_filename)
|
||
|
if os.path.exists(cache_path):
|
||
|
with open(cache_path, 'rb') as pf:
|
||
|
tarinfo_map = pickle.load(pf)
|
||
|
else:
|
||
|
tarinfo_map = {}
|
||
|
for fi, fn in enumerate(tar_filenames):
|
||
|
if fi % 1000 == 0:
|
||
|
print(f'DEBUG: tar {fi}/{num_tars}')
|
||
|
# cannot keep this open across processes, reopen later
|
||
|
name = os.path.splitext(os.path.basename(fn))[0]
|
||
|
with tarfile.open(fn) as tf:
|
||
|
if extensions is None:
|
||
|
# assume all files are valid samples
|
||
|
class_tarinfos = tf.getmembers()
|
||
|
else:
|
||
|
class_tarinfos = [m for m in tf.getmembers() if os.path.splitext(m.name)[1].lower() in extensions]
|
||
|
tarinfo_map[name] = dict(tarinfos=class_tarinfos)
|
||
|
print(f'DEBUG: {len(class_tarinfos)} images for class {name}')
|
||
|
tarinfo_map = {k: v for k, v in sorted(tarinfo_map.items(), key=lambda k: natural_key(k[0]))}
|
||
|
if cache_path:
|
||
|
with open(cache_path, 'wb') as pf:
|
||
|
pickle.dump(tarinfo_map, pf, protocol=pickle.HIGHEST_PROTOCOL)
|
||
|
|
||
|
tarinfos = []
|
||
|
targets = []
|
||
|
build_class_map = False
|
||
|
if class_name_to_idx is None:
|
||
|
class_name_to_idx = {}
|
||
|
build_class_map = True
|
||
|
for i, (name, metadata) in enumerate(tarinfo_map.items()):
|
||
|
class_idx = i
|
||
|
if build_class_map:
|
||
|
class_name_to_idx[name] = i
|
||
|
else:
|
||
|
if name not in class_name_to_idx:
|
||
|
# only samples with class in class mapping are added
|
||
|
continue
|
||
|
class_idx = class_name_to_idx[name]
|
||
|
num_samples = len(metadata['tarinfos'])
|
||
|
tarinfos.extend(metadata['tarinfos'])
|
||
|
targets.extend([class_idx] * num_samples)
|
||
|
|
||
|
return tarinfos, np.array(targets), class_name_to_idx
|
||
|
|
||
|
|
||
|
class ParserImageClassInTar(Parser):
|
||
|
""" Multi-tarfile dataset parser where there is one .tar file per class
|
||
|
"""
|
||
|
|
||
|
CACHE_FILENAME = '_tarinfos.pickle'
|
||
|
|
||
|
def __init__(self, root, class_map=''):
|
||
|
super().__init__()
|
||
|
|
||
|
class_name_to_idx = None
|
||
|
if class_map:
|
||
|
class_name_to_idx = load_class_map(class_map, root)
|
||
|
assert os.path.isdir(root)
|
||
|
self.root = root
|
||
|
self.tarinfos, self.targets, self.class_name_to_idx = extract_tarinfos(
|
||
|
self.root, class_name_to_idx=class_name_to_idx,
|
||
|
cache_filename=self.CACHE_FILENAME, extensions=IMG_EXTENSIONS)
|
||
|
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
||
|
self.tarfiles = {} # to open lazily
|
||
|
self.cache_tarfiles = False
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.tarinfos)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
tarinfo = self.tarinfos[index]
|
||
|
target = self.targets[index]
|
||
|
class_name = self.class_idx_to_name[target]
|
||
|
if self.cache_tarfiles:
|
||
|
tf = self.tarfiles.setdefault(
|
||
|
class_name, tarfile.open(os.path.join(self.root, class_name + '.tar')))
|
||
|
else:
|
||
|
tf = tarfile.open(os.path.join(self.root, class_name + '.tar'))
|
||
|
fileobj = tf.extractfile(tarinfo)
|
||
|
return fileobj, target
|
||
|
|
||
|
def _filename(self, index, basename=False, absolute=False):
|
||
|
filename = self.tarinfos[index].name
|
||
|
if basename:
|
||
|
filename = os.path.basename(filename)
|
||
|
return filename
|