diff --git a/timm/data/loader.py b/timm/data/loader.py index 359cf315..bed9a254 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -231,9 +231,9 @@ class MultiEpochsDataLoader(torch.utils.data.DataLoader): super().__init__(*args, **kwargs) self._DataLoader__initialized = False if self.batch_sampler is None: - self.sampler = RepeatSampler(self.sampler) + self.sampler = _RepeatSampler(self.sampler) else: - self.batch_sampler = RepeatSampler(self.batch_sampler) + self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__()