Set num_workers in Iterable WDS/TFDS datasets early so sample estimate is correct

pull/1479/head
Ross Wightman 2 years ago
parent 285771972e
commit f67a7ee8bd

@ -4,6 +4,7 @@ Hacked together by / Copyright 2019, Ross Wightman
""" """
import io import io
import logging import logging
from typing import Optional
import torch import torch
import torch.utils.data as data import torch.utils.data as data
@ -132,6 +133,14 @@ class IterableImageDataset(data.IterableDataset):
if hasattr(self.parser, 'set_epoch'): if hasattr(self.parser, 'set_epoch'):
self.parser.set_epoch(count) self.parser.set_epoch(count)
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
# TFDS and WDS readers need # workers for correct # samples estimate before loader processes created
if hasattr(self.parser, 'set_loader_cfg'):
self.parser.set_loader_cfg(num_workers=num_workers)
def filename(self, index, basename=False, absolute=False): def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().' assert False, 'Filename lookup by index not supported, use filenames().'

@ -16,12 +16,12 @@ import torch
import torch.utils.data import torch.utils.data
import numpy as np import numpy as np
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .dataset import IterableImageDataset
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from .random_erasing import RandomErasing from .random_erasing import RandomErasing
from .mixup import FastCollateMixup from .mixup import FastCollateMixup
from .transforms_factory import create_transform
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -248,6 +248,11 @@ def create_loader(
separate=num_aug_splits > 0, separate=num_aug_splits > 0,
) )
if isinstance(dataset, IterableImageDataset):
# give Iterable datasets early knowledge of num_workers so that sample estimates
# are correct before worker processes are launched
dataset.set_loader_cfg(num_workers=num_workers)
sampler = None sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training: if is_training:

@ -49,7 +49,7 @@ class ParserHfds(Parser):
self.class_to_idx = get_class_labels(self.dataset.info) self.class_to_idx = get_class_labels(self.dataset.info)
self.split_info = self.dataset.info.splits[split] self.split_info = self.dataset.info.splits[split]
self.num_examples = self.split_info.num_examples self.num_samples = self.split_info.num_examples
def __getitem__(self, index): def __getitem__(self, index):
item = self.dataset[index] item = self.dataset[index]

@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
import math import math
import os import os
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -38,12 +39,12 @@ from .shared_count import SharedCount
MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities
SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # examples to shuffle in DS queue SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # samples to shuffle in DS queue
PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # examples to prefetch PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # samples to prefetch
def even_split_indices(split, n, num_examples): def even_split_indices(split, n, num_samples):
partitions = [round(i * num_examples / n) for i in range(n + 1)] partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
@ -59,20 +60,20 @@ class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch """ Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of: There several things to be aware of:
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413 https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is worked around by option above, for from each worker could be a different size. For training this is worked around by option above, for
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightly alter the results, distributed validation will not be across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra examples with IterableDatasets. since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets. benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS * This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
components. components.
@ -104,7 +105,7 @@ class ParserTfds(Parser):
name: tfds dataset name (eg `imagenet2012`) name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances seed: common seed for shard shuffle across all distributed/worker instances
@ -143,7 +144,7 @@ class ParserTfds(Parser):
self.builder.download_and_prepare() self.builder.download_and_prepare()
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split] self.split_info = self.builder.info.splits[split]
self.num_examples = self.split_info.num_examples self.num_samples = self.split_info.num_examples
# Distributed world state # Distributed world state
self.dist_rank = 0 self.dist_rank = 0
@ -154,6 +155,7 @@ class ParserTfds(Parser):
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself # Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1 self.global_num_workers = 1
self.num_workers = 1
self.worker_info = None self.worker_info = None
self.worker_seed = 0 # seed unique to each work instance self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits self.subsplit = None # set when data is distributed across workers using sub-splits
@ -167,6 +169,16 @@ class ParserTfds(Parser):
def set_epoch(self, count): def set_epoch(self, count):
self.epoch_count.value = count self.epoch_count.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self): def _lazy_init(self):
""" Lazily initialize the dataset. """ Lazily initialize the dataset.
@ -186,9 +198,9 @@ class ParserTfds(Parser):
if worker_info is not None: if worker_info is not None:
self.worker_info = worker_info self.worker_info = worker_info
self.worker_seed = worker_info.seed self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * num_workers self.global_num_workers = self.dist_num_replicas * self.num_workers
global_worker_id = self.dist_rank * num_workers + worker_info.id global_worker_id = self.dist_rank * self.num_workers + worker_info.id
""" Data sharding """ Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
@ -198,17 +210,17 @@ class ParserTfds(Parser):
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding. for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
""" """
should_subsplit = self.global_num_workers > 1 and ( should_subsplit = self.global_num_workers > 1 and (
self.split_info.num_shards < self.global_num_workers or not self.is_training) self.split_info.num_shards < self.global_num_workers or not self.is_training)
if should_subsplit: if should_subsplit:
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal # split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there # read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits: if has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds! # my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples) subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = subsplits[global_worker_id] self.subsplit = subsplits[global_worker_id]
else: else:
subsplits = tfds.even_splits(self.split, self.global_num_workers) subsplits = tfds.even_splits(self.split, self.global_num_workers)
@ -235,7 +247,7 @@ class ParserTfds(Parser):
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers # avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options() options = tf.data.Options()
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1 getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options) ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
@ -243,60 +255,65 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training: if self.is_training:
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)
self.init_count += 1 self.init_count += 1
def _num_samples_per_worker(self):
num_worker_samples = \
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self): def __iter__(self):
if self.ds is None or self.reinit_each_iter: if self.ds is None or self.reinit_each_iter:
self._lazy_init() self._lazy_init()
# Compute a rounded up sample count that is used to: # Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation. # 1. make batches even cross workers & replicas in distributed validation.
# This adds extra examples and will slightly alter validation results. # This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around) # batches are produced (underlying tfds iter wraps around)
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers) target_sample_count = self._num_samples_per_worker()
if self.is_training:
# round up to nearest batch_size per worker-replica
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled) # Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
example_count = 0 sample_count = 0
for example in self.ds: for sample in self.ds:
input_data = example[self.input_name] input_data = sample[self.input_name]
if self.input_img_mode: if self.input_img_mode:
input_data = Image.fromarray(input_data, mode=self.input_img_mode) input_data = Image.fromarray(input_data, mode=self.input_img_mode)
target_data = example[self.target_name] target_data = sample[self.target_name]
if self.target_img_mode: if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_img_mode) target_data = Image.fromarray(target_data, mode=self.target_img_mode)
yield input_data, target_data yield input_data, target_data
example_count += 1 sample_count += 1
if self.is_training and example_count >= target_example_count: if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling # Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra examples per epoch but seems more desirable than dropping # this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break break
# Pad across distributed nodes (make counts equal by adding examples) # Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < example_count < target_example_count: 0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes. # Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes. # For single process case, it won't matter if workers return different batch sizes.
# If using input_context or % based splits, sample count can vary significantly across workers and this # If using input_context or % based splits, sample count can vary significantly across workers and this
# approach should not be used (hence disabled if self.subsplit isn't set). # approach should not be used (hence disabled if self.subsplit isn't set).
while example_count < target_example_count: while sample_count < target_sample_count:
yield input_data, target_data # yield prev sample again yield input_data, target_data # yield prev sample again
example_count += 1 sample_count += 1
def __len__(self): def __len__(self):
# this is just an estimate and does not factor in extra examples added to pad batches based on num_samples = self._num_samples_per_worker() * self.num_workers
# complete worker & replica info (not available until init in dataloader). return num_samples
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False): def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base""" """ Return all filenames in dataset, overrides base"""
@ -304,7 +321,7 @@ class ParserTfds(Parser):
self._lazy_init() self._lazy_init()
names = [] names = []
for sample in self.ds: for sample in self.ds:
if len(names) > self.num_examples: if len(names) > self.num_samples:
break # safety for ds.repeat() case break # safety for ds.repeat() case
if 'file_name' in sample: if 'file_name' in sample:
name = sample['file_name'] name = sample['file_name']

@ -12,7 +12,7 @@ import sys
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from itertools import islice from itertools import islice
from typing import Dict, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -345,6 +345,16 @@ class ParserWds(Parser):
def set_epoch(self, count): def set_epoch(self, count):
self.epoch_count.value = count self.epoch_count.value = count
def set_loader_cfg(
self,
num_workers: Optional[int] = None,
):
if self.ds is not None:
return
if num_workers is not None:
self.num_workers = num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
def _lazy_init(self): def _lazy_init(self):
""" Lazily initialize worker (in worker processes) """ Lazily initialize worker (in worker processes)
""" """
@ -396,25 +406,27 @@ class ParserWds(Parser):
for s in src: for s in src:
yield s yield s
def _num_samples_per_worker(self):
num_worker_samples = self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
if self.is_training or self.dist_num_replicas > 1:
num_worker_samples = math.ceil(num_worker_samples)
if self.is_training and self.batch_size is not None:
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
return int(num_worker_samples)
def __iter__(self): def __iter__(self):
if self.ds is None: if self.ds is None:
self._lazy_init() self._lazy_init()
if self.is_training: num_worker_samples = self._num_samples_per_worker()
num_worker_samples = math.floor(self.num_samples / self.global_num_workers) if self.is_training or self.dist_num_replicas > 1:
if self.batch_size is not None: # NOTE: doing distributed validation w/ WDS is messy, hard to meet constraints that
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size # same # of batches needed across all replicas w/ seeing each sample once.
# with_epoch() is simple but could miss a shard's worth of samples in some workers,
# and duplicate in others. Best to keep num DL workers low and a divisor of #val shards.
ds = self.ds.with_epoch(num_worker_samples) ds = self.ds.with_epoch(num_worker_samples)
else: else:
if self.dist_num_replicas > 1: ds = self.ds
# doing distributed validation w/ WDS is messy, hard to meet constraints that
# same # of batches needed across all replicas w/ seeing each sample once.
# with_epoch() is simple but could miss a shard's worth of samples in some workers,
# and duplicate in others. Best to keep num DL workers low and a divisor of #val shards.
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
ds = self.ds.with_epoch(num_worker_samples)
else:
ds = self.ds
i = 0 i = 0
_logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
@ -424,7 +436,8 @@ class ParserWds(Parser):
_logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
def __len__(self): def __len__(self):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) num_samples = self._num_samples_per_worker() * self.num_workers
return num_samples
def _filename(self, index, basename=False, absolute=False): def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples assert False, "Not supported" # no random access to examples

Loading…
Cancel
Save