From f42f1df26c5b44321a9cc65aca3f728a89d7479d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Mar 2021 23:16:14 -0700 Subject: [PATCH] Improve evenness of per-worker split for validation set with TFDS --- timm/data/parsers/parser_tfds.py | 41 +++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0c2e10c0..0b12a4db 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -29,6 +29,11 @@ SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue PREFETCH_SIZE = 4096 # samples to prefetch +def even_split_indices(split, n, num_samples): + partitions = [round(i * num_samples / n) for i in range(n + 1)] + return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] + + class ParserTfds(Parser): """ Wrap Tensorflow Datasets for use in PyTorch @@ -63,6 +68,7 @@ class ParserTfds(Parser): "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats + self.subsplit = None self.builder = tfds.builder(name, data_dir=root) # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call @@ -96,6 +102,7 @@ class ParserTfds(Parser): if worker_info is not None: self.worker_info = worker_info num_workers = worker_info.num_workers + global_num_workers = self.dist_num_replicas * num_workers worker_id = worker_info.id # FIXME I need to spend more time figuring out the best way to distribute/split data across @@ -115,15 +122,27 @@ class ParserTfds(Parser): # split = split + '[{}:]'.format(start) # else: # split = split + '[{}:{}]'.format(start, start + split_size) - - input_context = tf.distribute.InputContext( - num_input_pipelines=self.dist_num_replicas * num_workers, - input_pipeline_id=self.dist_rank * num_workers + worker_id, - num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact? - ) - - read_config = tfds.ReadConfig(input_context=input_context) - ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config) + if not self.is_training and '[' not in self.split: + # If not training, and split doesn't define a subsplit, manually split the dataset + # for more even samples / worker + self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ + self.dist_rank * num_workers + worker_id] + + if self.subsplit is None: + input_context = tf.distribute.InputContext( + num_input_pipelines=self.dist_num_replicas * num_workers, + input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? + ) + else: + input_context = None + + read_config = tfds.ReadConfig( + shuffle_seed=42, + shuffle_reshuffle_each_iteration=True, + input_context=input_context) + ds = self.builder.as_dataset( + split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) ds.options().experimental_threading.max_intra_op_parallelism = 1 @@ -161,8 +180,8 @@ class ParserTfds(Parser): if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. - # FIXME this needs more testing, possible for sharding / split api to cause differences of > 1? - assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal + # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this + # approach is not optimal yield img, sample['label'] # yield prev sample again sample_count += 1