diff --git a/data/distributed_sampler.py b/data/distributed_sampler.py new file mode 100644 index 00000000..9506a880 --- /dev/null +++ b/data/distributed_sampler.py @@ -0,0 +1,51 @@ +import math +import torch +from torch.utils.data import Sampler +import torch.distributed as dist + + +class OrderedDistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/data/loader.py b/data/loader.py index af900c86..2402b093 100644 --- a/data/loader.py +++ b/data/loader.py @@ -2,6 +2,7 @@ import torch import torch.utils.data from data.random_erasing import RandomErasingTorch from data.transforms import * +from data.distributed_sampler import OrderedDistributedSampler def fast_collate(batch): @@ -102,10 +103,12 @@ def create_loader( sampler = None if distributed: - # FIXME note, doing this for validation isn't technically correct - # There currently is no fixed order distributed sampler that corrects - # for padded entries - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + if is_training: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + # This will add extra duplicate entries to result in equal num + # of samples per-process, will slightly alter validation results + sampler = OrderedDistributedSampler(dataset) loader = torch.utils.data.DataLoader( dataset,