More TFDS parser cleanup, support improved TFDS even_split impl (on tfds-nightly only currently).

more_datasets
Ross Wightman 3 years ago
parent ba65dfe2c6
commit 9ec3210c2d

@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import io
import math
import torch
import torch.distributed as dist
@ -17,6 +15,13 @@ try:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds
try:
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
has_buggy_even_splits = False
except TypeError:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
has_buggy_even_splits = True
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
@ -25,7 +30,7 @@ from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue
PREFETCH_SIZE = 2048 # samples to prefetch
@ -58,8 +63,34 @@ class ParserTfds(Parser):
"""
def __init__(
self, root, name, split='train', is_training=False, batch_size=None,
download=False, repeats=0, seed=42):
self,
root,
name,
split='train',
is_training=False,
batch_size=None,
download=False,
repeats=0,
seed=42,
prefetch_size=None,
shuffle_size=None,
max_threadpool_size=None
):
""" Tensorflow-datasets Wrapper
Args:
root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size
max_threadpool_size: override default threadpool size for tf.data
"""
super().__init__()
self.root = root
self.split = split
@ -69,25 +100,33 @@ class ParserTfds(Parser):
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # seed across all worker / dist nodes
self.worker_seed = 0 # seed specific to each work instance
self.subsplit = None
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.prefetch_size = prefetch_size or PREFETCH_SIZE
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
# TFDS builder and split information
self.builder = tfds.builder(name, data_dir=root)
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
self.ds = None # initialized lazily on each dataloader worker process
self.worker_info = None
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1
self.worker_info = None
self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process
def _lazy_init(self):
""" Lazily initialize the dataset.
@ -102,38 +141,44 @@ class ParserTfds(Parser):
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
global_num_workers = num_workers = 1
global_worker_id = 1
num_workers = 1
global_worker_id = 0
if worker_info is not None:
self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id
global_worker_id = self.dist_rank * num_workers + worker_id
self.global_num_workers = self.dist_num_replicas * num_workers
global_worker_id = self.dist_rank * num_workers + worker_info.id
# FIXME verify best sharding approach
""" Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong.
Possible split options include:
* InputContext for both distributed & worker processes (current)
* InputContext for distributed and sub-splits for worker processes
* sub-splits for both
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
"""
can_subsplit = '[' not in self.split # can't subsplit a subsplit
should_subsplit = global_num_workers > 1 and (
self.split_info.num_shards < global_num_workers or not self.is_training)
if can_subsplit and should_subsplit:
# manually split the dataset w/o sharding for more even samples / worker
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id]
should_subsplit = self.global_num_workers > 1 and (
self.split_info.num_shards < self.global_num_workers or not self.is_training)
if should_subsplit:
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = subsplits[global_worker_id]
else:
subsplits = tfds.even_splits(self.split, self.global_num_workers)
self.subsplit = subsplits[global_worker_id]
input_context = None
if global_num_workers > 1 and self.subsplit is None:
if self.global_num_workers > 1 and self.subsplit is None:
# set input context to divide shards among distributed replicas
input_context = tf.distribute.InputContext(
num_input_pipelines=global_num_workers,
num_input_pipelines=self.global_num_workers,
input_pipeline_id=global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
@ -143,10 +188,10 @@ class ParserTfds(Parser):
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options()
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1:
@ -154,22 +199,25 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds)
def __iter__(self):
if self.ds is None:
self._lazy_init()
# compute a rounded up sample count that is used to:
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
sample_count = 0
for sample in self.ds:
img = Image.fromarray(sample['image'], mode='RGB')
@ -180,21 +228,17 @@ class ParserTfds(Parser):
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
# approach is not optimal
yield img, sample['label'] # yield prev sample again
sample_count += 1
@property
def _num_workers(self):
return 1 if self.worker_info is None else self.worker_info.num_workers
@property
def _num_pipelines(self):
return self._num_workers * self.dist_num_replicas
# If using input_context or % based splits, sample count can vary significantly across workers and this
# approach should not be used (hence disabled if self.subsplit isn't set).
while sample_count < target_sample_count:
yield img, sample['label'] # yield prev sample again
sample_count += 1
def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on
@ -202,7 +246,7 @@ class ParserTfds(Parser):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""

Loading…
Cancel
Save