diff --git a/timm/data/dataset.py b/timm/data/dataset.py index e719f3f6..d3603a23 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -23,15 +23,17 @@ class ImageDataset(data.Dataset): self, root, parser=None, - class_map='', + class_map=None, load_bytes=False, transform=None, + target_transform=None, ): if parser is None or isinstance(parser, str): parser = create_parser(parser or '', root=root, class_map=class_map) self.parser = parser self.load_bytes = load_bytes self.transform = transform + self.target_transform = target_transform self._consecutive_errors = 0 def __getitem__(self, index): @@ -49,7 +51,9 @@ class ImageDataset(data.Dataset): if self.transform is not None: img = self.transform(img) if target is None: - target = torch.tensor(-1, dtype=torch.long) + target = -1 + elif self.target_transform is not None: + target = self.target_transform(target) return img, target def __len__(self): @@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset): split='train', is_training=False, batch_size=None, - class_map='', - load_bytes=False, repeats=0, + download=False, transform=None, + target_transform=None, ): assert parser is not None if isinstance(parser, str): self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) + parser, root=root, split=split, is_training=is_training, + batch_size=batch_size, repeats=repeats, download=download) else: self.parser = parser self.transform = transform + self.target_transform = target_transform self._consecutive_errors = 0 def __iter__(self): for img, target in self.parser: if self.transform is not None: img = self.transform(img) - if target is None: - target = torch.tensor(-1, dtype=torch.long) + if self.target_transform is not None: + target = self.target_transform(target) yield img, target def __len__(self): diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index ccc99d5c..03b03cf5 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,7 +1,26 @@ import os +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ + Places365, ImageNet, ImageFolder +try: + from torchvision.datasets import INaturalist + has_inaturalist = True +except ImportError: + has_inaturalist = False + from .dataset import IterableImageDataset, ImageDataset +_TORCH_BASIC_DS = dict( + cifar10=CIFAR10, + cifar100=CIFAR100, + mnist=MNIST, + qmist=QMNIST, + kmnist=KMNIST, + fashion_mnist=FashionMNIST, +) +_TRAIN_SYNONYM = {'train', 'training'} +_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} + def _search_split(root, split): # look for sub-folder with name of split in root and use that if it exists @@ -9,22 +28,107 @@ def _search_split(root, split): try_root = os.path.join(root, split_name) if os.path.exists(try_root): return try_root - if split_name == 'validation': - try_root = os.path.join(root, 'val') - if os.path.exists(try_root): - return try_root + + def _try(syn): + for s in syn: + try_root = os.path.join(root, s) + if os.path.exists(try_root): + return try_root + return root + if split_name in _TRAIN_SYNONYM: + root = _try(_TRAIN_SYNONYM) + elif split_name in _EVAL_SYNONYM: + root = _try(_EVAL_SYNONYM) return root -def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): +def create_dataset( + name, + root, + split='validation', + search_split=True, + class_map=None, + load_bytes=False, + is_training=False, + download=False, + batch_size=None, + repeats=0, + **kwargs +): + """ Dataset factory method + + In parenthesis after each arg are the type of dataset supported for each arg, one of: + * folder - default, timm folder (or tar) based ImageDataset + * torch - torchvision based datasets + * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset + * all - any of the above + + Args: + name: dataset name, empty is okay for folder based datasets + root: root folder of dataset (all) + split: dataset split (all) + search_split: search for split specific child fold from root so one can specify + `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) + class_map: specify class -> index mapping via text file or dict (folder) + load_bytes: load data, return images as undecoded bytes (folder) + download: download dataset if not present and supported (TFDS, torch) + is_training: create dataset in train mode, this is different from the split. + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) + batch_size: batch size hint for (TFDS) + repeats: dataset repeats per iteration i.e. epoch (TFDS) + **kwargs: other args to pass to dataset + + Returns: + Dataset object + """ name = name.lower() - if name.startswith('tfds'): + if name.startswith('torch/'): + name = name.split('/', 2)[-1] + torch_kwargs = dict(root=root, download=download, **kwargs) + if name in _TORCH_BASIC_DS: + ds_class = _TORCH_BASIC_DS[name] + use_train = split in _TRAIN_SYNONYM + ds = ds_class(train=use_train, **torch_kwargs) + elif name == 'inaturalist' or name == 'inat': + assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' + target_type = 'full' + split_split = split.split('/') + if len(split_split) > 1: + target_type = split_split[0].split('_') + if len(target_type) == 1: + target_type = target_type[0] + split = split_split[-1] + if split in _TRAIN_SYNONYM: + split = '2021_train' + elif split in _EVAL_SYNONYM: + split = '2021_valid' + ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) + elif name == 'places365': + if split in _TRAIN_SYNONYM: + split = 'train-standard' + elif split in _EVAL_SYNONYM: + split = 'val' + ds = Places365(split=split, **torch_kwargs) + elif name == 'imagenet': + if split in _EVAL_SYNONYM: + split = 'val' + ds = ImageNet(split=split, **torch_kwargs) + elif name == 'image_folder' or name == 'folder': + # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason + if search_split and os.path.isdir(root): + # look for split specific sub-folder in root + root = _search_split(root, split) + ds = ImageFolder(root, **kwargs) + else: + assert False, f"Unknown torchvision dataset {name}" + elif name.startswith('tfds/'): ds = IterableImageDataset( - root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) + root, parser=name, split=split, is_training=is_training, + download=download, batch_size=batch_size, repeats=repeats, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future - kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier if search_split and os.path.isdir(root): + # look for split specific sub-folder in root root = _search_split(root, split) - ds = ImageDataset(root, parser=name, **kwargs) + ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) return ds diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py index 9ef4d1fa..6b6fe453 100644 --- a/timm/data/parsers/class_map.py +++ b/timm/data/parsers/class_map.py @@ -1,16 +1,19 @@ import os -def load_class_map(filename, root=''): - class_map_path = filename +def load_class_map(map_or_filename, root=''): + if isinstance(map_or_filename, dict): + assert dict, 'class_map dict must be non-empty' + return map_or_filename + class_map_path = map_or_filename if not os.path.exists(class_map_path): - class_map_path = os.path.join(root, filename) - assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename - class_map_ext = os.path.splitext(filename)[-1].lower() + class_map_path = os.path.join(root, class_map_path) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename + class_map_ext = os.path.splitext(map_or_filename)[-1].lower() if class_map_ext == '.txt': with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} else: - assert False, 'Unsupported class map extension' + assert False, f'Unsupported class map file extension ({class_map_ext}).' return class_to_idx diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 419ffe89..892090ad 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs): # explicitly select other options shortly if prefix == 'tfds': from .parser_tfds import ParserTfds # defer tensorflow import - parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) + parser = ParserTfds(root, name, split=split, **kwargs) else: assert os.path.exists(root) # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 2ff90b09..67db6891 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification Hacked together by / Copyright 2020 Ross Wightman """ -import os -import io import math import torch import torch.distributed as dist @@ -17,6 +15,13 @@ try: import tensorflow as tf tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) import tensorflow_datasets as tfds + try: + tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + has_buggy_even_splits = False + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + has_buggy_even_splits = True except ImportError as e: print(e) print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") @@ -25,7 +30,7 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue PREFETCH_SIZE = 2048 # samples to prefetch @@ -57,32 +62,71 @@ class ParserTfds(Parser): components. """ - def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): + def __init__( + self, + root, + name, + split='train', + is_training=False, + batch_size=None, + download=False, + repeats=0, + seed=42, + prefetch_size=None, + shuffle_size=None, + max_threadpool_size=None + ): + """ Tensorflow-datasets Wrapper + + Args: + root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir) + name: tfds dataset name (eg `imagenet2012`) + split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) + is_training: training mode, shuffle enabled, dataset len rounded by batch_size + batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes + download: download and build TFDS dataset if set, otherwise must use tfds CLI + repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) + seed: common seed for shard shuffle across all distributed/worker instances + prefetch_size: override default tf.data prefetch buffer size + shuffle_size: override default tf.data shuffle buffer size + max_threadpool_size: override default threadpool size for tf.data + """ super().__init__() self.root = root self.split = split - self.shuffle = shuffle self.is_training = is_training if self.is_training: assert batch_size is not None,\ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats - self.subsplit = None + self.common_seed = seed # a seed that's fixed across all worker / distributed instances + self.prefetch_size = prefetch_size or PREFETCH_SIZE + self.shuffle_size = shuffle_size or SHUFFLE_SIZE + self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE + # TFDS builder and split information self.builder = tfds.builder(name, data_dir=root) - # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call - # download_and_prepare() by default here as it's caused issues generating unwanted paths. - self.num_samples = self.builder.info.splits[split].num_examples - self.ds = None # initialized lazily on each dataloader worker process + # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag + if download: + self.builder.download_and_prepare() + self.split_info = self.builder.info.splits[split] + self.num_samples = self.split_info.num_examples - self.worker_info = None + # Distributed world state self.dist_rank = 0 self.dist_num_replicas = 1 if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: self.dist_rank = dist.get_rank() self.dist_num_replicas = dist.get_world_size() + # Attributes that are updated in _lazy_init, including the tf.data pipeline itself + self.global_num_workers = 1 + self.worker_info = None + self.worker_seed = 0 # seed unique to each work instance + self.subsplit = None # set when data is distributed across workers using sub-splits + self.ds = None # initialized lazily on each dataloader worker process + def _lazy_init(self): """ Lazily initialize the dataset. @@ -97,78 +141,83 @@ class ParserTfds(Parser): worker_info = torch.utils.data.get_worker_info() # setup input context to split dataset across distributed processes - split = self.split num_workers = 1 + global_worker_id = 0 if worker_info is not None: self.worker_info = worker_info + self.worker_seed = worker_info.seed num_workers = worker_info.num_workers - global_num_workers = self.dist_num_replicas * num_workers - worker_id = worker_info.id + self.global_num_workers = self.dist_num_replicas * num_workers + global_worker_id = self.dist_rank * num_workers + worker_info.id - # FIXME I need to spend more time figuring out the best way to distribute/split data across - # combo of distributed replicas + dataloader worker processes - """ + """ Data sharding InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. - Possible split options include: - * InputContext for both distributed & worker processes (current) - * InputContext for distributed and sub-splits for worker processes - * sub-splits for both + + I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing + the data across workers. For training InputContext is used to assign shards to nodes unless num_shards + in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or + for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding. """ - # split_size = self.num_samples // num_workers - # start = worker_id * split_size - # if worker_id == num_workers - 1: - # split = split + '[{}:]'.format(start) - # else: - # split = split + '[{}:{}]'.format(start, start + split_size) - if not self.is_training and '[' not in self.split: - # If not training, and split doesn't define a subsplit, manually split the dataset - # for more even samples / worker - self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ - self.dist_rank * num_workers + worker_id] - - if self.subsplit is None: + should_subsplit = self.global_num_workers > 1 and ( + self.split_info.num_shards < self.global_num_workers or not self.is_training) + if should_subsplit: + # split the dataset w/o using sharding for more even samples / worker, can result in less optimal + # read patterns for distributed training (overlap across shards) so better to use InputContext there + if has_buggy_even_splits: + # my even_split workaround doesn't work on subsplits, upgrade tfds! + if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) + self.subsplit = subsplits[global_worker_id] + else: + subsplits = tfds.even_splits(self.split, self.global_num_workers) + self.subsplit = subsplits[global_worker_id] + + input_context = None + if self.global_num_workers > 1 and self.subsplit is None: + # set input context to divide shards among distributed replicas input_context = tf.distribute.InputContext( - num_input_pipelines=self.dist_num_replicas * num_workers, - input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_input_pipelines=self.global_num_workers, + input_pipeline_id=global_worker_id, num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) - else: - input_context = None - read_config = tfds.ReadConfig( - shuffle_seed=42, + shuffle_seed=self.common_seed, shuffle_reshuffle_each_iteration=True, input_context=input_context) ds = self.builder.as_dataset( - split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) - # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers + split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) + # avoid overloading threading w/ combo of TF ds threads + PyTorch workers options = tf.data.Options() - options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - options.experimental_threading.max_intra_op_parallelism = 1 + thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' + getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) + getattr(options, thread_member).max_intra_op_parallelism = 1 ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually - if self.shuffle: - ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) - ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) + if self.is_training: + ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) def __iter__(self): if self.ds is None: self._lazy_init() - # compute a rounded up sample count that is used to: + + # Compute a rounded up sample count that is used to: # 1. make batches even cross workers & replicas in distributed validation. # This adds extra samples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) + target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers) if self.is_training: # round up to nearest batch_size per worker-replica target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + + # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) sample_count = 0 for sample in self.ds: img = Image.fromarray(sample['image'], mode='RGB') @@ -179,21 +228,17 @@ class ParserTfds(Parser): # this results in extra samples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break - if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: + + # Pad across distributed nodes (make counts equal by adding samples) + if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ + 0 < sample_count < target_sample_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. - # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this - # approach is not optimal - yield img, sample['label'] # yield prev sample again - sample_count += 1 - - @property - def _num_workers(self): - return 1 if self.worker_info is None else self.worker_info.num_workers - - @property - def _num_pipelines(self): - return self._num_workers * self.dist_num_replicas + # If using input_context or % based splits, sample count can vary significantly across workers and this + # approach should not be used (hence disabled if self.subsplit isn't set). + while sample_count < target_sample_count: + yield img, sample['label'] # yield prev sample again + sample_count += 1 def __len__(self): # this is just an estimate and does not factor in extra samples added to pad batches based on @@ -201,7 +246,7 @@ class ParserTfds(Parser): return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): - assert False, "Not supported" # no random access to samples + assert False, "Not supported" # no random access to samples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base""" diff --git a/train.py b/train.py index 332dec0c..10d839be 100755 --- a/train.py +++ b/train.py @@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -# Dataset / Model parameters +# Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', @@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') + +# Model parameters parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, @@ -484,11 +490,16 @@ def main(): # create the train and eval datasets dataset_train = create_dataset( - args.dataset, - root=args.data_dir, split=args.train_split, is_training=True, - batch_size=args.batch_size, repeats=args.epoch_repeats) + args.dataset, root=args.data_dir, split=args.train_split, is_training=True, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size, + repeats=args.epoch_repeats) dataset_eval = create_dataset( - args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) + args.dataset, root=args.data_dir, split=args.val_split, is_training=False, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None diff --git a/validate.py b/validate.py index 2e18841f..a99e5b5c 100755 --- a/validate.py +++ b/validate.py @@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', @@ -175,7 +177,7 @@ def validate(args): dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, - load_bytes=args.tf_preprocessing, class_map=args.class_map) + download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: