@ -27,36 +27,44 @@ except ImportError as e:
exit ( 1 )
from . parser import Parser
from timm . bits import get_global_device
from timm . bits import get_global_device , is_global_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16384 # s amples to shuffle in DS queue
PREFETCH_SIZE = 2048 # s amples to prefetch
SHUFFLE_SIZE = 8192 # ex amples to shuffle in DS queue
PREFETCH_SIZE = 2048 # ex amples to prefetch
def even_split_indices ( split , n , num_ s amples) :
partitions = [ round ( i * num_ s amples / n ) for i in range ( n + 1 ) ]
def even_split_indices ( split , n , num_ ex amples) :
partitions = [ round ( i * num_ ex amples / n ) for i in range ( n + 1 ) ]
return [ f " { split } [ { partitions [ i ] } : { partitions [ i + 1 ] } ] " for i in range ( n ) ]
def get_class_labels ( info ) :
if ' label ' not in info . features :
return { }
class_label = info . features [ ' label ' ]
class_to_idx = { n : class_label . str2int ( n ) for n in class_label . names }
return class_to_idx
class ParserTfds ( Parser ) :
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of :
* To prevent excessive samples being dropped per epoch w / distributed training or multiplicity of
* To prevent excessive ex amples being dropped per epoch w / distributed training or multiplicity of
dataloader workers , the train iterator wraps to avoid returning partial batches that trigger drop_last
https : / / github . com / pytorch / pytorch / issues / 33413
* With PyTorch IterableDatasets , each worker in each replica operates in isolation , the final batch
from each worker could be a different size . For training this is worked around by option above , for
validation extra s amples are inserted iff distributed mode is enabled so that the batches being reduced
validation extra ex amples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size . This will slightly alter the results , distributed validation will not be
100 % correct . This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra s amples with IterableDatasets .
since there are up to N * J extra ex amples with IterableDatasets .
* The sharding ( splitting of dataset into TFRecord ) files imposes limitations on the number of
replicas and dataloader workers you can use . For really small datasets that only contain a few shards
you may have to train non - distributed w / 1 - 2 dataloader workers . This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets .
* This wrapper is currently configured to return individual , decompressed image s amples from the TFDS
* This wrapper is currently configured to return individual , decompressed image ex amples from the TFDS
dataset . The augmentation ( transforms ) and batching is still done in PyTorch . It would be possible
to specify TF augmentation fn and return augmented batches w / some modifications to other downstream
components .
@ -72,6 +80,10 @@ class ParserTfds(Parser):
download = False ,
repeats = 0 ,
seed = 42 ,
input_name = ' image ' ,
input_image = ' RGB ' ,
target_name = ' label ' ,
target_image = ' ' ,
prefetch_size = None ,
shuffle_size = None ,
max_threadpool_size = None
@ -83,10 +95,13 @@ class ParserTfds(Parser):
name : tfds dataset name ( eg ` imagenet2012 ` )
split : tfds dataset split ( can use all TFDS split strings eg ` train [ : 10 % ] ` )
is_training : training mode , shuffle enabled , dataset len rounded by batch_size
batch_size : batch_size to use to unsure total s amples % batch_size == 0 in training across all dis nodes
batch_size : batch_size to use to unsure total ex amples % batch_size == 0 in training across all dis nodes
download : download and build TFDS dataset if set , otherwise must use tfds CLI
repeats : iterate through ( repeat ) the dataset this many times per iteration ( once if 0 or 1 )
seed : common seed for shard shuffle across all distributed / worker instances
input_image : image mode if input is an image ( currently PIL mode string )
target_name : name of Feature to return as target ( label )
target_image : image mode if target is an image ( currently PIL mode string )
prefetch_size : override default tf . data prefetch buffer size
shuffle_size : override default tf . data shuffle buffer size
max_threadpool_size : override default threadpool size for tf . data
@ -101,25 +116,39 @@ class ParserTfds(Parser):
self . batch_size = batch_size
self . repeats = repeats
self . common_seed = seed # a seed that's fixed across all worker / distributed instances
# Performance settings
self . prefetch_size = prefetch_size or PREFETCH_SIZE
self . shuffle_size = shuffle_size or SHUFFLE_SIZE
self . max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
# TFDS builder and split information
self . input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
self . input_image = input_image
self . target_name = target_name
self . target_image = target_image
self . builder = tfds . builder ( name , data_dir = root )
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download :
self . builder . download_and_prepare ( )
self . class_to_idx = get_class_labels ( self . builder . info ) if self . target_name == ' label ' else { }
self . split_info = self . builder . info . splits [ split ]
self . num_ s amples = self . split_info . num_examples
self . num_ ex amples = self . split_info . num_examples
# Distributed world state
self . dist_rank = 0
self . dist_num_replicas = 1
dev_env = get_global_device ( ) # FIXME allow to work without devenv usage?
if dev_env . distributed and dev_env . world_size > 1 :
self . dist_rank = dev_env . global_rank
self . dist_num_replicas = dev_env . world_size
if is_global_device ( ) :
dev_env = get_global_device ( )
if dev_env . distributed and dev_env . world_size > 1 :
self . dist_rank = dev_env . global_rank
self . dist_num_replicas = dev_env . world_size
else :
# FIXME warn if we fallback to torch distributed?
import torch . distributed as dist
if dist . is_available ( ) and dist . is_initialized ( ) and dist . get_world_size ( ) > 1 :
self . dist_rank = dist . get_rank ( )
self . dist_num_replicas = dist . get_world_size ( )
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self . global_num_workers = 1
@ -159,17 +188,17 @@ class ParserTfds(Parser):
I am currently using a mix of InputContext shard assignment and fine - grained sub - splits for distributing
the data across workers . For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers . Otherwise sub - split API is used for datasets without enough shards or
for validation where we can ' t drop s amples and need to avoid minimize uneven splits to avoid padding.
for validation where we can ' t drop ex amples and need to avoid minimize uneven splits to avoid padding.
"""
should_subsplit = self . global_num_workers > 1 and (
self . split_info . num_shards < self . global_num_workers or not self . is_training )
if should_subsplit :
# split the dataset w/o using sharding for more even s amples / worker, can result in less optimal
# split the dataset w/o using sharding for more even ex amples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits :
# my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance ( self . split_info , tfds . core . splits . SubSplitInfo ) :
subsplits = even_split_indices ( self . split , self . global_num_workers , self . num_ s amples)
subsplits = even_split_indices ( self . split , self . global_num_workers , self . num_ ex amples)
self . subsplit = subsplits [ global_worker_id ]
else :
subsplits = tfds . even_splits ( self . split , self . global_num_workers )
@ -200,8 +229,8 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds . repeat ( ) # allow wrap around and break iteration manually
if self . is_training :
ds = ds . shuffle ( min ( self . num_ s amples, self . shuffle_size ) / / self . global_num_workers , seed = self . worker_seed )
ds = ds . prefetch ( min ( self . num_ s amples / / self . global_num_workers , self . prefetch_size ) )
ds = ds . shuffle ( min ( self . num_ ex amples, self . shuffle_size ) / / self . global_num_workers , seed = self . worker_seed )
ds = ds . prefetch ( min ( self . num_ ex amples / / self . global_num_workers , self . prefetch_size ) )
self . ds = tfds . as_numpy ( ds )
def __iter__ ( self ) :
@ -210,44 +239,49 @@ class ParserTfds(Parser):
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra s amples and will slightly alter validation results.
# This adds extra ex amples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_ s ample_count = math . ceil ( max ( 1 , self . repeats ) * self . num_ s amples / self . global_num_workers )
target_ ex ample_count = math . ceil ( max ( 1 , self . repeats ) * self . num_ ex amples / self . global_num_workers )
if self . is_training :
# round up to nearest batch_size per worker-replica
target_ sample_count = math . ceil ( target_s ample_count / self . batch_size ) * self . batch_size
target_ example_count = math . ceil ( target_ex ample_count / self . batch_size ) * self . batch_size
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
sample_count = 0
for sample in self . ds :
img = Image . fromarray ( sample [ ' image ' ] , mode = ' RGB ' )
yield img , sample [ ' label ' ]
sample_count + = 1
if self . is_training and sample_count > = target_sample_count :
example_count = 0
for example in self . ds :
input_data = example [ self . input_name ]
if self . input_image :
input_data = Image . fromarray ( input_data , mode = self . input_image )
target_data = example [ self . target_name ]
if self . target_image :
target_data = Image . fromarray ( target_data , mode = self . target_image )
yield input_data , target_data
example_count + = 1
if self . is_training and example_count > = target_example_count :
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra samples per epoch but seems more desirable than dropping
# this results in extra ex amples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
# Pad across distributed nodes (make counts equal by adding s amples)
# Pad across distributed nodes (make counts equal by adding ex amples)
if not self . is_training and self . dist_num_replicas > 1 and self . subsplit is not None and \
0 < sample_count < target_s ample_count:
0 < example_count < target_ex ample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# If using input_context or % based splits, sample count can vary significantly across workers and this
# approach should not be used (hence disabled if self.subsplit isn't set).
while sample_count < target_s ample_count:
yield i mg, sample [ ' label ' ] # yield prev sample again
s ample_count + = 1
while example_count < target_ex ample_count:
yield i nput_data, target_data # yield prev sample again
ex ample_count + = 1
def __len__ ( self ) :
# this is just an estimate and does not factor in extra s amples added to pad batches based on
# this is just an estimate and does not factor in extra ex amples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math . ceil ( max ( 1 , self . repeats ) * self . num_ s amples / self . dist_num_replicas )
return math . ceil ( max ( 1 , self . repeats ) * self . num_ ex amples / self . dist_num_replicas )
def _filename ( self , index , basename = False , absolute = False ) :
assert False , " Not supported " # no random access to s amples
assert False , " Not supported " # no random access to ex amples
def filenames ( self , basename = False , absolute = False ) :
""" Return all filenames in dataset, overrides base """
@ -255,7 +289,7 @@ class ParserTfds(Parser):
self . _lazy_init ( )
names = [ ]
for sample in self . ds :
if len ( names ) > self . num_ s amples:
if len ( names ) > self . num_ ex amples:
break # safety for ds.repeat() case
if ' file_name ' in sample :
name = sample [ ' file_name ' ]