From ba65dfe2c6681404f35a9409f802aba2a226b761 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Nov 2021 22:34:15 -0800 Subject: [PATCH 1/2] Dataset work * support some torchvision datasets * improvements to TFDS wrapper for subsplit handling (fix #942), shuffle seed * add class-map support to train (fix #957) --- timm/data/dataset.py | 20 +++-- timm/data/dataset_factory.py | 122 ++++++++++++++++++++++++++-- timm/data/parsers/class_map.py | 15 ++-- timm/data/parsers/parser_factory.py | 2 +- timm/data/parsers/parser_tfds.py | 69 ++++++++-------- train.py | 21 +++-- validate.py | 4 +- 7 files changed, 190 insertions(+), 63 deletions(-) diff --git a/timm/data/dataset.py b/timm/data/dataset.py index e719f3f6..d3603a23 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -23,15 +23,17 @@ class ImageDataset(data.Dataset): self, root, parser=None, - class_map='', + class_map=None, load_bytes=False, transform=None, + target_transform=None, ): if parser is None or isinstance(parser, str): parser = create_parser(parser or '', root=root, class_map=class_map) self.parser = parser self.load_bytes = load_bytes self.transform = transform + self.target_transform = target_transform self._consecutive_errors = 0 def __getitem__(self, index): @@ -49,7 +51,9 @@ class ImageDataset(data.Dataset): if self.transform is not None: img = self.transform(img) if target is None: - target = torch.tensor(-1, dtype=torch.long) + target = -1 + elif self.target_transform is not None: + target = self.target_transform(target) return img, target def __len__(self): @@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset): split='train', is_training=False, batch_size=None, - class_map='', - load_bytes=False, repeats=0, + download=False, transform=None, + target_transform=None, ): assert parser is not None if isinstance(parser, str): self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) + parser, root=root, split=split, is_training=is_training, + batch_size=batch_size, repeats=repeats, download=download) else: self.parser = parser self.transform = transform + self.target_transform = target_transform self._consecutive_errors = 0 def __iter__(self): for img, target in self.parser: if self.transform is not None: img = self.transform(img) - if target is None: - target = torch.tensor(-1, dtype=torch.long) + if self.target_transform is not None: + target = self.target_transform(target) yield img, target def __len__(self): diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index ccc99d5c..03b03cf5 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,7 +1,26 @@ import os +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ + Places365, ImageNet, ImageFolder +try: + from torchvision.datasets import INaturalist + has_inaturalist = True +except ImportError: + has_inaturalist = False + from .dataset import IterableImageDataset, ImageDataset +_TORCH_BASIC_DS = dict( + cifar10=CIFAR10, + cifar100=CIFAR100, + mnist=MNIST, + qmist=QMNIST, + kmnist=KMNIST, + fashion_mnist=FashionMNIST, +) +_TRAIN_SYNONYM = {'train', 'training'} +_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} + def _search_split(root, split): # look for sub-folder with name of split in root and use that if it exists @@ -9,22 +28,107 @@ def _search_split(root, split): try_root = os.path.join(root, split_name) if os.path.exists(try_root): return try_root - if split_name == 'validation': - try_root = os.path.join(root, 'val') - if os.path.exists(try_root): - return try_root + + def _try(syn): + for s in syn: + try_root = os.path.join(root, s) + if os.path.exists(try_root): + return try_root + return root + if split_name in _TRAIN_SYNONYM: + root = _try(_TRAIN_SYNONYM) + elif split_name in _EVAL_SYNONYM: + root = _try(_EVAL_SYNONYM) return root -def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): +def create_dataset( + name, + root, + split='validation', + search_split=True, + class_map=None, + load_bytes=False, + is_training=False, + download=False, + batch_size=None, + repeats=0, + **kwargs +): + """ Dataset factory method + + In parenthesis after each arg are the type of dataset supported for each arg, one of: + * folder - default, timm folder (or tar) based ImageDataset + * torch - torchvision based datasets + * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset + * all - any of the above + + Args: + name: dataset name, empty is okay for folder based datasets + root: root folder of dataset (all) + split: dataset split (all) + search_split: search for split specific child fold from root so one can specify + `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) + class_map: specify class -> index mapping via text file or dict (folder) + load_bytes: load data, return images as undecoded bytes (folder) + download: download dataset if not present and supported (TFDS, torch) + is_training: create dataset in train mode, this is different from the split. + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) + batch_size: batch size hint for (TFDS) + repeats: dataset repeats per iteration i.e. epoch (TFDS) + **kwargs: other args to pass to dataset + + Returns: + Dataset object + """ name = name.lower() - if name.startswith('tfds'): + if name.startswith('torch/'): + name = name.split('/', 2)[-1] + torch_kwargs = dict(root=root, download=download, **kwargs) + if name in _TORCH_BASIC_DS: + ds_class = _TORCH_BASIC_DS[name] + use_train = split in _TRAIN_SYNONYM + ds = ds_class(train=use_train, **torch_kwargs) + elif name == 'inaturalist' or name == 'inat': + assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' + target_type = 'full' + split_split = split.split('/') + if len(split_split) > 1: + target_type = split_split[0].split('_') + if len(target_type) == 1: + target_type = target_type[0] + split = split_split[-1] + if split in _TRAIN_SYNONYM: + split = '2021_train' + elif split in _EVAL_SYNONYM: + split = '2021_valid' + ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) + elif name == 'places365': + if split in _TRAIN_SYNONYM: + split = 'train-standard' + elif split in _EVAL_SYNONYM: + split = 'val' + ds = Places365(split=split, **torch_kwargs) + elif name == 'imagenet': + if split in _EVAL_SYNONYM: + split = 'val' + ds = ImageNet(split=split, **torch_kwargs) + elif name == 'image_folder' or name == 'folder': + # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason + if search_split and os.path.isdir(root): + # look for split specific sub-folder in root + root = _search_split(root, split) + ds = ImageFolder(root, **kwargs) + else: + assert False, f"Unknown torchvision dataset {name}" + elif name.startswith('tfds/'): ds = IterableImageDataset( - root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) + root, parser=name, split=split, is_training=is_training, + download=download, batch_size=batch_size, repeats=repeats, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future - kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier if search_split and os.path.isdir(root): + # look for split specific sub-folder in root root = _search_split(root, split) - ds = ImageDataset(root, parser=name, **kwargs) + ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) return ds diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py index 9ef4d1fa..6b6fe453 100644 --- a/timm/data/parsers/class_map.py +++ b/timm/data/parsers/class_map.py @@ -1,16 +1,19 @@ import os -def load_class_map(filename, root=''): - class_map_path = filename +def load_class_map(map_or_filename, root=''): + if isinstance(map_or_filename, dict): + assert dict, 'class_map dict must be non-empty' + return map_or_filename + class_map_path = map_or_filename if not os.path.exists(class_map_path): - class_map_path = os.path.join(root, filename) - assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename - class_map_ext = os.path.splitext(filename)[-1].lower() + class_map_path = os.path.join(root, class_map_path) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename + class_map_ext = os.path.splitext(map_or_filename)[-1].lower() if class_map_ext == '.txt': with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} else: - assert False, 'Unsupported class map extension' + assert False, f'Unsupported class map file extension ({class_map_ext}).' return class_to_idx diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 419ffe89..892090ad 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs): # explicitly select other options shortly if prefix == 'tfds': from .parser_tfds import ParserTfds # defer tensorflow import - parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) + parser = ParserTfds(root, name, split=split, **kwargs) else: assert os.path.exists(root) # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 2ff90b09..2b0cd731 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -57,23 +57,28 @@ class ParserTfds(Parser): components. """ - def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): + def __init__( + self, root, name, split='train', is_training=False, batch_size=None, + download=False, repeats=0, seed=42): super().__init__() self.root = root self.split = split - self.shuffle = shuffle self.is_training = is_training if self.is_training: assert batch_size is not None,\ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats + self.common_seed = seed # seed across all worker / dist nodes + self.worker_seed = 0 # seed specific to each work instance self.subsplit = None self.builder = tfds.builder(name, data_dir=root) - # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call - # download_and_prepare() by default here as it's caused issues generating unwanted paths. - self.num_samples = self.builder.info.splits[split].num_examples + # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag + if download: + self.builder.download_and_prepare() + self.split_info = self.builder.info.splits[split] + self.num_samples = self.split_info.num_examples self.ds = None # initialized lazily on each dataloader worker process self.worker_info = None @@ -97,17 +102,18 @@ class ParserTfds(Parser): worker_info = torch.utils.data.get_worker_info() # setup input context to split dataset across distributed processes - split = self.split - num_workers = 1 + global_num_workers = num_workers = 1 + global_worker_id = 1 if worker_info is not None: self.worker_info = worker_info + self.worker_seed = worker_info.seed num_workers = worker_info.num_workers global_num_workers = self.dist_num_replicas * num_workers worker_id = worker_info.id + global_worker_id = self.dist_rank * num_workers + worker_id - # FIXME I need to spend more time figuring out the best way to distribute/split data across - # combo of distributed replicas + dataloader worker processes - """ + # FIXME verify best sharding approach + """ Data sharding InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. @@ -116,44 +122,39 @@ class ParserTfds(Parser): * InputContext for distributed and sub-splits for worker processes * sub-splits for both """ - # split_size = self.num_samples // num_workers - # start = worker_id * split_size - # if worker_id == num_workers - 1: - # split = split + '[{}:]'.format(start) - # else: - # split = split + '[{}:{}]'.format(start, start + split_size) - if not self.is_training and '[' not in self.split: - # If not training, and split doesn't define a subsplit, manually split the dataset - # for more even samples / worker - self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ - self.dist_rank * num_workers + worker_id] - - if self.subsplit is None: + can_subsplit = '[' not in self.split # can't subsplit a subsplit + should_subsplit = global_num_workers > 1 and ( + self.split_info.num_shards < global_num_workers or not self.is_training) + if can_subsplit and should_subsplit: + # manually split the dataset w/o sharding for more even samples / worker + self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id] + + input_context = None + if global_num_workers > 1 and self.subsplit is None: + # set input context to divide shards among distributed replicas input_context = tf.distribute.InputContext( - num_input_pipelines=self.dist_num_replicas * num_workers, - input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_input_pipelines=global_num_workers, + input_pipeline_id=global_worker_id, num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) - else: - input_context = None - read_config = tfds.ReadConfig( - shuffle_seed=42, + shuffle_seed=self.common_seed, shuffle_reshuffle_each_iteration=True, input_context=input_context) ds = self.builder.as_dataset( - split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) + split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers options = tf.data.Options() - options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - options.experimental_threading.max_intra_op_parallelism = 1 + thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' + getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + getattr(options, thread_member).max_intra_op_parallelism = 1 ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually - if self.shuffle: - ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) + if self.is_training: + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) self.ds = tfds.as_numpy(ds) diff --git a/train.py b/train.py index 332dec0c..10d839be 100755 --- a/train.py +++ b/train.py @@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -# Dataset / Model parameters +# Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', @@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') + +# Model parameters parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, @@ -484,11 +490,16 @@ def main(): # create the train and eval datasets dataset_train = create_dataset( - args.dataset, - root=args.data_dir, split=args.train_split, is_training=True, - batch_size=args.batch_size, repeats=args.epoch_repeats) + args.dataset, root=args.data_dir, split=args.train_split, is_training=True, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size, + repeats=args.epoch_repeats) dataset_eval = create_dataset( - args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) + args.dataset, root=args.data_dir, split=args.val_split, is_training=False, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None diff --git a/validate.py b/validate.py index 2e18841f..a99e5b5c 100755 --- a/validate.py +++ b/validate.py @@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', @@ -175,7 +177,7 @@ def validate(args): dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, - load_bytes=args.tf_preprocessing, class_map=args.class_map) + download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: From 9ec3210c2d03e96de1f3cd48b5ba659911cd173a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 Nov 2021 15:52:09 -0800 Subject: [PATCH 2/2] More TFDS parser cleanup, support improved TFDS even_split impl (on tfds-nightly only currently). --- timm/data/parsers/parser_tfds.py | 140 ++++++++++++++++++++----------- 1 file changed, 92 insertions(+), 48 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 2b0cd731..67db6891 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification Hacked together by / Copyright 2020 Ross Wightman """ -import os -import io import math import torch import torch.distributed as dist @@ -17,6 +15,13 @@ try: import tensorflow as tf tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) import tensorflow_datasets as tfds + try: + tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + has_buggy_even_splits = False + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + has_buggy_even_splits = True except ImportError as e: print(e) print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") @@ -25,7 +30,7 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue PREFETCH_SIZE = 2048 # samples to prefetch @@ -58,8 +63,34 @@ class ParserTfds(Parser): """ def __init__( - self, root, name, split='train', is_training=False, batch_size=None, - download=False, repeats=0, seed=42): + self, + root, + name, + split='train', + is_training=False, + batch_size=None, + download=False, + repeats=0, + seed=42, + prefetch_size=None, + shuffle_size=None, + max_threadpool_size=None + ): + """ Tensorflow-datasets Wrapper + + Args: + root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir) + name: tfds dataset name (eg `imagenet2012`) + split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) + is_training: training mode, shuffle enabled, dataset len rounded by batch_size + batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes + download: download and build TFDS dataset if set, otherwise must use tfds CLI + repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) + seed: common seed for shard shuffle across all distributed/worker instances + prefetch_size: override default tf.data prefetch buffer size + shuffle_size: override default tf.data shuffle buffer size + max_threadpool_size: override default threadpool size for tf.data + """ super().__init__() self.root = root self.split = split @@ -69,25 +100,33 @@ class ParserTfds(Parser): "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats - self.common_seed = seed # seed across all worker / dist nodes - self.worker_seed = 0 # seed specific to each work instance - self.subsplit = None + self.common_seed = seed # a seed that's fixed across all worker / distributed instances + self.prefetch_size = prefetch_size or PREFETCH_SIZE + self.shuffle_size = shuffle_size or SHUFFLE_SIZE + self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE + # TFDS builder and split information self.builder = tfds.builder(name, data_dir=root) # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() self.split_info = self.builder.info.splits[split] self.num_samples = self.split_info.num_examples - self.ds = None # initialized lazily on each dataloader worker process - self.worker_info = None + # Distributed world state self.dist_rank = 0 self.dist_num_replicas = 1 if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: self.dist_rank = dist.get_rank() self.dist_num_replicas = dist.get_world_size() + # Attributes that are updated in _lazy_init, including the tf.data pipeline itself + self.global_num_workers = 1 + self.worker_info = None + self.worker_seed = 0 # seed unique to each work instance + self.subsplit = None # set when data is distributed across workers using sub-splits + self.ds = None # initialized lazily on each dataloader worker process + def _lazy_init(self): """ Lazily initialize the dataset. @@ -102,38 +141,44 @@ class ParserTfds(Parser): worker_info = torch.utils.data.get_worker_info() # setup input context to split dataset across distributed processes - global_num_workers = num_workers = 1 - global_worker_id = 1 + num_workers = 1 + global_worker_id = 0 if worker_info is not None: self.worker_info = worker_info self.worker_seed = worker_info.seed num_workers = worker_info.num_workers - global_num_workers = self.dist_num_replicas * num_workers - worker_id = worker_info.id - global_worker_id = self.dist_rank * num_workers + worker_id + self.global_num_workers = self.dist_num_replicas * num_workers + global_worker_id = self.dist_rank * num_workers + worker_info.id - # FIXME verify best sharding approach """ Data sharding InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. - Possible split options include: - * InputContext for both distributed & worker processes (current) - * InputContext for distributed and sub-splits for worker processes - * sub-splits for both + + I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing + the data across workers. For training InputContext is used to assign shards to nodes unless num_shards + in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or + for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding. """ - can_subsplit = '[' not in self.split # can't subsplit a subsplit - should_subsplit = global_num_workers > 1 and ( - self.split_info.num_shards < global_num_workers or not self.is_training) - if can_subsplit and should_subsplit: - # manually split the dataset w/o sharding for more even samples / worker - self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id] + should_subsplit = self.global_num_workers > 1 and ( + self.split_info.num_shards < self.global_num_workers or not self.is_training) + if should_subsplit: + # split the dataset w/o using sharding for more even samples / worker, can result in less optimal + # read patterns for distributed training (overlap across shards) so better to use InputContext there + if has_buggy_even_splits: + # my even_split workaround doesn't work on subsplits, upgrade tfds! + if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) + self.subsplit = subsplits[global_worker_id] + else: + subsplits = tfds.even_splits(self.split, self.global_num_workers) + self.subsplit = subsplits[global_worker_id] input_context = None - if global_num_workers > 1 and self.subsplit is None: + if self.global_num_workers > 1 and self.subsplit is None: # set input context to divide shards among distributed replicas input_context = tf.distribute.InputContext( - num_input_pipelines=global_num_workers, + num_input_pipelines=self.global_num_workers, input_pipeline_id=global_worker_id, num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) @@ -143,10 +188,10 @@ class ParserTfds(Parser): input_context=input_context) ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) - # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers + # avoid overloading threading w/ combo of TF ds threads + PyTorch workers options = tf.data.Options() thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' - getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) getattr(options, thread_member).max_intra_op_parallelism = 1 ds = ds.with_options(options) if self.is_training or self.repeats > 1: @@ -154,22 +199,25 @@ class ParserTfds(Parser): # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.is_training: - ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed) - ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) + ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) def __iter__(self): if self.ds is None: self._lazy_init() - # compute a rounded up sample count that is used to: + + # Compute a rounded up sample count that is used to: # 1. make batches even cross workers & replicas in distributed validation. # This adds extra samples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) + target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers) if self.is_training: # round up to nearest batch_size per worker-replica target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + + # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) sample_count = 0 for sample in self.ds: img = Image.fromarray(sample['image'], mode='RGB') @@ -180,21 +228,17 @@ class ParserTfds(Parser): # this results in extra samples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break - if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: + + # Pad across distributed nodes (make counts equal by adding samples) + if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ + 0 < sample_count < target_sample_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. - # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this - # approach is not optimal - yield img, sample['label'] # yield prev sample again - sample_count += 1 - - @property - def _num_workers(self): - return 1 if self.worker_info is None else self.worker_info.num_workers - - @property - def _num_pipelines(self): - return self._num_workers * self.dist_num_replicas + # If using input_context or % based splits, sample count can vary significantly across workers and this + # approach should not be used (hence disabled if self.subsplit isn't set). + while sample_count < target_sample_count: + yield img, sample['label'] # yield prev sample again + sample_count += 1 def __len__(self): # this is just an estimate and does not factor in extra samples added to pad batches based on @@ -202,7 +246,7 @@ class ParserTfds(Parser): return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): - assert False, "Not supported" # no random access to samples + assert False, "Not supported" # no random access to samples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base"""