|
|
|
@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
@ -38,12 +39,12 @@ from .shared_count import SharedCount
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities
|
|
|
|
|
SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # examples to shuffle in DS queue
|
|
|
|
|
PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # examples to prefetch
|
|
|
|
|
SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # samples to shuffle in DS queue
|
|
|
|
|
PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # samples to prefetch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def even_split_indices(split, n, num_examples):
|
|
|
|
|
partitions = [round(i * num_examples / n) for i in range(n + 1)]
|
|
|
|
|
def even_split_indices(split, n, num_samples):
|
|
|
|
|
partitions = [round(i * num_samples / n) for i in range(n + 1)]
|
|
|
|
|
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -59,20 +60,20 @@ class ParserTfds(Parser):
|
|
|
|
|
""" Wrap Tensorflow Datasets for use in PyTorch
|
|
|
|
|
|
|
|
|
|
There several things to be aware of:
|
|
|
|
|
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
|
|
|
|
|
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
|
|
|
|
|
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
|
|
|
|
|
https://github.com/pytorch/pytorch/issues/33413
|
|
|
|
|
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
|
|
|
|
|
from each worker could be a different size. For training this is worked around by option above, for
|
|
|
|
|
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
|
|
|
|
|
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
|
|
|
|
|
across replicas are of same size. This will slightly alter the results, distributed validation will not be
|
|
|
|
|
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
|
|
|
|
|
since there are up to N * J extra examples with IterableDatasets.
|
|
|
|
|
since there are up to N * J extra samples with IterableDatasets.
|
|
|
|
|
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
|
|
|
|
|
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
|
|
|
|
|
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
|
|
|
|
|
benefit of distributed training or fast dataloading should be much less for small datasets.
|
|
|
|
|
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS
|
|
|
|
|
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
|
|
|
|
|
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
|
|
|
|
|
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
|
|
|
|
|
components.
|
|
|
|
@ -104,7 +105,7 @@ class ParserTfds(Parser):
|
|
|
|
|
name: tfds dataset name (eg `imagenet2012`)
|
|
|
|
|
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
|
|
|
|
|
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
|
|
|
|
|
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
|
|
|
|
|
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
|
|
|
|
|
download: download and build TFDS dataset if set, otherwise must use tfds CLI
|
|
|
|
|
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
|
|
|
|
|
seed: common seed for shard shuffle across all distributed/worker instances
|
|
|
|
@ -143,7 +144,7 @@ class ParserTfds(Parser):
|
|
|
|
|
self.builder.download_and_prepare()
|
|
|
|
|
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
|
|
|
|
|
self.split_info = self.builder.info.splits[split]
|
|
|
|
|
self.num_examples = self.split_info.num_examples
|
|
|
|
|
self.num_samples = self.split_info.num_examples
|
|
|
|
|
|
|
|
|
|
# Distributed world state
|
|
|
|
|
self.dist_rank = 0
|
|
|
|
@ -154,6 +155,7 @@ class ParserTfds(Parser):
|
|
|
|
|
|
|
|
|
|
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
|
|
|
|
|
self.global_num_workers = 1
|
|
|
|
|
self.num_workers = 1
|
|
|
|
|
self.worker_info = None
|
|
|
|
|
self.worker_seed = 0 # seed unique to each work instance
|
|
|
|
|
self.subsplit = None # set when data is distributed across workers using sub-splits
|
|
|
|
@ -167,6 +169,16 @@ class ParserTfds(Parser):
|
|
|
|
|
def set_epoch(self, count):
|
|
|
|
|
self.epoch_count.value = count
|
|
|
|
|
|
|
|
|
|
def set_loader_cfg(
|
|
|
|
|
self,
|
|
|
|
|
num_workers: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
if self.ds is not None:
|
|
|
|
|
return
|
|
|
|
|
if num_workers is not None:
|
|
|
|
|
self.num_workers = num_workers
|
|
|
|
|
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
|
|
|
|
|
|
|
|
|
def _lazy_init(self):
|
|
|
|
|
""" Lazily initialize the dataset.
|
|
|
|
|
|
|
|
|
@ -186,9 +198,9 @@ class ParserTfds(Parser):
|
|
|
|
|
if worker_info is not None:
|
|
|
|
|
self.worker_info = worker_info
|
|
|
|
|
self.worker_seed = worker_info.seed
|
|
|
|
|
num_workers = worker_info.num_workers
|
|
|
|
|
self.global_num_workers = self.dist_num_replicas * num_workers
|
|
|
|
|
global_worker_id = self.dist_rank * num_workers + worker_info.id
|
|
|
|
|
self.num_workers = worker_info.num_workers
|
|
|
|
|
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
|
|
|
|
global_worker_id = self.dist_rank * self.num_workers + worker_info.id
|
|
|
|
|
|
|
|
|
|
""" Data sharding
|
|
|
|
|
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
|
|
|
@ -198,17 +210,17 @@ class ParserTfds(Parser):
|
|
|
|
|
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
|
|
|
|
|
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
|
|
|
|
|
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
|
|
|
|
|
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
|
|
|
|
|
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
|
|
|
|
|
"""
|
|
|
|
|
should_subsplit = self.global_num_workers > 1 and (
|
|
|
|
|
self.split_info.num_shards < self.global_num_workers or not self.is_training)
|
|
|
|
|
if should_subsplit:
|
|
|
|
|
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal
|
|
|
|
|
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
|
|
|
|
|
# read patterns for distributed training (overlap across shards) so better to use InputContext there
|
|
|
|
|
if has_buggy_even_splits:
|
|
|
|
|
# my even_split workaround doesn't work on subsplits, upgrade tfds!
|
|
|
|
|
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
|
|
|
|
|
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
|
|
|
|
|
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
|
|
|
|
|
self.subsplit = subsplits[global_worker_id]
|
|
|
|
|
else:
|
|
|
|
|
subsplits = tfds.even_splits(self.split, self.global_num_workers)
|
|
|
|
@ -235,7 +247,7 @@ class ParserTfds(Parser):
|
|
|
|
|
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
|
|
|
|
|
options = tf.data.Options()
|
|
|
|
|
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
|
|
|
|
|
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
|
|
|
|
|
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
|
|
|
|
|
getattr(options, thread_member).max_intra_op_parallelism = 1
|
|
|
|
|
ds = ds.with_options(options)
|
|
|
|
|
if self.is_training or self.repeats > 1:
|
|
|
|
@ -243,60 +255,65 @@ class ParserTfds(Parser):
|
|
|
|
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
|
|
|
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
|
|
|
|
if self.is_training:
|
|
|
|
|
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
|
|
|
|
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
|
|
|
|
|
ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
|
|
|
|
|
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
|
|
|
|
|
self.ds = tfds.as_numpy(ds)
|
|
|
|
|
self.init_count += 1
|
|
|
|
|
|
|
|
|
|
def _num_samples_per_worker(self):
|
|
|
|
|
num_worker_samples = \
|
|
|
|
|
max(1, self.repeats) * self.num_samples / max(self.global_num_workers, self.dist_num_replicas)
|
|
|
|
|
if self.is_training or self.dist_num_replicas > 1:
|
|
|
|
|
num_worker_samples = math.ceil(num_worker_samples)
|
|
|
|
|
if self.is_training and self.batch_size is not None:
|
|
|
|
|
num_worker_samples = math.ceil(num_worker_samples / self.batch_size) * self.batch_size
|
|
|
|
|
return int(num_worker_samples)
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
if self.ds is None or self.reinit_each_iter:
|
|
|
|
|
self._lazy_init()
|
|
|
|
|
|
|
|
|
|
# Compute a rounded up sample count that is used to:
|
|
|
|
|
# 1. make batches even cross workers & replicas in distributed validation.
|
|
|
|
|
# This adds extra examples and will slightly alter validation results.
|
|
|
|
|
# This adds extra samples and will slightly alter validation results.
|
|
|
|
|
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
|
|
|
|
# batches are produced (underlying tfds iter wraps around)
|
|
|
|
|
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
|
|
|
|
|
if self.is_training:
|
|
|
|
|
# round up to nearest batch_size per worker-replica
|
|
|
|
|
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
|
|
|
|
|
target_sample_count = self._num_samples_per_worker()
|
|
|
|
|
|
|
|
|
|
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
|
|
|
|
|
example_count = 0
|
|
|
|
|
for example in self.ds:
|
|
|
|
|
input_data = example[self.input_name]
|
|
|
|
|
sample_count = 0
|
|
|
|
|
for sample in self.ds:
|
|
|
|
|
input_data = sample[self.input_name]
|
|
|
|
|
if self.input_img_mode:
|
|
|
|
|
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
|
|
|
|
|
target_data = example[self.target_name]
|
|
|
|
|
target_data = sample[self.target_name]
|
|
|
|
|
if self.target_img_mode:
|
|
|
|
|
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
|
|
|
|
|
yield input_data, target_data
|
|
|
|
|
example_count += 1
|
|
|
|
|
if self.is_training and example_count >= target_example_count:
|
|
|
|
|
sample_count += 1
|
|
|
|
|
if self.is_training and sample_count >= target_sample_count:
|
|
|
|
|
# Need to break out of loop when repeat() is enabled for training w/ oversampling
|
|
|
|
|
# this results in extra examples per epoch but seems more desirable than dropping
|
|
|
|
|
# this results in extra samples per epoch but seems more desirable than dropping
|
|
|
|
|
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Pad across distributed nodes (make counts equal by adding examples)
|
|
|
|
|
# Pad across distributed nodes (make counts equal by adding samples)
|
|
|
|
|
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
|
|
|
|
|
0 < example_count < target_example_count:
|
|
|
|
|
0 < sample_count < target_sample_count:
|
|
|
|
|
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
|
|
|
|
# For single process case, it won't matter if workers return different batch sizes.
|
|
|
|
|
# If using input_context or % based splits, sample count can vary significantly across workers and this
|
|
|
|
|
# approach should not be used (hence disabled if self.subsplit isn't set).
|
|
|
|
|
while example_count < target_example_count:
|
|
|
|
|
while sample_count < target_sample_count:
|
|
|
|
|
yield input_data, target_data # yield prev sample again
|
|
|
|
|
example_count += 1
|
|
|
|
|
sample_count += 1
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
# this is just an estimate and does not factor in extra examples added to pad batches based on
|
|
|
|
|
# complete worker & replica info (not available until init in dataloader).
|
|
|
|
|
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
|
|
|
|
|
num_samples = self._num_samples_per_worker() * self.num_workers
|
|
|
|
|
return num_samples
|
|
|
|
|
|
|
|
|
|
def _filename(self, index, basename=False, absolute=False):
|
|
|
|
|
assert False, "Not supported" # no random access to examples
|
|
|
|
|
assert False, "Not supported" # no random access to samples
|
|
|
|
|
|
|
|
|
|
def filenames(self, basename=False, absolute=False):
|
|
|
|
|
""" Return all filenames in dataset, overrides base"""
|
|
|
|
@ -304,7 +321,7 @@ class ParserTfds(Parser):
|
|
|
|
|
self._lazy_init()
|
|
|
|
|
names = []
|
|
|
|
|
for sample in self.ds:
|
|
|
|
|
if len(names) > self.num_examples:
|
|
|
|
|
if len(names) > self.num_samples:
|
|
|
|
|
break # safety for ds.repeat() case
|
|
|
|
|
if 'file_name' in sample:
|
|
|
|
|
name = sample['file_name']
|
|
|
|
|