Add worker_init_fn to loader for numpy seed per worker

pull/880/head
Ross Wightman 3 years ago
parent 515121cca1
commit f8a63a3b71

@ -125,6 +125,12 @@ class PrefetchLoader:
self.loader.collate_fn.mixup_enabled = x self.loader.collate_fn.mixup_enabled = x
def _worker_init(worker_id):
worker_info = torch.utils.data.get_worker_info()
assert worker_info.id == worker_id
np.random.seed(worker_info.seed % (2**32-1))
def create_loader( def create_loader(
dataset, dataset,
input_size, input_size,
@ -202,7 +208,6 @@ def create_loader(
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader_class = torch.utils.data.DataLoader loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader: if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader loader_class = MultiEpochsDataLoader
@ -214,6 +219,7 @@ def create_loader(
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=pin_memory, pin_memory=pin_memory,
drop_last=is_training, drop_last=is_training,
worker_init_fn=_worker_init,
persistent_workers=persistent_workers) persistent_workers=persistent_workers)
try: try:
loader = loader_class(dataset, **loader_args) loader = loader_class(dataset, **loader_args)

Loading…
Cancel
Save