@ -103,15 +103,16 @@ class RepeatAugSampler(Sampler):
g = torch . Generator ( )
g = torch . Generator ( )
g . manual_seed ( self . epoch )
g . manual_seed ( self . epoch )
if self . shuffle :
if self . shuffle :
indices = torch . randperm ( len ( self . dataset ) , generator = g ) . tolist ( )
indices = torch . randperm ( len ( self . dataset ) , generator = g )
else :
else :
indices = list ( range ( len ( self . dataset ) ) )
indices = torch . arange ( start = 0 , end = len ( self . dataset ) )
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = [ x for x in indices for _ in range ( self . num_repeats ) ]
indices = torch . repeat_interleave ( indices , repeats = self . num_repeats , dim = 0 )
# add extra samples to make it evenly divisible
# add extra samples to make it evenly divisible
padding_size = self . total_size - len ( indices )
padding_size = self . total_size - len ( indices )
indices + = indices [ : padding_size ]
if padding_size > 0 :
indices = torch . cat ( [ indices , indices [ : padding_size ] ] , dim = 0 )
assert len ( indices ) == self . total_size
assert len ( indices ) == self . total_size
# subsample per rank
# subsample per rank
@ -125,4 +126,4 @@ class RepeatAugSampler(Sampler):
return self . num_selected_samples
return self . num_selected_samples
def set_epoch ( self , epoch ) :
def set_epoch ( self , epoch ) :
self . epoch = epoch
self . epoch = epoch