Improve evenness of per-worker split for validation set with TFDS

pull/450/head
Ross Wightman 3 years ago
parent cbcb76d72c
commit f42f1df26c

@ -29,6 +29,11 @@ SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch
def even_split_indices(split, n, num_samples):
partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)]
class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch
@ -63,6 +68,7 @@ class ParserTfds(Parser):
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.repeats = repeats
self.subsplit = None
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
@ -96,6 +102,7 @@ class ParserTfds(Parser):
if worker_info is not None:
self.worker_info = worker_info
num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id
# FIXME I need to spend more time figuring out the best way to distribute/split data across
@ -115,15 +122,27 @@ class ParserTfds(Parser):
# split = split + '[{}:]'.format(start)
# else:
# split = split + '[{}:{}]'.format(start, start + split_size)
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
)
read_config = tfds.ReadConfig(input_context=input_context)
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
if not self.is_training and '[' not in self.split:
# If not training, and split doesn't define a subsplit, manually split the dataset
# for more even samples / worker
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
self.dist_rank * num_workers + worker_id]
if self.subsplit is None:
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
else:
input_context = None
read_config = tfds.ReadConfig(
shuffle_seed=42,
shuffle_reshuffle_each_iteration=True,
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1
@ -161,8 +180,8 @@ class ParserTfds(Parser):
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
# approach is not optimal
yield img, sample['label'] # yield prev sample again
sample_count += 1

Loading…
Cancel
Save