diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py index fa403d0a..16090189 100644 --- a/timm/data/distributed_sampler.py +++ b/timm/data/distributed_sampler.py @@ -103,15 +103,16 @@ class RepeatAugSampler(Sampler): g = torch.Generator() g.manual_seed(self.epoch) if self.shuffle: - indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = torch.randperm(len(self.dataset), generator=g) else: - indices = list(range(len(self.dataset))) + indices = torch.arange(start=0, end=len(self.dataset)) # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] - indices = [x for x in indices for _ in range(self.num_repeats)] + indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0) # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) - indices += indices[:padding_size] + if padding_size > 0: + indices = torch.cat([indices, indices[:padding_size]], dim=0) assert len(indices) == self.total_size # subsample per rank @@ -125,4 +126,4 @@ class RepeatAugSampler(Sampler): return self.num_selected_samples def set_epoch(self, epoch): - self.epoch = epoch \ No newline at end of file + self.epoch = epoch