|
|
|
@ -230,12 +230,15 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self._DataLoader__initialized = False
|
|
|
|
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
|
|
|
|
if self.batch_sampler is None:
|
|
|
|
|
self.sampler = RepeatSampler(self.sampler)
|
|
|
|
|
else:
|
|
|
|
|
self.batch_sampler = RepeatSampler(self.batch_sampler)
|
|
|
|
|
self._DataLoader__initialized = True
|
|
|
|
|
self.iterator = super().__iter__()
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
|
return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
for i in range(len(self)):
|
|
|
|
|