From 7de69cd325230a3dd68e5b9921f3c9832f6a731c Mon Sep 17 00:00:00 2001 From: Sander van Leeuwen Date: Fri, 23 Apr 2021 11:34:33 +0200 Subject: [PATCH] Added support for distributing workload among GPUs of diffrent generations (speed/memory) --- timm/data/distributed_sampler.py | 108 +++++++++++++++++++++++++++++++ timm/data/loader.py | 10 ++- train.py | 7 +- 3 files changed, 121 insertions(+), 4 deletions(-) diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py index 9506a880..a005939f 100644 --- a/timm/data/distributed_sampler.py +++ b/timm/data/distributed_sampler.py @@ -49,3 +49,111 @@ class OrderedDistributedSampler(Sampler): def __len__(self): return self.num_samples + + +class VariableDistributedSampler(Sampler): + """Sampler that distributes the dataset to each GPU according to the workload specified by the callery. + It adjusts the dataset slice and batch size. + Note: Sampling now occurs in slices of the dataset; no longer by stepping through it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + gpu_load: GPU workload distribution list + batch_size: Average batch size for the overall system + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + """ + + def __init__(self, dataset, gpu_load, batch_size, shuffle = True, seed = 0): + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + + world_size = dist.get_world_size() + rank = dist.get_rank() + + if (len(gpu_load) != world_size): + raise ValueError("Number of gpu_load entries not equal to world size") + + if (sum(gpu_load) != world_size): + raise ValueError("Total gpu_load weights not equal to world size") + + self.dataset = dataset + self.num_replicas = world_size + self.rank = rank + self.epoch = 0 + + self.num_samples = [None for _ in range(world_size)] + self.index_offset = [None for _ in range(world_size)] + self.batch_size = [None for _ in range(world_size)] + self.num_batches = [None for _ in range(world_size)] + + # calculate the dataset slice size for each GPU + for i in range(world_size): + self.num_samples[i] = int(math.ceil(len(self.dataset) / self.num_replicas * gpu_load[i])) + self.batch_size[i] = int(math.ceil(batch_size * gpu_load[i])) + self.num_batches[i] = int(math.ceil(self.num_samples[i] / self.batch_size[i])) + + for i in range(1, world_size): + if (self.num_batches[i] != self.num_batches[i-1]): + raise ValueError("Number of batches mismatch: ", self.num_batches) + + # calculcate the dataset offset of each GPU slice + self.index_offset[0] = 0 + for i in range(1, world_size): + self.index_offset[i] = self.index_offset[i-1] + self.num_samples[i-1] + + self.total_size = sum(self.num_samples) + + if (rank == 0): + print('VariableDistributedSampler: Number of samples: ', self.num_samples) + print('VariableDistributedSampler: Index offsets : ', self.index_offset) + print('VariableDistributedSampler: Batch sizes : ', self.batch_size) + print('VariableDistributedSampler: Number of batches: ', self.num_batches) + + self.shuffle = shuffle + self.seed = seed + + def get_batch_size(self): + return self.batch_size[self.rank] + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + assert len(indices) == self.total_size + + # subsample + #indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.index_offset[self.rank]:self.index_offset[self.rank] + self.num_samples[self.rank]] + assert len(indices) == self.num_samples[self.rank] + + return iter(indices) + + def __len__(self): + return self.num_samples[self.rank] + + def set_epoch(self, epoch: int): + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + \ No newline at end of file diff --git a/timm/data/loader.py b/timm/data/loader.py index 76144669..b0ba4bde 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -11,7 +11,7 @@ import numpy as np from .transforms_factory import create_transform from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .distributed_sampler import OrderedDistributedSampler +from .distributed_sampler import OrderedDistributedSampler, VariableDistributedSampler from .random_erasing import RandomErasing from .mixup import FastCollateMixup @@ -155,6 +155,7 @@ def create_loader( tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, + gpu_load=None ): re_num_splits = 0 if re_split: @@ -186,7 +187,12 @@ def create_loader( sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + if gpu_load != None: + # split up the dataset according to the workload + sampler = VariableDistributedSampler(dataset, batch_size=batch_size, gpu_load=gpu_load) + batch_size = sampler.get_batch_size() + else: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: # This will add extra duplicate entries to result in equal num # of samples per-process, will slightly alter validation results diff --git a/train.py b/train.py index 4264a164..3465add0 100755 --- a/train.py +++ b/train.py @@ -279,7 +279,8 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') - +parser.add_argument('--gpu-load', nargs="*", type=float, default=None, # None equals 1.0 1.0 1.0 1.0 for a 4 GPU system + help='Distribute workload unevenly to GPUs with different performance levels. E.g. --gpu-speed 1.2 1.2 0.8 0.8 for a 4 GPU system (2xA6000 + 2x2080ti)') def _parse_args(): # Do we have a config file to parse? @@ -524,7 +525,8 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, - use_multi_epochs_loader=args.use_multi_epochs_loader + use_multi_epochs_loader=args.use_multi_epochs_loader, + gpu_load=args.gpu_load ) loader_eval = create_loader( @@ -540,6 +542,7 @@ def main(): distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, + gpu_load=args.gpu_load ) # setup loss function