|
|
|
@ -25,8 +25,8 @@ from .parser import Parser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
|
|
|
|
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
|
|
|
|
PREFETCH_SIZE = 4096 # samples to prefetch
|
|
|
|
|
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
|
|
|
|
|
PREFETCH_SIZE = 2048 # samples to prefetch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def even_split_indices(split, n, num_samples):
|
|
|
|
@ -144,14 +144,16 @@ class ParserTfds(Parser):
|
|
|
|
|
ds = self.builder.as_dataset(
|
|
|
|
|
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
|
|
|
|
|
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
|
|
|
|
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
|
|
|
|
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
|
|
|
|
options = tf.data.Options()
|
|
|
|
|
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
|
|
|
|
options.experimental_threading.max_intra_op_parallelism = 1
|
|
|
|
|
ds = ds.with_options(options)
|
|
|
|
|
if self.is_training or self.repeats > 1:
|
|
|
|
|
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
|
|
|
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
|
|
|
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
|
|
|
|
if self.shuffle:
|
|
|
|
|
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
|
|
|
|
|
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
|
|
|
|
|
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
|
|
|
|
self.ds = tfds.as_numpy(ds)
|
|
|
|
|
|
|
|
|
|