From f8a63a3b7173edcf1086ab0a3c4d9226bfc40569 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:44:38 -0700 Subject: [PATCH] Add worker_init_fn to loader for numpy seed per worker --- timm/data/loader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 99cf132f..7d5aa1e5 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -125,6 +125,12 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x +def _worker_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + assert worker_info.id == worker_id + np.random.seed(worker_info.seed % (2**32-1)) + + def create_loader( dataset, input_size, @@ -202,7 +208,6 @@ def create_loader( collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate loader_class = torch.utils.data.DataLoader - if use_multi_epochs_loader: loader_class = MultiEpochsDataLoader @@ -214,6 +219,7 @@ def create_loader( collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, + worker_init_fn=_worker_init, persistent_workers=persistent_workers) try: loader = loader_class(dataset, **loader_args)