pull/212/merge
Santiago Castro 3 years ago committed by GitHub
commit eab434babe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -316,12 +316,15 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._DataLoader__initialized = False self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler) if self.batch_sampler is None:
self.sampler = _RepeatSampler(self.sampler)
else:
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True self._DataLoader__initialized = True
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
return len(self.batch_sampler.sampler) return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler)
def __iter__(self): def __iter__(self):
for i in range(len(self)): for i in range(len(self)):

Loading…
Cancel
Save