|
|
@ -140,6 +140,7 @@ def create_loader(
|
|
|
|
pin_memory=False,
|
|
|
|
pin_memory=False,
|
|
|
|
fp16=False,
|
|
|
|
fp16=False,
|
|
|
|
tf_preprocessing=False,
|
|
|
|
tf_preprocessing=False,
|
|
|
|
|
|
|
|
use_multi_epochs_loader=False
|
|
|
|
):
|
|
|
|
):
|
|
|
|
re_num_splits = 0
|
|
|
|
re_num_splits = 0
|
|
|
|
if re_split:
|
|
|
|
if re_split:
|
|
|
@ -175,7 +176,12 @@ def create_loader(
|
|
|
|
if collate_fn is None:
|
|
|
|
if collate_fn is None:
|
|
|
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
|
|
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
|
|
|
|
|
|
|
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
|
|
loader_class = torch.utils.data.DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_multi_epochs_loader:
|
|
|
|
|
|
|
|
loader_class = MultiEpochsDataLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loader = loader_class(
|
|
|
|
dataset,
|
|
|
|
dataset,
|
|
|
|
batch_size=batch_size,
|
|
|
|
batch_size=batch_size,
|
|
|
|
shuffle=sampler is None and is_training,
|
|
|
|
shuffle=sampler is None and is_training,
|
|
|
@ -198,3 +204,35 @@ def create_loader(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return loader
|
|
|
|
return loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
self._DataLoader__initialized = False
|
|
|
|
|
|
|
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
|
|
|
|
|
|
|
self._DataLoader__initialized = True
|
|
|
|
|
|
|
|
self.iterator = super().__iter__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
|
|
|
for i in range(len(self)):
|
|
|
|
|
|
|
|
yield next(self.iterator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _RepeatSampler(object):
|
|
|
|
|
|
|
|
""" Sampler that repeats forever.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
sampler (Sampler)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, sampler):
|
|
|
|
|
|
|
|
self.sampler = sampler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
|
|
yield from iter(self.sampler)
|
|
|
|