diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index ce9aa35f..419ffe89 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -2,7 +2,7 @@ import os from .parser_image_folder import ParserImageFolder from .parser_image_tar import ParserImageTar -from .parser_image_class_in_tar import ParserImageClassInTar +from .parser_image_in_tar import ParserImageInTar def create_parser(name, root, split='train', **kwargs): @@ -23,7 +23,7 @@ def create_parser(name, root, split='train', **kwargs): # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder # FIXME support split here, in parser? if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': - parser = ParserImageTar(root, **kwargs) + parser = ParserImageInTar(root, **kwargs) else: parser = ParserImageFolder(root, **kwargs) return parser diff --git a/timm/data/parsers/parser_image_class_in_tar.py b/timm/data/parsers/parser_image_class_in_tar.py deleted file mode 100644 index f43ff359..00000000 --- a/timm/data/parsers/parser_image_class_in_tar.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -import tarfile -import pickle -from glob import glob -import numpy as np - -from timm.utils.misc import natural_key - -from .parser import Parser -from .class_map import load_class_map -from .constants import IMG_EXTENSIONS - - -def extract_tarinfos(root, class_name_to_idx=None, cache_filename=None, extensions=None): - tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True) - assert len(tar_filenames) - num_tars = len(tar_filenames) - - cache_path = '' - if cache_filename is not None: - cache_path = os.path.join(root, cache_filename) - if os.path.exists(cache_path): - with open(cache_path, 'rb') as pf: - tarinfo_map = pickle.load(pf) - else: - tarinfo_map = {} - for fi, fn in enumerate(tar_filenames): - if fi % 1000 == 0: - print(f'DEBUG: tar {fi}/{num_tars}') - # cannot keep this open across processes, reopen later - name = os.path.splitext(os.path.basename(fn))[0] - with tarfile.open(fn) as tf: - if extensions is None: - # assume all files are valid samples - class_tarinfos = tf.getmembers() - else: - class_tarinfos = [m for m in tf.getmembers() if os.path.splitext(m.name)[1].lower() in extensions] - tarinfo_map[name] = dict(tarinfos=class_tarinfos) - print(f'DEBUG: {len(class_tarinfos)} images for class {name}') - tarinfo_map = {k: v for k, v in sorted(tarinfo_map.items(), key=lambda k: natural_key(k[0]))} - if cache_path: - with open(cache_path, 'wb') as pf: - pickle.dump(tarinfo_map, pf, protocol=pickle.HIGHEST_PROTOCOL) - - tarinfos = [] - targets = [] - build_class_map = False - if class_name_to_idx is None: - class_name_to_idx = {} - build_class_map = True - for i, (name, metadata) in enumerate(tarinfo_map.items()): - class_idx = i - if build_class_map: - class_name_to_idx[name] = i - else: - if name not in class_name_to_idx: - # only samples with class in class mapping are added - continue - class_idx = class_name_to_idx[name] - num_samples = len(metadata['tarinfos']) - tarinfos.extend(metadata['tarinfos']) - targets.extend([class_idx] * num_samples) - - return tarinfos, np.array(targets), class_name_to_idx - - -class ParserImageClassInTar(Parser): - """ Multi-tarfile dataset parser where there is one .tar file per class - """ - - CACHE_FILENAME = '_tarinfos.pickle' - - def __init__(self, root, class_map=''): - super().__init__() - - class_name_to_idx = None - if class_map: - class_name_to_idx = load_class_map(class_map, root) - assert os.path.isdir(root) - self.root = root - self.tarinfos, self.targets, self.class_name_to_idx = extract_tarinfos( - self.root, class_name_to_idx=class_name_to_idx, - cache_filename=self.CACHE_FILENAME, extensions=IMG_EXTENSIONS) - self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} - self.tarfiles = {} # to open lazily - self.cache_tarfiles = False - - def __len__(self): - return len(self.tarinfos) - - def __getitem__(self, index): - tarinfo = self.tarinfos[index] - target = self.targets[index] - class_name = self.class_idx_to_name[target] - if self.cache_tarfiles: - tf = self.tarfiles.setdefault( - class_name, tarfile.open(os.path.join(self.root, class_name + '.tar'))) - else: - tf = tarfile.open(os.path.join(self.root, class_name + '.tar')) - fileobj = tf.extractfile(tarinfo) - return fileobj, target - - def _filename(self, index, basename=False, absolute=False): - filename = self.tarinfos[index].name - if basename: - filename = os.path.basename(filename) - return filename diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py index d2a597d9..ed349009 100644 --- a/timm/data/parsers/parser_image_folder.py +++ b/timm/data/parsers/parser_image_folder.py @@ -1,6 +1,11 @@ +""" A dataset parser that reads images from folders + +Folders are scannerd recursively to find image files. Labels are based +on the folder hierarchy, just leaf folders by default. + +Hacked together by / Copyright 2020 Ross Wightman +""" import os -import io -import torch from timm.utils.misc import natural_key diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py new file mode 100644 index 00000000..fd561bcb --- /dev/null +++ b/timm/data/parsers/parser_image_in_tar.py @@ -0,0 +1,219 @@ +""" A dataset parser that reads tarfile based datasets + +This parser can read and extract image samples from: +* a single tar of image files +* a folder of multiple tarfiles containing imagefiles +* a tar of tars containing image files + +Labels are based on the combined folder and/or tar name structure. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import tarfile +import pickle +import logging +import numpy as np +from glob import glob +from typing import List, Dict + +from timm.utils.misc import natural_key + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS + + +_logger = logging.getLogger(__name__) +CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' + + +class TarState: + + def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None): + self.tf: tarfile.TarFile = tf + self.ti: tarfile.TarInfo = ti + self.children: Dict[str, TarState] = {} # child states (tars within tars) + + def reset(self): + self.tf = None + + +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): + sample_count = 0 + for i, ti in enumerate(tf): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + name, ext = os.path.splitext(basename) + ext = ext.lower() + if ext == '.tar': + with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf: + child_info = dict( + name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[]) + sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions) + _logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.') + parent_info['children'].append(child_info) + elif ext in extensions: + parent_info['samples'].append(ti) + sample_count += 1 + return sample_count + + +def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): + root_is_tar = False + if os.path.isfile(root): + assert os.path.splitext(root)[-1].lower() == '.tar' + tar_filenames = [root] + root, root_name = os.path.split(root) + root_name = os.path.splitext(root_name)[0] + root_is_tar = True + else: + root_name = root.strip(os.path.sep).split(os.path.sep)[-1] + tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True) + num_tars = len(tar_filenames) + tar_bytes = sum([os.path.getsize(f) for f in tar_filenames]) + assert num_tars, f'No .tar files found at specified path ({root}).' + + _logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...') + info = dict(tartrees=[]) + cache_path = '' + if cache_tarinfo is None: + cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB + if cache_tarinfo: + cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX + cache_path = os.path.join(root, cache_filename) + if os.path.exists(cache_path): + _logger.info(f'Reading tar info from cache file {cache_path}.') + with open(cache_path, 'rb') as pf: + info = pickle.load(pf) + assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles" + else: + for i, fn in enumerate(tar_filenames): + path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0] + with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode + parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[]) + num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions) + num_children = len(parent_info["children"]) + _logger.debug( + f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.') + info['tartrees'].append(parent_info) + if cache_path: + _logger.info(f'Writing tar info to cache file {cache_path}.') + with open(cache_path, 'wb') as pf: + pickle.dump(info, pf) + + samples = [] + labels = [] + build_class_map = False + if class_name_to_idx is None: + build_class_map = True + + # Flatten tartree info into lists of samples and targets w/ targets based on label id via + # class map arg or from unique paths. + # NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children + # this covers my current use cases and keeps things a little easier to test for now. + tarfiles = [] + + def _label_from_paths(*path, leaf_only=True): + path = os.path.join(*path).strip(os.path.sep) + return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_') + + def _add_samples(info, fn): + added = 0 + for s in info['samples']: + label = _label_from_paths(info['path'], os.path.dirname(s.path)) + if not build_class_map and label not in class_name_to_idx: + continue + samples.append((s, fn, info['ti'])) + labels.append(label) + added += 1 + return added + + _logger.info(f'Collecting samples and building tar states.') + for parent_info in info['tartrees']: + # if tartree has children, we assume all samples are at the child level + tar_name = None if root_is_tar else parent_info['name'] + tar_state = TarState() + parent_added = 0 + for child_info in parent_info['children']: + child_added = _add_samples(child_info, fn=tar_name) + if child_added: + tar_state.children[child_info['name']] = TarState(ti=child_info['ti']) + parent_added += child_added + parent_added += _add_samples(parent_info, fn=tar_name) + if parent_added: + tarfiles.append((tar_name, tar_state)) + del info + + if build_class_map: + # build class index + sorted_labels = list(sorted(set(labels), key=natural_key)) + class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + + _logger.info(f'Mapping targets and sorting samples.') + samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx] + if sort: + samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path)) + + _logger.info(f'Finished processing {len(samples_and_targets)} samples across {len(tarfiles)} tar files.') + return samples_and_targets, class_name_to_idx, tarfiles + + +class ParserImageInTar(Parser): + """ Multi-tarfile dataset parser where there is one .tar file per class + """ + + def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None): + super().__init__() + + class_name_to_idx = None + if class_map: + class_name_to_idx = load_class_map(class_map, root) + self.root = root + self.samples_and_targets, self.class_name_to_idx, tarfiles = extract_tarinfos( + self.root, + class_name_to_idx=class_name_to_idx, + cache_tarinfo=cache_tarinfo, + extensions=IMG_EXTENSIONS) + self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} + if len(tarfiles) == 1 and tarfiles[0][0] is None: + self.root_is_tar = True + self.tar_state = tarfiles[0][1] + else: + self.root_is_tar = False + self.tar_state = dict(tarfiles) + self.cache_tarfiles = cache_tarfiles + + def __len__(self): + return len(self.samples_and_targets) + + def __getitem__(self, index): + sample, target = self.samples_and_targets[index] + sample_ti, parent_fn, child_ti = sample + parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root + + tf = None + cache_state = None + if self.cache_tarfiles: + cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn] + tf = cache_state.tf + if tf is None: + tf = tarfile.open(parent_abs) + if self.cache_tarfiles: + cache_state.tf = tf + if child_ti is not None: + ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None + if ctf is None: + ctf = tarfile.open(fileobj=tf.extractfile(child_ti)) + if self.cache_tarfiles: + cache_state.children[child_ti.name].tf = ctf + tf = ctf + + return tf.extractfile(sample_ti), target + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples_and_targets[index][0][0].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py index 657b56f9..467537f4 100644 --- a/timm/data/parsers/parser_image_tar.py +++ b/timm/data/parsers/parser_image_tar.py @@ -1,3 +1,10 @@ +""" A dataset parser that reads single tarfile based datasets + +This parser can read datasets consisting if a single tarfile containing images. +I am planning to deprecated it in favour of ParerImageInTar. + +Hacked together by / Copyright 2020 Ross Wightman +""" import os import tarfile @@ -31,6 +38,8 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True): class ParserImageTar(Parser): """ Single tarfile dataset where classes are mapped to folders within tar + NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can + operate on folders of tars or tars in tars. """ def __init__(self, root, class_map=''): super().__init__() diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 39a9243a..15361cb5 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -37,14 +37,14 @@ class ParserTfds(Parser): dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last https://github.com/pytorch/pytorch/issues/33413 * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch - from each worker could be a different size. For training this is avoid by option above, for - validation extra samples are inserted iff distributed mode is enabled so the batches being reduced - across replicas are of same size. This will slightlyalter the results, distributed validation will not be + from each worker could be a different size. For training this is worked around by option above, for + validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced + across replicas are of same size. This will slightly alter the results, distributed validation will not be 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse - since there are to N * J extra samples. + since there are up to N * J extra samples with IterableDatasets. * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of replicas and dataloader workers you can use. For really small datasets that only contain a few shards - you may have to train non-distributed w/ 1-2 dataloader workers. This may not be a huge concern as the + you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the benefit of distributed training or fast dataloading should be much less for small datasets. * This wrapper is currently configured to return individual, decompressed image samples from the TFDS dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible @@ -64,8 +64,8 @@ class ParserTfds(Parser): self.batch_size = batch_size self.builder = tfds.builder(name, data_dir=root) - # NOTE: please use tfds command line app to download & prepare datasets, I don't want to trigger - # it by default here as it's caused issues generating unwanted paths in data directories. + # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call + # download_and_prepare() by default here as it's caused issues generating unwanted paths. self.num_samples = self.builder.info.splits[split].num_examples self.ds = None # initialized lazily on each dataloader worker process @@ -102,7 +102,7 @@ class ParserTfds(Parser): """ InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) - between the splits each iteration but that could be wrong. + between the splits each iteration, but that understanding could be wrong. Possible split options include: * InputContext for both distributed & worker processes (current) * InputContext for distributed and sub-splits for worker processes @@ -154,7 +154,7 @@ class ParserTfds(Parser): sample_count += 1 if self.is_training and sample_count >= target_sample_count: # Need to break out of loop when repeat() is enabled for training w/ oversampling - # this results in 'extra' samples per epoch but seems more desirable than dropping + # this results in extra samples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: diff --git a/train.py b/train.py index a4010e1f..94c417b4 100755 --- a/train.py +++ b/train.py @@ -283,7 +283,7 @@ def _parse_args(): def main(): - setup_default_logging(log_path='./train.log') + setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher