|
|
|
@ -5,6 +5,7 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
import torch.utils.data
|
|
|
|
|
import numpy as np
|
|
|
|
@ -245,11 +246,25 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
|
|
|
|
self._DataLoader__initialized = True
|
|
|
|
|
self.iterator = super().__iter__()
|
|
|
|
|
|
|
|
|
|
self._items_to_consume = 0
|
|
|
|
|
self._last_consumed_item = -1
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
for i in range(len(self)):
|
|
|
|
|
# If a previous call to this function didn't consume all the corresponding items, the next call it's gonna
|
|
|
|
|
# continue the previous one, which is incorrect. So we consume the remaining items.
|
|
|
|
|
# It's inefficient but I can't come up with a better way to do this.
|
|
|
|
|
# Probably avoiding this class completely is more efficient if you're skipping a bunch of items in every epoch.
|
|
|
|
|
for _ in range(self._items_to_consume - 1 - self._last_consumed_item):
|
|
|
|
|
warnings.warn("Consuming the rest of the sampler items from a previous call. "
|
|
|
|
|
"Consider not using MultiEpochsDataLoader as it may take a lot of time.")
|
|
|
|
|
next(self.iterator)
|
|
|
|
|
|
|
|
|
|
self._items_to_consume = len(self)
|
|
|
|
|
for i in range(self._items_to_consume):
|
|
|
|
|
self._last_consumed_item = i
|
|
|
|
|
yield next(self.iterator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|