From 02ae11e526e2f58e0b4e7d16d5b9235aa37cf92b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 Jan 2022 22:33:09 -0800 Subject: [PATCH] Leaving repeat aug sampler indices as tensor thrashes worker shared process memory --- timm/data/distributed_sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py index 16090189..1cefc31d 100644 --- a/timm/data/distributed_sampler.py +++ b/timm/data/distributed_sampler.py @@ -108,11 +108,11 @@ class RepeatAugSampler(Sampler): indices = torch.arange(start=0, end=len(self.dataset)) # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] - indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0) + indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size > 0: - indices = torch.cat([indices, indices[:padding_size]], dim=0) + indices += indices[:padding_size] assert len(indices) == self.total_size # subsample per rank