From 9ec3210c2d03e96de1f3cd48b5ba659911cd173a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 Nov 2021 15:52:09 -0800 Subject: [PATCH] More TFDS parser cleanup, support improved TFDS even_split impl (on tfds-nightly only currently). --- timm/data/parsers/parser_tfds.py | 140 ++++++++++++++++++++----------- 1 file changed, 92 insertions(+), 48 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 2b0cd731..67db6891 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification Hacked together by / Copyright 2020 Ross Wightman """ -import os -import io import math import torch import torch.distributed as dist @@ -17,6 +15,13 @@ try: import tensorflow as tf tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) import tensorflow_datasets as tfds + try: + tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + has_buggy_even_splits = False + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + has_buggy_even_splits = True except ImportError as e: print(e) print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") @@ -25,7 +30,7 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue PREFETCH_SIZE = 2048 # samples to prefetch @@ -58,8 +63,34 @@ class ParserTfds(Parser): """ def __init__( - self, root, name, split='train', is_training=False, batch_size=None, - download=False, repeats=0, seed=42): + self, + root, + name, + split='train', + is_training=False, + batch_size=None, + download=False, + repeats=0, + seed=42, + prefetch_size=None, + shuffle_size=None, + max_threadpool_size=None + ): + """ Tensorflow-datasets Wrapper + + Args: + root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir) + name: tfds dataset name (eg `imagenet2012`) + split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) + is_training: training mode, shuffle enabled, dataset len rounded by batch_size + batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes + download: download and build TFDS dataset if set, otherwise must use tfds CLI + repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) + seed: common seed for shard shuffle across all distributed/worker instances + prefetch_size: override default tf.data prefetch buffer size + shuffle_size: override default tf.data shuffle buffer size + max_threadpool_size: override default threadpool size for tf.data + """ super().__init__() self.root = root self.split = split @@ -69,25 +100,33 @@ class ParserTfds(Parser): "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats - self.common_seed = seed # seed across all worker / dist nodes - self.worker_seed = 0 # seed specific to each work instance - self.subsplit = None + self.common_seed = seed # a seed that's fixed across all worker / distributed instances + self.prefetch_size = prefetch_size or PREFETCH_SIZE + self.shuffle_size = shuffle_size or SHUFFLE_SIZE + self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE + # TFDS builder and split information self.builder = tfds.builder(name, data_dir=root) # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() self.split_info = self.builder.info.splits[split] self.num_samples = self.split_info.num_examples - self.ds = None # initialized lazily on each dataloader worker process - self.worker_info = None + # Distributed world state self.dist_rank = 0 self.dist_num_replicas = 1 if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: self.dist_rank = dist.get_rank() self.dist_num_replicas = dist.get_world_size() + # Attributes that are updated in _lazy_init, including the tf.data pipeline itself + self.global_num_workers = 1 + self.worker_info = None + self.worker_seed = 0 # seed unique to each work instance + self.subsplit = None # set when data is distributed across workers using sub-splits + self.ds = None # initialized lazily on each dataloader worker process + def _lazy_init(self): """ Lazily initialize the dataset. @@ -102,38 +141,44 @@ class ParserTfds(Parser): worker_info = torch.utils.data.get_worker_info() # setup input context to split dataset across distributed processes - global_num_workers = num_workers = 1 - global_worker_id = 1 + num_workers = 1 + global_worker_id = 0 if worker_info is not None: self.worker_info = worker_info self.worker_seed = worker_info.seed num_workers = worker_info.num_workers - global_num_workers = self.dist_num_replicas * num_workers - worker_id = worker_info.id - global_worker_id = self.dist_rank * num_workers + worker_id + self.global_num_workers = self.dist_num_replicas * num_workers + global_worker_id = self.dist_rank * num_workers + worker_info.id - # FIXME verify best sharding approach """ Data sharding InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. - Possible split options include: - * InputContext for both distributed & worker processes (current) - * InputContext for distributed and sub-splits for worker processes - * sub-splits for both + + I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing + the data across workers. For training InputContext is used to assign shards to nodes unless num_shards + in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or + for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding. """ - can_subsplit = '[' not in self.split # can't subsplit a subsplit - should_subsplit = global_num_workers > 1 and ( - self.split_info.num_shards < global_num_workers or not self.is_training) - if can_subsplit and should_subsplit: - # manually split the dataset w/o sharding for more even samples / worker - self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id] + should_subsplit = self.global_num_workers > 1 and ( + self.split_info.num_shards < self.global_num_workers or not self.is_training) + if should_subsplit: + # split the dataset w/o using sharding for more even samples / worker, can result in less optimal + # read patterns for distributed training (overlap across shards) so better to use InputContext there + if has_buggy_even_splits: + # my even_split workaround doesn't work on subsplits, upgrade tfds! + if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) + self.subsplit = subsplits[global_worker_id] + else: + subsplits = tfds.even_splits(self.split, self.global_num_workers) + self.subsplit = subsplits[global_worker_id] input_context = None - if global_num_workers > 1 and self.subsplit is None: + if self.global_num_workers > 1 and self.subsplit is None: # set input context to divide shards among distributed replicas input_context = tf.distribute.InputContext( - num_input_pipelines=global_num_workers, + num_input_pipelines=self.global_num_workers, input_pipeline_id=global_worker_id, num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) @@ -143,10 +188,10 @@ class ParserTfds(Parser): input_context=input_context) ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) - # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers + # avoid overloading threading w/ combo of TF ds threads + PyTorch workers options = tf.data.Options() thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' - getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) getattr(options, thread_member).max_intra_op_parallelism = 1 ds = ds.with_options(options) if self.is_training or self.repeats > 1: @@ -154,22 +199,25 @@ class ParserTfds(Parser): # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.is_training: - ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed) - ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) + ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) def __iter__(self): if self.ds is None: self._lazy_init() - # compute a rounded up sample count that is used to: + + # Compute a rounded up sample count that is used to: # 1. make batches even cross workers & replicas in distributed validation. # This adds extra samples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) + target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers) if self.is_training: # round up to nearest batch_size per worker-replica target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + + # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) sample_count = 0 for sample in self.ds: img = Image.fromarray(sample['image'], mode='RGB') @@ -180,21 +228,17 @@ class ParserTfds(Parser): # this results in extra samples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break - if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: + + # Pad across distributed nodes (make counts equal by adding samples) + if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ + 0 < sample_count < target_sample_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. - # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this - # approach is not optimal - yield img, sample['label'] # yield prev sample again - sample_count += 1 - - @property - def _num_workers(self): - return 1 if self.worker_info is None else self.worker_info.num_workers - - @property - def _num_pipelines(self): - return self._num_workers * self.dist_num_replicas + # If using input_context or % based splits, sample count can vary significantly across workers and this + # approach should not be used (hence disabled if self.subsplit isn't set). + while sample_count < target_sample_count: + yield img, sample['label'] # yield prev sample again + sample_count += 1 def __len__(self): # this is just an estimate and does not factor in extra samples added to pad batches based on @@ -202,7 +246,7 @@ class ParserTfds(Parser): return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): - assert False, "Not supported" # no random access to samples + assert False, "Not supported" # no random access to samples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base"""