diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 8fb1de14..132065be 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -162,6 +162,8 @@ class ParserTfds(Parser): self.worker_seed = 0 # seed unique to each work instance self.subsplit = None # set when data is distributed across workers using sub-splits self.ds = None # initialized lazily on each dataloader worker process + self.init_count = 0 + self.reinit_each_iter = False # self.is_training # FIXME evaluating shuffle across epochs def _lazy_init(self): """ Lazily initialize the dataset. @@ -174,12 +176,10 @@ class ParserTfds(Parser): instances once it has been initialized. Do not call any dataset methods that can call _lazy_init before it is passed to dataloader. """ - worker_info = torch.utils.data.get_worker_info() - # setup input context to split dataset across distributed processes - num_workers = 1 - global_worker_id = 0 - if worker_info is not None: + if self.worker_info is None: + worker_info = torch.utils.data.get_worker_info() + assert worker_info is not None self.worker_info = worker_info self.worker_seed = worker_info.seed num_workers = worker_info.num_workers @@ -209,6 +209,9 @@ class ParserTfds(Parser): else: subsplits = tfds.even_splits(self.split, self.global_num_workers) self.subsplit = subsplits[global_worker_id] + else: + num_workers = self.worker_info.num_workers + global_worker_id = self.dist_rank * num_workers + self.worker_info.id input_context = None if self.global_num_workers > 1 and self.subsplit is None: @@ -219,8 +222,8 @@ class ParserTfds(Parser): num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) read_config = tfds.ReadConfig( - shuffle_seed=self.common_seed, - shuffle_reshuffle_each_iteration=True, + shuffle_seed=self.common_seed + self.init_count, + shuffle_reshuffle_each_iteration=not self.reinit_each_iter, input_context=input_context) ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) @@ -235,12 +238,15 @@ class ParserTfds(Parser): # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.is_training: - ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.shuffle( + min(self.num_examples, self.shuffle_size) // self.global_num_workers, + seed=self.worker_seed + self.init_count) ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) + self.init_count += 1 def __iter__(self): - if self.ds is None: + if self.ds is None or self.reinit_each_iter: self._lazy_init() # Compute a rounded up sample count that is used to: