From 2d0ff2f444b849ccc885954fd3033f4e2a82b803 Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Fri, 7 Aug 2020 16:28:48 -0400 Subject: [PATCH 1/2] Fix MultiEpochsDataLoader when not all is consumed --- timm/data/loader.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 317f77df..59d62701 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -234,11 +234,23 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader): self._DataLoader__initialized = True self.iterator = super().__iter__() + self._items_to_consume = 0 + self._last_consumed_item = -1 + def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): - for i in range(len(self)): + # If a previous call to this function didn't consume all the corresponding items, the next call it's gonna + # continue the previous one, which is incorrect. So we consume the remaining items. + # It's inefficient but I can't come up with a better way to do this. + # Probably avoiding this class completely is more efficient if you're skipping a bunch of items in every epoch. + for _ in range(self._items_to_consume - 1 - self._last_consumed_item): + next(self.iterator) + + self._items_to_consume = len(self) + for i in range(self._items_to_consume): + self._last_consumed_item = i yield next(self.iterator) From 417b8a7bb7e7378b839aa99c8cb3a08b79c36ff6 Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Fri, 7 Aug 2020 17:08:50 -0400 Subject: [PATCH 2/2] Add a warning --- timm/data/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/timm/data/loader.py b/timm/data/loader.py index 59d62701..ece3bb9d 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -5,6 +5,7 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d Hacked together by / Copyright 2020 Ross Wightman """ +import warnings import torch.utils.data import numpy as np @@ -246,6 +247,8 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader): # It's inefficient but I can't come up with a better way to do this. # Probably avoiding this class completely is more efficient if you're skipping a bunch of items in every epoch. for _ in range(self._items_to_consume - 1 - self._last_consumed_item): + warnings.warn("Consuming the rest of the sampler items from a previous call. " + "Consider not using MultiEpochsDataLoader as it may take a lot of time.") next(self.iterator) self._items_to_consume = len(self)