diff --git a/timm/data/loader.py b/timm/data/loader.py index 9d87ed3f..7020deb7 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -316,12 +316,15 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._DataLoader__initialized = False - self.batch_sampler = _RepeatSampler(self.batch_sampler) + if self.batch_sampler is None: + self.sampler = _RepeatSampler(self.sampler) + else: + self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() def __len__(self): - return len(self.batch_sampler.sampler) + return len(self.sampler) if self.batch_sampler is None else len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)):