|
|
|
@ -125,6 +125,12 @@ class PrefetchLoader:
|
|
|
|
|
self.loader.collate_fn.mixup_enabled = x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _worker_init(worker_id):
|
|
|
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
|
|
|
assert worker_info.id == worker_id
|
|
|
|
|
np.random.seed(worker_info.seed % (2**32-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_loader(
|
|
|
|
|
dataset,
|
|
|
|
|
input_size,
|
|
|
|
@ -202,7 +208,6 @@ def create_loader(
|
|
|
|
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
|
|
|
|
|
|
|
|
|
loader_class = torch.utils.data.DataLoader
|
|
|
|
|
|
|
|
|
|
if use_multi_epochs_loader:
|
|
|
|
|
loader_class = MultiEpochsDataLoader
|
|
|
|
|
|
|
|
|
@ -214,6 +219,7 @@ def create_loader(
|
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
drop_last=is_training,
|
|
|
|
|
worker_init_fn=_worker_init,
|
|
|
|
|
persistent_workers=persistent_workers)
|
|
|
|
|
try:
|
|
|
|
|
loader = loader_class(dataset, **loader_args)
|
|
|
|
|