@ -57,23 +57,28 @@ class ParserTfds(Parser):
components .
components .
"""
"""
def __init__ ( self , root , name , split = ' train ' , shuffle = False , is_training = False , batch_size = None , repeats = 0 ) :
def __init__ (
self , root , name , split = ' train ' , is_training = False , batch_size = None ,
download = False , repeats = 0 , seed = 42 ) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
self . root = root
self . root = root
self . split = split
self . split = split
self . shuffle = shuffle
self . is_training = is_training
self . is_training = is_training
if self . is_training :
if self . is_training :
assert batch_size is not None , \
assert batch_size is not None , \
" Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper "
" Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper "
self . batch_size = batch_size
self . batch_size = batch_size
self . repeats = repeats
self . repeats = repeats
self . common_seed = seed # seed across all worker / dist nodes
self . worker_seed = 0 # seed specific to each work instance
self . subsplit = None
self . subsplit = None
self . builder = tfds . builder ( name , data_dir = root )
self . builder = tfds . builder ( name , data_dir = root )
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
if download :
self . num_samples = self . builder . info . splits [ split ] . num_examples
self . builder . download_and_prepare ( )
self . split_info = self . builder . info . splits [ split ]
self . num_samples = self . split_info . num_examples
self . ds = None # initialized lazily on each dataloader worker process
self . ds = None # initialized lazily on each dataloader worker process
self . worker_info = None
self . worker_info = None
@ -97,17 +102,18 @@ class ParserTfds(Parser):
worker_info = torch . utils . data . get_worker_info ( )
worker_info = torch . utils . data . get_worker_info ( )
# setup input context to split dataset across distributed processes
# setup input context to split dataset across distributed processes
split = self . split
global_num_workers = num_workers = 1
num_workers = 1
global_worker_id = 1
if worker_info is not None :
if worker_info is not None :
self . worker_info = worker_info
self . worker_info = worker_info
self . worker_seed = worker_info . seed
num_workers = worker_info . num_workers
num_workers = worker_info . num_workers
global_num_workers = self . dist_num_replicas * num_workers
global_num_workers = self . dist_num_replicas * num_workers
worker_id = worker_info . id
worker_id = worker_info . id
global_worker_id = self . dist_rank * num_workers + worker_id
# FIXME I need to spend more time figuring out the best way to distribute/split data across
# FIXME verify best sharding approach
# combo of distributed replicas + dataloader worker processes
""" Data sharding
"""
InputContext will assign subset of underlying TFRecord files to each ' pipeline ' if used .
InputContext will assign subset of underlying TFRecord files to each ' pipeline ' if used .
My understanding is that using split , the underling TFRecord files will shuffle ( shuffle_files = True )
My understanding is that using split , the underling TFRecord files will shuffle ( shuffle_files = True )
between the splits each iteration , but that understanding could be wrong .
between the splits each iteration , but that understanding could be wrong .
@ -116,44 +122,39 @@ class ParserTfds(Parser):
* InputContext for distributed and sub - splits for worker processes
* InputContext for distributed and sub - splits for worker processes
* sub - splits for both
* sub - splits for both
"""
"""
# split_size = self.num_samples // num_workers
can_subsplit = ' [ ' not in self . split # can't subsplit a subsplit
# start = worker_id * split_size
should_subsplit = global_num_workers > 1 and (
# if worker_id == num_workers - 1:
self . split_info . num_shards < global_num_workers or not self . is_training )
# split = split + '[{}:]'.format(start)
if can_subsplit and should_subsplit :
# else:
# manually split the dataset w/o sharding for more even samples / worker
# split = split + '[{}:{}]'.format(start, start + split_size)
self . subsplit = even_split_indices ( self . split , global_num_workers , self . num_samples ) [ global_worker_id ]
if not self . is_training and ' [ ' not in self . split :
# If not training, and split doesn't define a subsplit, manually split the dataset
input_context = None
# for more even samples / worker
if global_num_workers > 1 and self . subsplit is None :
self . subsplit = even_split_indices ( self . split , global_num_workers , self . num_samples ) [
# set input context to divide shards among distributed replicas
self . dist_rank * num_workers + worker_id ]
if self . subsplit is None :
input_context = tf . distribute . InputContext (
input_context = tf . distribute . InputContext (
num_input_pipelines = self . dist_num_replicas * num_workers,
num_input_pipelines = global_num_workers ,
input_pipeline_id = self . dist_rank * num_workers + worker_id,
input_pipeline_id = global_worker_id ,
num_replicas_in_sync = self . dist_num_replicas # FIXME does this arg have any impact?
num_replicas_in_sync = self . dist_num_replicas # FIXME does this arg have any impact?
)
)
else :
input_context = None
read_config = tfds . ReadConfig (
read_config = tfds . ReadConfig (
shuffle_seed = 42 ,
shuffle_seed = self . common_seed ,
shuffle_reshuffle_each_iteration = True ,
shuffle_reshuffle_each_iteration = True ,
input_context = input_context )
input_context = input_context )
ds = self . builder . as_dataset (
ds = self . builder . as_dataset (
split = self . subsplit or self . split , shuffle_files = self . shuffle , read_config = read_config )
split = self . subsplit or self . split , shuffle_files = self . is_training , read_config = read_config )
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
options = tf . data . Options ( )
options = tf . data . Options ( )
options . experimental_threading . private_threadpool_size = max ( 1 , MAX_TP_SIZE / / num_workers )
thread_member = ' threading ' if hasattr ( options , ' threading ' ) else ' experimental_threading '
options . experimental_threading . max_intra_op_parallelism = 1
getattr ( options , thread_member ) . private_threadpool_size = max ( 1 , MAX_TP_SIZE / / num_workers )
getattr ( options , thread_member ) . max_intra_op_parallelism = 1
ds = ds . with_options ( options )
ds = ds . with_options ( options )
if self . is_training or self . repeats > 1 :
if self . is_training or self . repeats > 1 :
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds . repeat ( ) # allow wrap around and break iteration manually
ds = ds . repeat ( ) # allow wrap around and break iteration manually
if self . shuffle :
if self . is_training :
ds = ds . shuffle ( min ( self . num_samples , SHUFFLE_SIZE ) / / self . _num_pipelines , seed = 0 )
ds = ds . shuffle ( min ( self . num_samples , SHUFFLE_SIZE ) / / self . _num_pipelines , seed = self . worker_seed )
ds = ds . prefetch ( min ( self . num_samples / / self . _num_pipelines , PREFETCH_SIZE ) )
ds = ds . prefetch ( min ( self . num_samples / / self . _num_pipelines , PREFETCH_SIZE ) )
self . ds = tfds . as_numpy ( ds )
self . ds = tfds . as_numpy ( ds )