From feaa3abc515a51ce3b7545e8d34d52cbab24b1b6 Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Fri, 7 Aug 2020 16:25:14 -0400 Subject: [PATCH 1/2] Fix MultiEpochsDataLoader when there's no batching --- timm/data/loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 317f77df..359cf315 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -230,12 +230,15 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._DataLoader__initialized = False - self.batch_sampler = _RepeatSampler(self.batch_sampler) + if self.batch_sampler is None: + self.sampler = RepeatSampler(self.sampler) + else: + self.batch_sampler = RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() def __len__(self): - return len(self.batch_sampler.sampler) + return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): From 0f75d48a6685e29a52e93b2764fb6c44515f553b Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Fri, 7 Aug 2020 16:26:00 -0400 Subject: [PATCH 2/2] Fix class name --- timm/data/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 359cf315..bed9a254 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -231,9 +231,9 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader): super().__init__(*args, **kwargs) self._DataLoader__initialized = False if self.batch_sampler is None: - self.sampler = RepeatSampler(self.sampler) + self.sampler = _RepeatSampler(self.sampler) else: - self.batch_sampler = RepeatSampler(self.batch_sampler) + self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__()