Leaving repeat aug sampler indices as tensor thrashes worker shared process memory

pull/659/merge
Ross Wightman 3 years ago
parent 4df51f3932
commit 02ae11e526

@ -108,11 +108,11 @@ class RepeatAugSampler(Sampler):
indices = torch.arange(start=0, end=len(self.dataset))
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0)
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size > 0:
indices = torch.cat([indices, indices[:padding_size]], dim=0)
indices += indices[:padding_size]
assert len(indices) == self.total_size
# subsample per rank

Loading…
Cancel
Save