|
|
|
@ -29,6 +29,11 @@ SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
|
|
|
|
PREFETCH_SIZE = 4096 # samples to prefetch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def even_split_indices(split, n, num_samples):
|
|
|
|
|
partitions = [round(i * num_samples / n) for i in range(n + 1)]
|
|
|
|
|
return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParserTfds(Parser):
|
|
|
|
|
""" Wrap Tensorflow Datasets for use in PyTorch
|
|
|
|
|
|
|
|
|
@ -63,6 +68,7 @@ class ParserTfds(Parser):
|
|
|
|
|
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
self.repeats = repeats
|
|
|
|
|
self.subsplit = None
|
|
|
|
|
|
|
|
|
|
self.builder = tfds.builder(name, data_dir=root)
|
|
|
|
|
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
|
|
|
@ -96,6 +102,7 @@ class ParserTfds(Parser):
|
|
|
|
|
if worker_info is not None:
|
|
|
|
|
self.worker_info = worker_info
|
|
|
|
|
num_workers = worker_info.num_workers
|
|
|
|
|
global_num_workers = self.dist_num_replicas * num_workers
|
|
|
|
|
worker_id = worker_info.id
|
|
|
|
|
|
|
|
|
|
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
|
|
|
@ -115,15 +122,27 @@ class ParserTfds(Parser):
|
|
|
|
|
# split = split + '[{}:]'.format(start)
|
|
|
|
|
# else:
|
|
|
|
|
# split = split + '[{}:{}]'.format(start, start + split_size)
|
|
|
|
|
|
|
|
|
|
input_context = tf.distribute.InputContext(
|
|
|
|
|
num_input_pipelines=self.dist_num_replicas * num_workers,
|
|
|
|
|
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
|
|
|
|
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
read_config = tfds.ReadConfig(input_context=input_context)
|
|
|
|
|
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
|
|
|
|
|
if not self.is_training and '[' not in self.split:
|
|
|
|
|
# If not training, and split doesn't define a subsplit, manually split the dataset
|
|
|
|
|
# for more even samples / worker
|
|
|
|
|
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
|
|
|
|
|
self.dist_rank * num_workers + worker_id]
|
|
|
|
|
|
|
|
|
|
if self.subsplit is None:
|
|
|
|
|
input_context = tf.distribute.InputContext(
|
|
|
|
|
num_input_pipelines=self.dist_num_replicas * num_workers,
|
|
|
|
|
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
|
|
|
|
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
input_context = None
|
|
|
|
|
|
|
|
|
|
read_config = tfds.ReadConfig(
|
|
|
|
|
shuffle_seed=42,
|
|
|
|
|
shuffle_reshuffle_each_iteration=True,
|
|
|
|
|
input_context=input_context)
|
|
|
|
|
ds = self.builder.as_dataset(
|
|
|
|
|
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
|
|
|
|
|
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
|
|
|
|
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
|
|
|
|
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
|
|
|
@ -161,8 +180,8 @@ class ParserTfds(Parser):
|
|
|
|
|
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
|
|
|
|
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
|
|
|
|
# For single process case, it won't matter if workers return different batch sizes.
|
|
|
|
|
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
|
|
|
|
|
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
|
|
|
|
|
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
|
|
|
|
|
# approach is not optimal
|
|
|
|
|
yield img, sample['label'] # yield prev sample again
|
|
|
|
|
sample_count += 1
|
|
|
|
|
|
|
|
|
|