Add enhanced ParserImageInTar that can read images from tars within tars, folders with multiple tars, etc. Additional comment cleanup.

pull/323/head
Ross Wightman 4 years ago
parent 55f7dfa9ea
commit 5d4c3d0af3

@ -2,7 +2,7 @@ import os
from .parser_image_folder import ParserImageFolder from .parser_image_folder import ParserImageFolder
from .parser_image_tar import ParserImageTar from .parser_image_tar import ParserImageTar
from .parser_image_class_in_tar import ParserImageClassInTar from .parser_image_in_tar import ParserImageInTar
def create_parser(name, root, split='train', **kwargs): def create_parser(name, root, split='train', **kwargs):
@ -23,7 +23,7 @@ def create_parser(name, root, split='train', **kwargs):
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
# FIXME support split here, in parser? # FIXME support split here, in parser?
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
parser = ParserImageTar(root, **kwargs) parser = ParserImageInTar(root, **kwargs)
else: else:
parser = ParserImageFolder(root, **kwargs) parser = ParserImageFolder(root, **kwargs)
return parser return parser

@ -1,107 +0,0 @@
import os
import tarfile
import pickle
from glob import glob
import numpy as np
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
def extract_tarinfos(root, class_name_to_idx=None, cache_filename=None, extensions=None):
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
assert len(tar_filenames)
num_tars = len(tar_filenames)
cache_path = ''
if cache_filename is not None:
cache_path = os.path.join(root, cache_filename)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as pf:
tarinfo_map = pickle.load(pf)
else:
tarinfo_map = {}
for fi, fn in enumerate(tar_filenames):
if fi % 1000 == 0:
print(f'DEBUG: tar {fi}/{num_tars}')
# cannot keep this open across processes, reopen later
name = os.path.splitext(os.path.basename(fn))[0]
with tarfile.open(fn) as tf:
if extensions is None:
# assume all files are valid samples
class_tarinfos = tf.getmembers()
else:
class_tarinfos = [m for m in tf.getmembers() if os.path.splitext(m.name)[1].lower() in extensions]
tarinfo_map[name] = dict(tarinfos=class_tarinfos)
print(f'DEBUG: {len(class_tarinfos)} images for class {name}')
tarinfo_map = {k: v for k, v in sorted(tarinfo_map.items(), key=lambda k: natural_key(k[0]))}
if cache_path:
with open(cache_path, 'wb') as pf:
pickle.dump(tarinfo_map, pf, protocol=pickle.HIGHEST_PROTOCOL)
tarinfos = []
targets = []
build_class_map = False
if class_name_to_idx is None:
class_name_to_idx = {}
build_class_map = True
for i, (name, metadata) in enumerate(tarinfo_map.items()):
class_idx = i
if build_class_map:
class_name_to_idx[name] = i
else:
if name not in class_name_to_idx:
# only samples with class in class mapping are added
continue
class_idx = class_name_to_idx[name]
num_samples = len(metadata['tarinfos'])
tarinfos.extend(metadata['tarinfos'])
targets.extend([class_idx] * num_samples)
return tarinfos, np.array(targets), class_name_to_idx
class ParserImageClassInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
"""
CACHE_FILENAME = '_tarinfos.pickle'
def __init__(self, root, class_map=''):
super().__init__()
class_name_to_idx = None
if class_map:
class_name_to_idx = load_class_map(class_map, root)
assert os.path.isdir(root)
self.root = root
self.tarinfos, self.targets, self.class_name_to_idx = extract_tarinfos(
self.root, class_name_to_idx=class_name_to_idx,
cache_filename=self.CACHE_FILENAME, extensions=IMG_EXTENSIONS)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
self.tarfiles = {} # to open lazily
self.cache_tarfiles = False
def __len__(self):
return len(self.tarinfos)
def __getitem__(self, index):
tarinfo = self.tarinfos[index]
target = self.targets[index]
class_name = self.class_idx_to_name[target]
if self.cache_tarfiles:
tf = self.tarfiles.setdefault(
class_name, tarfile.open(os.path.join(self.root, class_name + '.tar')))
else:
tf = tarfile.open(os.path.join(self.root, class_name + '.tar'))
fileobj = tf.extractfile(tarinfo)
return fileobj, target
def _filename(self, index, basename=False, absolute=False):
filename = self.tarinfos[index].name
if basename:
filename = os.path.basename(filename)
return filename

@ -1,6 +1,11 @@
""" A dataset parser that reads images from folders
Folders are scannerd recursively to find image files. Labels are based
on the folder hierarchy, just leaf folders by default.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os import os
import io
import torch
from timm.utils.misc import natural_key from timm.utils.misc import natural_key

@ -0,0 +1,219 @@
""" A dataset parser that reads tarfile based datasets
This parser can read and extract image samples from:
* a single tar of image files
* a folder of multiple tarfiles containing imagefiles
* a tar of tars containing image files
Labels are based on the combined folder and/or tar name structure.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import tarfile
import pickle
import logging
import numpy as np
from glob import glob
from typing import List, Dict
from timm.utils.misc import natural_key
from .parser import Parser
from .class_map import load_class_map
from .constants import IMG_EXTENSIONS
_logger = logging.getLogger(__name__)
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
class TarState:
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
self.tf: tarfile.TarFile = tf
self.ti: tarfile.TarInfo = ti
self.children: Dict[str, TarState] = {} # child states (tars within tars)
def reset(self):
self.tf = None
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
sample_count = 0
for i, ti in enumerate(tf):
if not ti.isfile():
continue
dirname, basename = os.path.split(ti.path)
name, ext = os.path.splitext(basename)
ext = ext.lower()
if ext == '.tar':
with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
child_info = dict(
name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
_logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
parent_info['children'].append(child_info)
elif ext in extensions:
parent_info['samples'].append(ti)
sample_count += 1
return sample_count
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
root_is_tar = False
if os.path.isfile(root):
assert os.path.splitext(root)[-1].lower() == '.tar'
tar_filenames = [root]
root, root_name = os.path.split(root)
root_name = os.path.splitext(root_name)[0]
root_is_tar = True
else:
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
num_tars = len(tar_filenames)
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
assert num_tars, f'No .tar files found at specified path ({root}).'
_logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
info = dict(tartrees=[])
cache_path = ''
if cache_tarinfo is None:
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
if cache_tarinfo:
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
cache_path = os.path.join(root, cache_filename)
if os.path.exists(cache_path):
_logger.info(f'Reading tar info from cache file {cache_path}.')
with open(cache_path, 'rb') as pf:
info = pickle.load(pf)
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
else:
for i, fn in enumerate(tar_filenames):
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
num_children = len(parent_info["children"])
_logger.debug(
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
info['tartrees'].append(parent_info)
if cache_path:
_logger.info(f'Writing tar info to cache file {cache_path}.')
with open(cache_path, 'wb') as pf:
pickle.dump(info, pf)
samples = []
labels = []
build_class_map = False
if class_name_to_idx is None:
build_class_map = True
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
# class map arg or from unique paths.
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
# this covers my current use cases and keeps things a little easier to test for now.
tarfiles = []
def _label_from_paths(*path, leaf_only=True):
path = os.path.join(*path).strip(os.path.sep)
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
def _add_samples(info, fn):
added = 0
for s in info['samples']:
label = _label_from_paths(info['path'], os.path.dirname(s.path))
if not build_class_map and label not in class_name_to_idx:
continue
samples.append((s, fn, info['ti']))
labels.append(label)
added += 1
return added
_logger.info(f'Collecting samples and building tar states.')
for parent_info in info['tartrees']:
# if tartree has children, we assume all samples are at the child level
tar_name = None if root_is_tar else parent_info['name']
tar_state = TarState()
parent_added = 0
for child_info in parent_info['children']:
child_added = _add_samples(child_info, fn=tar_name)
if child_added:
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
parent_added += child_added
parent_added += _add_samples(parent_info, fn=tar_name)
if parent_added:
tarfiles.append((tar_name, tar_state))
del info
if build_class_map:
# build class index
sorted_labels = list(sorted(set(labels), key=natural_key))
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
_logger.info(f'Mapping targets and sorting samples.')
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
if sort:
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
_logger.info(f'Finished processing {len(samples_and_targets)} samples across {len(tarfiles)} tar files.')
return samples_and_targets, class_name_to_idx, tarfiles
class ParserImageInTar(Parser):
""" Multi-tarfile dataset parser where there is one .tar file per class
"""
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
super().__init__()
class_name_to_idx = None
if class_map:
class_name_to_idx = load_class_map(class_map, root)
self.root = root
self.samples_and_targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
self.root,
class_name_to_idx=class_name_to_idx,
cache_tarinfo=cache_tarinfo,
extensions=IMG_EXTENSIONS)
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
if len(tarfiles) == 1 and tarfiles[0][0] is None:
self.root_is_tar = True
self.tar_state = tarfiles[0][1]
else:
self.root_is_tar = False
self.tar_state = dict(tarfiles)
self.cache_tarfiles = cache_tarfiles
def __len__(self):
return len(self.samples_and_targets)
def __getitem__(self, index):
sample, target = self.samples_and_targets[index]
sample_ti, parent_fn, child_ti = sample
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
tf = None
cache_state = None
if self.cache_tarfiles:
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
tf = cache_state.tf
if tf is None:
tf = tarfile.open(parent_abs)
if self.cache_tarfiles:
cache_state.tf = tf
if child_ti is not None:
ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
if ctf is None:
ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
if self.cache_tarfiles:
cache_state.children[child_ti.name].tf = ctf
tf = ctf
return tf.extractfile(sample_ti), target
def _filename(self, index, basename=False, absolute=False):
filename = self.samples_and_targets[index][0][0].name
if basename:
filename = os.path.basename(filename)
return filename

@ -1,3 +1,10 @@
""" A dataset parser that reads single tarfile based datasets
This parser can read datasets consisting if a single tarfile containing images.
I am planning to deprecated it in favour of ParerImageInTar.
Hacked together by / Copyright 2020 Ross Wightman
"""
import os import os
import tarfile import tarfile
@ -31,6 +38,8 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
class ParserImageTar(Parser): class ParserImageTar(Parser):
""" Single tarfile dataset where classes are mapped to folders within tar """ Single tarfile dataset where classes are mapped to folders within tar
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
operate on folders of tars or tars in tars.
""" """
def __init__(self, root, class_map=''): def __init__(self, root, class_map=''):
super().__init__() super().__init__()

@ -37,14 +37,14 @@ class ParserTfds(Parser):
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413 https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is avoid by option above, for from each worker could be a different size. For training this is worked around by option above, for
validation extra samples are inserted iff distributed mode is enabled so the batches being reduced validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightlyalter the results, distributed validation will not be across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are to N * J extra samples. since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This may not be a huge concern as the you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets. benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS * This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
@ -64,8 +64,8 @@ class ParserTfds(Parser):
self.batch_size = batch_size self.batch_size = batch_size
self.builder = tfds.builder(name, data_dir=root) self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to trigger # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
# it by default here as it's caused issues generating unwanted paths in data directories. # download_and_prepare() by default here as it's caused issues generating unwanted paths.
self.num_samples = self.builder.info.splits[split].num_examples self.num_samples = self.builder.info.splits[split].num_examples
self.ds = None # initialized lazily on each dataloader worker process self.ds = None # initialized lazily on each dataloader worker process
@ -102,7 +102,7 @@ class ParserTfds(Parser):
""" """
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration but that could be wrong. between the splits each iteration, but that understanding could be wrong.
Possible split options include: Possible split options include:
* InputContext for both distributed & worker processes (current) * InputContext for both distributed & worker processes (current)
* InputContext for distributed and sub-splits for worker processes * InputContext for distributed and sub-splits for worker processes
@ -154,7 +154,7 @@ class ParserTfds(Parser):
sample_count += 1 sample_count += 1
if self.is_training and sample_count >= target_sample_count: if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling # Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in 'extra' samples per epoch but seems more desirable than dropping # this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break break
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:

@ -283,7 +283,7 @@ def _parse_args():
def main(): def main():
setup_default_logging(log_path='./train.log') setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher

Loading…
Cancel
Save