|
|
|
@ -162,6 +162,8 @@ class ParserTfds(Parser):
|
|
|
|
|
self.worker_seed = 0 # seed unique to each work instance
|
|
|
|
|
self.subsplit = None # set when data is distributed across workers using sub-splits
|
|
|
|
|
self.ds = None # initialized lazily on each dataloader worker process
|
|
|
|
|
self.init_count = 0
|
|
|
|
|
self.reinit_each_iter = False # self.is_training # FIXME evaluating shuffle across epochs
|
|
|
|
|
|
|
|
|
|
def _lazy_init(self):
|
|
|
|
|
""" Lazily initialize the dataset.
|
|
|
|
@ -174,12 +176,10 @@ class ParserTfds(Parser):
|
|
|
|
|
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
|
|
|
|
|
before it is passed to dataloader.
|
|
|
|
|
"""
|
|
|
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
|
|
|
|
|
|
|
|
# setup input context to split dataset across distributed processes
|
|
|
|
|
num_workers = 1
|
|
|
|
|
global_worker_id = 0
|
|
|
|
|
if worker_info is not None:
|
|
|
|
|
if self.worker_info is None:
|
|
|
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
|
|
|
assert worker_info is not None
|
|
|
|
|
self.worker_info = worker_info
|
|
|
|
|
self.worker_seed = worker_info.seed
|
|
|
|
|
num_workers = worker_info.num_workers
|
|
|
|
@ -209,6 +209,9 @@ class ParserTfds(Parser):
|
|
|
|
|
else:
|
|
|
|
|
subsplits = tfds.even_splits(self.split, self.global_num_workers)
|
|
|
|
|
self.subsplit = subsplits[global_worker_id]
|
|
|
|
|
else:
|
|
|
|
|
num_workers = self.worker_info.num_workers
|
|
|
|
|
global_worker_id = self.dist_rank * num_workers + self.worker_info.id
|
|
|
|
|
|
|
|
|
|
input_context = None
|
|
|
|
|
if self.global_num_workers > 1 and self.subsplit is None:
|
|
|
|
@ -219,8 +222,8 @@ class ParserTfds(Parser):
|
|
|
|
|
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
|
|
|
|
)
|
|
|
|
|
read_config = tfds.ReadConfig(
|
|
|
|
|
shuffle_seed=self.common_seed,
|
|
|
|
|
shuffle_reshuffle_each_iteration=True,
|
|
|
|
|
shuffle_seed=self.common_seed + self.init_count,
|
|
|
|
|
shuffle_reshuffle_each_iteration=not self.reinit_each_iter,
|
|
|
|
|
input_context=input_context)
|
|
|
|
|
ds = self.builder.as_dataset(
|
|
|
|
|
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
|
|
|
|
@ -235,12 +238,15 @@ class ParserTfds(Parser):
|
|
|
|
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
|
|
|
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
|
|
|
|
if self.is_training:
|
|
|
|
|
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
|
|
|
|
ds = ds.shuffle(
|
|
|
|
|
min(self.num_examples, self.shuffle_size) // self.global_num_workers,
|
|
|
|
|
seed=self.worker_seed + self.init_count)
|
|
|
|
|
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
|
|
|
|
|
self.ds = tfds.as_numpy(ds)
|
|
|
|
|
self.init_count += 1
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
if self.ds is None:
|
|
|
|
|
if self.ds is None or self.reinit_each_iter:
|
|
|
|
|
self._lazy_init()
|
|
|
|
|
|
|
|
|
|
# Compute a rounded up sample count that is used to:
|
|
|
|
|