Fix MultiEpochsDataLoader when not all is consumed

pull/213/head
Santiago Castro 5 years ago committed by GitHub
parent 6e9d6172c8
commit 2d0ff2f444
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -234,11 +234,23 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
self._DataLoader__initialized = True
self.iterator = super().__iter__()
self._items_to_consume = 0
self._last_consumed_item = -1
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
# If a previous call to this function didn't consume all the corresponding items, the next call it's gonna
# continue the previous one, which is incorrect. So we consume the remaining items.
# It's inefficient but I can't come up with a better way to do this.
# Probably avoiding this class completely is more efficient if you're skipping a bunch of items in every epoch.
for _ in range(self._items_to_consume - 1 - self._last_consumed_item):
next(self.iterator)
self._items_to_consume = len(self)
for i in range(self._items_to_consume):
self._last_consumed_item = i
yield next(self.iterator)

Loading…
Cancel
Save