Add webdataset (WDS) support, update TFDS to make some naming in parsers more similar. Fix workers=0 compatibility. Add ImageNet22k/12k synset defs.

pull/1239/head
Ross Wightman 3 years ago
parent 3fce010ca8
commit da2796ae82

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -69,6 +69,7 @@ def create_dataset(
* folder - default, timm folder (or tar) based ImageDataset * folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets * torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset
* all - any of the above * all - any of the above
Args: Args:
@ -134,6 +135,10 @@ def create_dataset(
ds = IterableImageDataset( ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs) download=download, batch_size=batch_size, repeats=repeats, **kwargs)
elif name.startswith('wds/'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, **kwargs)
else: else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root): if search_split and os.path.isdir(root):

@ -18,6 +18,10 @@ def create_parser(name, root, split='train', **kwargs):
if prefix == 'tfds': if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs) parser = ParserTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .parser_wds import ParserWebdataset
kwargs.pop('download', False)
parser = ParserWebdataset(root, name, split=split, **kwargs)
else: else:
assert os.path.exists(root) assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -34,12 +34,12 @@ from .parser import Parser
from timm.bits import get_global_device, is_global_device from timm.bits import get_global_device, is_global_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue SHUFFLE_SIZE = 8192 # number of samples to shuffle in DS queue
PREFETCH_SIZE = 2048 # examples to prefetch PREFETCH_SIZE = 2048 # number of samples to prefetch
def even_split_indices(split, n, num_examples): def even_split_indices(split, n, num_samples):
partitions = [round(i * num_examples / n) for i in range(n + 1)] partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
@ -55,20 +55,20 @@ class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch """ Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of: There several things to be aware of:
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413 https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is worked around by option above, for from each worker could be a different size. For training this is worked around by option above, for
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightly alter the results, distributed validation will not be across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra examples with IterableDatasets. since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets. benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS * This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
components. components.
@ -100,7 +100,7 @@ class ParserTfds(Parser):
name: tfds dataset name (eg `imagenet2012`) name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances seed: common seed for shard shuffle across all distributed/worker instances
@ -139,7 +139,7 @@ class ParserTfds(Parser):
self.builder.download_and_prepare() self.builder.download_and_prepare()
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split] self.split_info = self.builder.info.splits[split]
self.num_examples = self.split_info.num_examples self.num_samples = self.split_info.num_examples
# Distributed world state # Distributed world state
self.dist_rank = 0 self.dist_rank = 0
@ -157,13 +157,18 @@ class ParserTfds(Parser):
self.dist_num_replicas = dist.get_world_size() self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself # Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1 self.worker_init = False # worker info initialized
self.worker_info = None self.worker_id = 0
self.worker_seed = 0 # seed unique to each work instance self.worker_seed = 0 # seed unique to each work instance
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
self.subsplit = None # set when data is distributed across workers using sub-splits self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process self.ds = None # initialized lazily on each dataloader worker process
self.init_count = 0 self.init_count = 0 # number of ds TF data pipeline initializations
self.reinit_each_iter = self.is_training # FIXME need to determine if this is necessary # FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour
# of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs
self.reinit_each_iter = self.is_training
def _lazy_init(self): def _lazy_init(self):
""" Lazily initialize the dataset. """ Lazily initialize the dataset.
@ -177,14 +182,15 @@ class ParserTfds(Parser):
before it is passed to dataloader. before it is passed to dataloader.
""" """
# setup input context to split dataset across distributed processes # setup input context to split dataset across distributed processes
if self.worker_info is None: if not self.worker_init:
# worker init done once, even if data-pipeline is re-initialized
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
assert worker_info is not None if worker_info is not None:
self.worker_info = worker_info self.worker_id = worker_info.id
self.worker_seed = worker_info.seed self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * num_workers self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
global_worker_id = self.dist_rank * num_workers + worker_info.id self.global_num_workers = self.dist_num_replicas * self.num_workers
""" Data sharding """ Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
@ -194,54 +200,59 @@ class ParserTfds(Parser):
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding. for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
""" """
should_subsplit = self.global_num_workers > 1 and ( should_subsplit = self.global_num_workers > 1 and (
self.split_info.num_shards < self.global_num_workers or not self.is_training) self.split_info.num_shards < self.global_num_workers or not self.is_training)
if should_subsplit: if should_subsplit:
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal # split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there # read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits: if has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds! # my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples) subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = subsplits[global_worker_id] self.subsplit = subsplits[self.global_worker_id]
else: else:
subsplits = tfds.even_splits(self.split, self.global_num_workers) subsplits = tfds.even_splits(self.split, self.global_num_workers)
self.subsplit = subsplits[global_worker_id] self.subsplit = subsplits[self.global_worker_id]
else:
num_workers = self.worker_info.num_workers
global_worker_id = self.dist_rank * num_workers + self.worker_info.id
self.worker_init = True
# initialize TF data pipeline
input_context = None input_context = None
if self.global_num_workers > 1 and self.subsplit is None: if self.global_num_workers > 1 and self.subsplit is None:
# set input context to divide shards among distributed replicas # set input context to divide shards among distributed replicas
input_context = tf.distribute.InputContext( input_context = tf.distribute.InputContext(
num_input_pipelines=self.global_num_workers, num_input_pipelines=self.global_num_workers,
input_pipeline_id=global_worker_id, input_pipeline_id=self.global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
) )
read_config = tfds.ReadConfig( read_config = tfds.ReadConfig(
shuffle_seed=self.common_seed + self.init_count, shuffle_seed=self.common_seed + self.init_count, # shard shuffling seed
shuffle_reshuffle_each_iteration=not self.reinit_each_iter, shuffle_reshuffle_each_iteration=not self.reinit_each_iter, # re-shuffle shards per iteration
input_context=input_context) input_context=input_context)
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) split=self.subsplit or self.split,
shuffle_files=self.is_training, # enable shard shuffling
read_config=read_config)
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers # avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options() options = tf.data.Options()
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers) getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1 getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options) ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training: if self.is_training:
# shuffle samples
ds = ds.shuffle( ds = ds.shuffle(
min(self.num_examples, self.shuffle_size) // self.global_num_workers, min(self.num_samples, self.shuffle_size) // self.global_num_workers,
seed=self.worker_seed + self.init_count) seed=self.worker_seed + self.init_count)
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)
self.init_count += 1 self.init_count += 1
@ -251,10 +262,10 @@ class ParserTfds(Parser):
# Compute a rounded up sample count that is used to: # Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation. # 1. make batches even cross workers & replicas in distributed validation.
# This adds extra examples and will slightly alter validation results. # This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around) # batches are produced (underlying tfds iter wraps around)
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers) target_example_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers)
if self.is_training: if self.is_training:
# round up to nearest batch_size per worker-replica # round up to nearest batch_size per worker-replica
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
@ -272,11 +283,11 @@ class ParserTfds(Parser):
example_count += 1 example_count += 1
if self.is_training and example_count >= target_example_count: if self.is_training and example_count >= target_example_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling # Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra examples per epoch but seems more desirable than dropping # this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break break
# Pad across distributed nodes (make counts equal by adding examples) # Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < example_count < target_example_count: 0 < example_count < target_example_count:
# Validation batch padding only done for distributed training where results are reduced across nodes. # Validation batch padding only done for distributed training where results are reduced across nodes.
@ -288,12 +299,12 @@ class ParserTfds(Parser):
example_count += 1 example_count += 1
def __len__(self): def __len__(self):
# this is just an estimate and does not factor in extra examples added to pad batches based on # this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader). # complete worker & replica info (not available until init in dataloader).
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas) return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False): def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base""" """ Return all filenames in dataset, overrides base"""
@ -301,7 +312,7 @@ class ParserTfds(Parser):
self._lazy_init() self._lazy_init()
names = [] names = []
for sample in self.ds: for sample in self.ds:
if len(names) > self.num_examples: if len(names) >= self.num_samples:
break # safety for ds.repeat() case break # safety for ds.repeat() case
if 'file_name' in sample: if 'file_name' in sample:
name = sample['file_name'] name = sample['file_name']

@ -0,0 +1,261 @@
""" Dataset parser interface for webdataset
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
import os
import io
import json
import yaml
import random
from dataclasses import dataclass
from itertools import islice
from functools import partial
from typing import Dict, Tuple
import torch
from PIL import Image
try:
import webdataset as wds
from webdataset.shardlists import expand_urls
except ImportError:
wds = None
expand_urls = None
from .parser import Parser
from timm.bits import get_global_device, is_global_device
SHUFFLE_SIZE = 8192
def _load_info(root, basename='info'):
info_json = os.path.join(root, basename + '.json')
info_yaml = os.path.join(root, basename + '.yaml')
info_dict = {}
if os.path.exists(info_json):
with open(info_json, 'r') as f:
info_dict = json.load(f)
elif os.path.exists(info_yaml):
with open(info_yaml, 'r') as f:
info_dict = yaml.safe_load(f)
return info_dict
@dataclass
class SplitInfo:
num_samples: int
filenames: Tuple[str]
shard_lengths: Tuple[int] = ()
name: str = ''
def _parse_split_info(split: str, info: Dict):
def _info_convert(dict_info):
return SplitInfo(
num_samples=dict_info['num_samples'],
filenames=tuple(dict_info['filenames']),
shard_lengths=tuple(dict_info['shard_lengths']),
name=dict_info['name'],
)
if 'tar' in split or '..' in split:
# split in WDS string braceexpand format, sample count can be included with a | separator
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
split = split.split('|')
num_samples = 0
split_name = ''
if len(split) > 1:
num_samples = int(split[1])
split = split[0]
if '::' not in split:
split_parts = split.split('-', 3)
split_idx = len(split_parts) - 1
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
split_name = split_parts[split_idx]
split_filenames = expand_urls(split)
if split_name:
split_info = info['splits'][split_name]
if not num_samples:
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
num_samples = sum(_fc[f] for f in split_filenames)
split_info['filenames'] = tuple(_fc.keys())
split_info['shard_lengths'] = tuple(_fc.values())
split_info['num_samples'] = num_samples
split_info = _info_convert(split_info)
else:
split_info = SplitInfo(
name=split_name,
num_samples=num_samples,
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
return split_info
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'):
""" Custom sample decode
* decode and convert PIL Image
* cls byte string label to int
* pass through JSON byte string (if it exists) without parse
"""
with io.BytesIO(sample[image_key]) as b:
img = Image.open(b)
img.load()
if image_format:
img = img.convert(image_format)
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
class ParserWebdataset(Parser):
def __init__(
self,
root,
name,
split,
is_training=False,
batch_size=None,
repeats=0,
seed=42,
input_name='image',
input_image='RGB',
target_name=None,
target_image='',
prefetch_size=None,
shuffle_size=None,
):
super().__init__()
self.root = root
self.is_training = is_training
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shard_shuffle_size = 500
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
self.image_key = 'jpg'
self.image_format = input_image
self.target_key = 'cls'
self.filename_key = 'filename'
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
self.info = _load_info(self.root)
self.split_info = _parse_split_info(split, self.info)
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if is_global_device():
dev_env = get_global_device()
if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size
else:
# FIXME warn if we fallback to torch distributed?
import torch.distributed as dist
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_id = 0
self.worker_seed = seed # seed unique to each worker instance
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
self.init_count = 0
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
# is not handled in manner where it can be deterministic for each worker AND initialized up front
self.ds = None
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_id = worker_info.id
self.worker_seed = worker_info.seed
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
# init data pipeline
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
# at this point we have an iterator over all the shards
if self.is_training:
pipeline.extend([
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(),
wds.shuffle(
self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
])
else:
pipeline.extend([
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(),
])
pipeline.extend([
wds.map(partial(_decode, image_key=self.image_key, image_format=self.image_format))
])
self.ds = wds.DataPipeline(*pipeline)
self.init_count += 1
def _split_by_node_and_worker(self, src):
if self.global_num_workers > 1:
for s in islice(src, self.global_worker_id, self.global_num_workers):
yield s
else:
for s in src:
yield s
def __iter__(self):
if not self.init_count:
self._lazy_init()
i = 0
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
if self.is_training and self.batch_size is not None:
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size
ds = self.ds.with_epoch(num_worker_samples)
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
i += 1
print('end', i) # FIXME debug
def __len__(self):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if not self.init_count:
self._lazy_init()
names = []
for sample in self.ds:
if self.filename_key in sample:
name = sample[self.filename_key]
elif '__key__' in sample:
name = sample['__key__'] + self.key_ext
else:
assert False, "No supported name field present"
names.append(name)
if len(names) >= self.num_samples:
break # safety for ds.repeat() case
return names

@ -607,7 +607,7 @@ def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool):
) )
eval_workers = args.workers eval_workers = args.workers
if 'tfds' in args.dataset: if 'tfds' in args.dataset or 'wds' in args.dataset:
# FIXME reduces validation padding issues when using TFDS w/ workers and distributed training # FIXME reduces validation padding issues when using TFDS w/ workers and distributed training
eval_workers = min(2, args.workers) eval_workers = min(2, args.workers)

Loading…
Cancel
Save