Testing TFDS shuffle across epochs

pull/1239/head
Ross Wightman 2 years ago
parent 69e90dcd8c
commit ff0f709c20

@ -162,6 +162,8 @@ class ParserTfds(Parser):
self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process
self.init_count = 0
self.reinit_each_iter = False # self.is_training # FIXME evaluating shuffle across epochs
def _lazy_init(self):
""" Lazily initialize the dataset.
@ -174,12 +176,10 @@ class ParserTfds(Parser):
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
before it is passed to dataloader.
"""
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
num_workers = 1
global_worker_id = 0
if worker_info is not None:
if self.worker_info is None:
worker_info = torch.utils.data.get_worker_info()
assert worker_info is not None
self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers
@ -209,6 +209,9 @@ class ParserTfds(Parser):
else:
subsplits = tfds.even_splits(self.split, self.global_num_workers)
self.subsplit = subsplits[global_worker_id]
else:
num_workers = self.worker_info.num_workers
global_worker_id = self.dist_rank * num_workers + self.worker_info.id
input_context = None
if self.global_num_workers > 1 and self.subsplit is None:
@ -219,8 +222,8 @@ class ParserTfds(Parser):
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
read_config = tfds.ReadConfig(
shuffle_seed=self.common_seed,
shuffle_reshuffle_each_iteration=True,
shuffle_seed=self.common_seed + self.init_count,
shuffle_reshuffle_each_iteration=not self.reinit_each_iter,
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
@ -235,12 +238,15 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training:
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.shuffle(
min(self.num_examples, self.shuffle_size) // self.global_num_workers,
seed=self.worker_seed + self.init_count)
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds)
self.init_count += 1
def __iter__(self):
if self.ds is None:
if self.ds is None or self.reinit_each_iter:
self._lazy_init()
# Compute a rounded up sample count that is used to:

Loading…
Cancel
Save