""" Dataset parser interface for webdataset Hacked together by / Copyright 2022 Ross Wightman """ import io import json import logging import math import os import random import sys from dataclasses import dataclass from functools import partial from itertools import islice from typing import Dict, Tuple import torch import torch.distributed as dist import yaml from PIL import Image from torch.utils.data import Dataset, IterableDataset, get_worker_info try: import webdataset as wds from webdataset.filters import _shuffle from webdataset.shardlists import expand_urls from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample except ImportError: wds = None expand_urls = None from .parser import Parser from .shared_count import SharedCount _logger = logging.getLogger(__name__) SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192) def _load_info(root, basename='info'): info_json = os.path.join(root, basename + '.json') info_yaml = os.path.join(root, basename + '.yaml') err_str = '' try: with wds.gopen.gopen(info_json) as f: info_dict = json.load(f) return info_dict except Exception as e: err_str = str(e) try: with wds.gopen.gopen(info_yaml) as f: info_dict = yaml.safe_load(f) return info_dict except Exception: pass _logger.warning( f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. ' 'Falling back to provided split and size arg.') return {} @dataclass class SplitInfo: num_samples: int filenames: Tuple[str] shard_lengths: Tuple[int] = () alt_label: str = '' name: str = '' def _parse_split_info(split: str, info: Dict): def _info_convert(dict_info): return SplitInfo( num_samples=dict_info['num_samples'], filenames=tuple(dict_info['filenames']), shard_lengths=tuple(dict_info['shard_lengths']), alt_label=dict_info.get('alt_label', ''), name=dict_info['name'], ) if 'tar' in split or '..' in split: # split in WDS string braceexpand format, sample count can be included with a | separator # ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples split = split.split('|') num_samples = 0 split_name = '' if len(split) > 1: num_samples = int(split[1]) split = split[0] if '::' not in split: split_parts = split.split('-', 3) split_idx = len(split_parts) - 1 if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']: split_name = split_parts[split_idx] split_filenames = expand_urls(split) if split_name: split_info = info['splits'][split_name] if not num_samples: _fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])} num_samples = sum(_fc[f] for f in split_filenames) split_info['filenames'] = tuple(_fc.keys()) split_info['shard_lengths'] = tuple(_fc.values()) split_info['num_samples'] = num_samples split_info = _info_convert(split_info) else: split_info = SplitInfo( name=split_name, num_samples=num_samples, filenames=split_filenames, ) else: if split not in info['splits']: raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})") split = split split_info = info['splits'][split] split_info = _info_convert(split_info) return split_info def log_and_continue(exn): """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" _logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') return True def _decode( sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label='' ): """ Custom sample decode * decode and convert PIL Image * cls byte string label to int * pass through JSON byte string (if it exists) without parse """ # decode class label, skip if alternate label not valid if alt_label: # alternative labels are encoded in json metadata meta = json.loads(sample['json']) class_label = int(meta[alt_label]) if class_label < 0: # skipped labels currently encoded as -1, may change to a null/None value return None else: class_label = int(sample[target_key]) # decode image with io.BytesIO(sample[image_key]) as b: img = Image.open(b) img.load() if image_format: img = img.convert(image_format) # json passed through in undecoded state decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None)) return decoded def _decode_samples( data, image_key='jpg', image_format='RGB', target_key='cls', alt_label='', handler=log_and_continue): """Decode samples with skip.""" for sample in data: try: result = _decode( sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label ) except Exception as exn: if handler(exn): continue else: break # null results are skipped if result is not None: if isinstance(sample, dict) and isinstance(result, dict): result["__key__"] = sample.get("__key__") yield result def pytorch_worker_seed(): """get dataloader worker seed from pytorch""" worker_info = get_worker_info() if worker_info is not None: # favour the seed already created for pytorch dataloader workers if it exists return worker_info.seed # fallback to wds rank based seed return wds.utils.pytorch_worker_seed() if wds is not None: # conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage) class detshuffle2(wds.PipelineStage): def __init__( self, bufsize=1000, initial=100, seed=0, epoch=-1, ): self.bufsize = bufsize self.initial = initial self.seed = seed self.epoch = epoch def run(self, src): if isinstance(self.epoch, SharedCount): epoch = self.epoch.value else: # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) # situation as different workers may wrap at different times (or not at all). self.epoch += 1 epoch = self.epoch if self.seed < 0: seed = pytorch_worker_seed() + epoch else: seed = self.seed + epoch _logger.info('shuffle', self.seed, epoch, seed) # FIXME temporary rng = random.Random(seed) return _shuffle(src, self.bufsize, self.initial, rng) else: detshuffle2 = None class ResampledShards2(IterableDataset): """An iterable dataset yielding a list of urls.""" def __init__( self, urls, nshards=sys.maxsize, worker_seed=None, deterministic=True, epoch=-1, ): """Sample shards from the shard list with replacement. :param urls: a list of URLs as a Python list or brace notation string """ super().__init__() urls = wds.shardlists.expand_urls(urls) self.urls = urls assert isinstance(self.urls[0], str) self.nshards = nshards self.rng = random.Random() self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed self.deterministic = deterministic self.epoch = epoch def __iter__(self): """Return an iterator over the shards.""" if isinstance(self.epoch, SharedCount): epoch = self.epoch.value else: # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) # situation as different workers may wrap at different times (or not at all). self.epoch += 1 epoch = self.epoch if self.deterministic: # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed self.rng = random.Random(self.worker_seed() + epoch) for _ in range(self.nshards): index = self.rng.randint(0, len(self.urls) - 1) yield dict(url=self.urls[index]) class ParserWds(Parser): def __init__( self, root, name, split, is_training=False, batch_size=None, repeats=0, seed=42, input_name='jpg', input_image='RGB', target_name='cls', target_image='', prefetch_size=None, shuffle_size=None, ): super().__init__() if wds is None: raise RuntimeError( 'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.') self.root = root self.is_training = is_training self.batch_size = batch_size self.repeats = repeats self.common_seed = seed # a seed that's fixed across all worker / distributed instances self.shard_shuffle_size = 500 self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE self.image_key = input_name self.image_format = input_image self.target_key = target_name self.filename_key = 'filename' self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet) self.info = _load_info(self.root) self.split_info = _parse_split_info(split, self.info) self.num_samples = self.split_info.num_samples if not self.num_samples: raise RuntimeError(f'Invalid split definition, no samples found.') # Distributed world state self.dist_rank = 0 self.dist_num_replicas = 1 if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: self.dist_rank = dist.get_rank() self.dist_num_replicas = dist.get_world_size() # Attributes that are updated in _lazy_init self.worker_info = None self.worker_id = 0 self.worker_seed = seed # seed unique to each worker instance self.num_workers = 1 self.global_worker_id = 0 self.global_num_workers = 1 self.init_count = 0 self.epoch_count = SharedCount() # DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed # is not handled in manner where it can be deterministic for each worker AND initialized up front self.ds = None def set_epoch(self, count): self.epoch_count.value = count def _lazy_init(self): """ Lazily initialize worker (in worker processes) """ if self.worker_info is None: worker_info = torch.utils.data.get_worker_info() if worker_info is not None: self.worker_info = worker_info self.worker_id = worker_info.id self.worker_seed = worker_info.seed self.num_workers = worker_info.num_workers self.global_num_workers = self.dist_num_replicas * self.num_workers self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id # init data pipeline abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames] pipeline = [wds.SimpleShardList(abs_shard_filenames)] # at this point we have an iterator over all the shards if self.is_training: pipeline.extend([ detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count), self._split_by_node_and_worker, # at this point, we have an iterator over the shards assigned to each worker wds.tarfile_to_samples(handler=log_and_continue), wds.shuffle( self.sample_shuffle_size, rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline ]) else: pipeline.extend([ self._split_by_node_and_worker, # at this point, we have an iterator over the shards assigned to each worker wds.tarfile_to_samples(handler=log_and_continue), ]) pipeline.extend([ partial( _decode_samples, image_key=self.image_key, image_format=self.image_format, alt_label=self.split_info.alt_label ) ]) self.ds = wds.DataPipeline(*pipeline) def _split_by_node_and_worker(self, src): if self.global_num_workers > 1: for s in islice(src, self.global_worker_id, None, self.global_num_workers): yield s else: for s in src: yield s def __iter__(self): if self.ds is None: self._lazy_init() if self.is_training: num_worker_samples = math.floor(self.num_samples / self.global_num_workers) if self.batch_size is not None: num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size ds = self.ds.with_epoch(num_worker_samples) else: if self.dist_num_replicas > 1: # doing distributed validation w/ WDS is messy, hard to meet constraints that # same # of batches needed across all replicas w/ seeing each sample once. # with_epoch() is simple but could miss a shard's worth of samples in some workers, # and duplicate in others. Best to keep num DL workers low and a divisor of #val shards. num_worker_samples = math.ceil(self.num_samples / self.global_num_workers) ds = self.ds.with_epoch(num_worker_samples) else: ds = self.ds i = 0 _logger.info('start', i, self.worker_id) # FIXME temporary debug for sample in ds: yield sample[self.image_key], sample[self.target_key] i += 1 _logger.info('end', i, self.worker_id) # FIXME temporary debug def __len__(self): return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): assert False, "Not supported" # no random access to examples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base""" if self.ds is None: self._lazy_init() names = [] for sample in self.ds: if self.filename_key in sample: name = sample[self.filename_key] elif '__key__' in sample: name = sample['__key__'] + self.key_ext else: assert False, "No supported name field present" names.append(name) if len(names) >= self.num_samples: break # safety for ds.repeat() case return names