diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0b12a4db..2ff90b09 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -25,8 +25,8 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue -PREFETCH_SIZE = 4096 # samples to prefetch +SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +PREFETCH_SIZE = 2048 # samples to prefetch def even_split_indices(split, n, num_samples): @@ -144,14 +144,16 @@ class ParserTfds(Parser): ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers - ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - ds.options().experimental_threading.max_intra_op_parallelism = 1 + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + options.experimental_threading.max_intra_op_parallelism = 1 + ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.shuffle: - ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) self.ds = tfds.as_numpy(ds)