Testing TFDS shuffle across epochs

pull/1239/head
Ross Wightman 2 years ago
parent 69e90dcd8c
commit ff0f709c20

@ -162,6 +162,8 @@ class ParserTfds(Parser):
self.worker_seed = 0 # seed unique to each work instance self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process self.ds = None # initialized lazily on each dataloader worker process
self.init_count = 0
self.reinit_each_iter = False # self.is_training # FIXME evaluating shuffle across epochs
def _lazy_init(self): def _lazy_init(self):
""" Lazily initialize the dataset. """ Lazily initialize the dataset.
@ -174,12 +176,10 @@ class ParserTfds(Parser):
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
before it is passed to dataloader. before it is passed to dataloader.
""" """
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes # setup input context to split dataset across distributed processes
num_workers = 1 if self.worker_info is None:
global_worker_id = 0 worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: assert worker_info is not None
self.worker_info = worker_info self.worker_info = worker_info
self.worker_seed = worker_info.seed self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
@ -209,6 +209,9 @@ class ParserTfds(Parser):
else: else:
subsplits = tfds.even_splits(self.split, self.global_num_workers) subsplits = tfds.even_splits(self.split, self.global_num_workers)
self.subsplit = subsplits[global_worker_id] self.subsplit = subsplits[global_worker_id]
else:
num_workers = self.worker_info.num_workers
global_worker_id = self.dist_rank * num_workers + self.worker_info.id
input_context = None input_context = None
if self.global_num_workers > 1 and self.subsplit is None: if self.global_num_workers > 1 and self.subsplit is None:
@ -219,8 +222,8 @@ class ParserTfds(Parser):
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
) )
read_config = tfds.ReadConfig( read_config = tfds.ReadConfig(
shuffle_seed=self.common_seed, shuffle_seed=self.common_seed + self.init_count,
shuffle_reshuffle_each_iteration=True, shuffle_reshuffle_each_iteration=not self.reinit_each_iter,
input_context=input_context) input_context=input_context)
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
@ -235,12 +238,15 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training: if self.is_training:
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) ds = ds.shuffle(
min(self.num_examples, self.shuffle_size) // self.global_num_workers,
seed=self.worker_seed + self.init_count)
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)
self.init_count += 1
def __iter__(self): def __iter__(self):
if self.ds is None: if self.ds is None or self.reinit_each_iter:
self._lazy_init() self._lazy_init()
# Compute a rounded up sample count that is used to: # Compute a rounded up sample count that is used to:

Loading…
Cancel
Save