From a0b26574976c5ec1b6c498a9336e46707e0d1571 Mon Sep 17 00:00:00 2001 From: Hyeongchan Kim Date: Mon, 3 Jan 2022 07:01:06 +0900 Subject: [PATCH] Use `torch.repeat_interleave()` to generate repeated indices faster (#1058) * update: use numpy to generate repeated indices faster * update: use torch.repeat_interleave() instead of np.repeat() * refactor: remove unused import, numpy * refactor: torch.range to torch.arange * update: tensor to list before appending the extra samples * update: concatenate the paddings with torch.cat --- timm/data/distributed_sampler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py index fa403d0a..16090189 100644 --- a/timm/data/distributed_sampler.py +++ b/timm/data/distributed_sampler.py @@ -103,15 +103,16 @@ class RepeatAugSampler(Sampler): g = torch.Generator() g.manual_seed(self.epoch) if self.shuffle: - indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = torch.randperm(len(self.dataset), generator=g) else: - indices = list(range(len(self.dataset))) + indices = torch.arange(start=0, end=len(self.dataset)) # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] - indices = [x for x in indices for _ in range(self.num_repeats)] + indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0) # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) - indices += indices[:padding_size] + if padding_size > 0: + indices = torch.cat([indices, indices[:padding_size]], dim=0) assert len(indices) == self.total_size # subsample per rank @@ -125,4 +126,4 @@ class RepeatAugSampler(Sampler): return self.num_selected_samples def set_epoch(self, epoch): - self.epoch = epoch \ No newline at end of file + self.epoch = epoch