Fix tf.data options setting for newer TF versions

pull/637/head
Ross Wightman 4 years ago
parent 94d4b53352
commit d53e91218e

@ -25,8 +25,8 @@ from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch PREFETCH_SIZE = 2048 # samples to prefetch
def even_split_indices(split, n, num_samples): def even_split_indices(split, n, num_samples):
@ -144,14 +144,16 @@ class ParserTfds(Parser):
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) options = tf.data.Options()
ds.options().experimental_threading.max_intra_op_parallelism = 1 options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
options.experimental_threading.max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle: if self.shuffle:
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)

Loading…
Cancel
Save