|
|
|
@ -52,7 +52,7 @@ class ParserTfds(Parser):
|
|
|
|
|
components.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
|
|
|
|
|
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.root = root
|
|
|
|
|
self.split = split
|
|
|
|
@ -62,6 +62,7 @@ class ParserTfds(Parser):
|
|
|
|
|
assert batch_size is not None,\
|
|
|
|
|
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
|
self.repeats = repeats
|
|
|
|
|
|
|
|
|
|
self.builder = tfds.builder(name, data_dir=root)
|
|
|
|
|
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
|
|
|
@ -126,7 +127,7 @@ class ParserTfds(Parser):
|
|
|
|
|
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
|
|
|
|
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
|
|
|
|
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
|
|
|
|
if self.is_training:
|
|
|
|
|
if self.is_training or self.repeats > 1:
|
|
|
|
|
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
|
|
|
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
|
|
|
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
|
|
|
@ -143,7 +144,7 @@ class ParserTfds(Parser):
|
|
|
|
|
# This adds extra samples and will slightly alter validation results.
|
|
|
|
|
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
|
|
|
|
# batches are produced (underlying tfds iter wraps around)
|
|
|
|
|
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
|
|
|
|
|
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
|
|
|
|
|
if self.is_training:
|
|
|
|
|
# round up to nearest batch_size per worker-replica
|
|
|
|
|
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
|
|
|
@ -176,7 +177,7 @@ class ParserTfds(Parser):
|
|
|
|
|
def __len__(self):
|
|
|
|
|
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
|
|
|
|
# complete worker & replica info (not available until init in dataloader).
|
|
|
|
|
return math.ceil(self.num_samples / self.dist_num_replicas)
|
|
|
|
|
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
|
|
|
|
|
|
|
|
|
def _filename(self, index, basename=False, absolute=False):
|
|
|
|
|
assert False, "Not supported" # no random access to samples
|
|
|
|
|