""" Dataset parser interface for webdataset Hacked together by / Copyright 2022 Ross Wightman """ import math import os import io import json import yaml import random from dataclasses import dataclass from itertools import islice from functools import partial from typing import Dict, Tuple import torch from PIL import Image try: import webdataset as wds from webdataset.shardlists import expand_urls except ImportError: wds = None expand_urls = None from .parser import Parser from timm.bits import get_global_device, is_global_device SHUFFLE_SIZE = 8192 def _load_info(root, basename='info'): info_json = os.path.join(root, basename + '.json') info_yaml = os.path.join(root, basename + '.yaml') info_dict = {} if os.path.exists(info_json): with open(info_json, 'r') as f: info_dict = json.load(f) elif os.path.exists(info_yaml): with open(info_yaml, 'r') as f: info_dict = yaml.safe_load(f) return info_dict @dataclass class SplitInfo: num_samples: int filenames: Tuple[str] shard_lengths: Tuple[int] = () name: str = '' def _parse_split_info(split: str, info: Dict): def _info_convert(dict_info): return SplitInfo( num_samples=dict_info['num_samples'], filenames=tuple(dict_info['filenames']), shard_lengths=tuple(dict_info['shard_lengths']), name=dict_info['name'], ) if 'tar' in split or '..' in split: # split in WDS string braceexpand format, sample count can be included with a | separator # ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples split = split.split('|') num_samples = 0 split_name = '' if len(split) > 1: num_samples = int(split[1]) split = split[0] if '::' not in split: split_parts = split.split('-', 3) split_idx = len(split_parts) - 1 if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']: split_name = split_parts[split_idx] split_filenames = expand_urls(split) if split_name: split_info = info['splits'][split_name] if not num_samples: _fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])} num_samples = sum(_fc[f] for f in split_filenames) split_info['filenames'] = tuple(_fc.keys()) split_info['shard_lengths'] = tuple(_fc.values()) split_info['num_samples'] = num_samples split_info = _info_convert(split_info) else: split_info = SplitInfo( name=split_name, num_samples=num_samples, filenames=split_filenames, ) else: if split not in info['splits']: raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})") split = split split_info = info['splits'][split] split_info = _info_convert(split_info) return split_info def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'): """ Custom sample decode * decode and convert PIL Image * cls byte string label to int * pass through JSON byte string (if it exists) without parse """ with io.BytesIO(sample[image_key]) as b: img = Image.open(b) img.load() if image_format: img = img.convert(image_format) return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None)) class ParserWebdataset(Parser): def __init__( self, root, name, split, is_training=False, batch_size=None, repeats=0, seed=42, input_name='image', input_image='RGB', target_name=None, target_image='', prefetch_size=None, shuffle_size=None, ): super().__init__() self.root = root self.is_training = is_training self.batch_size = batch_size self.repeats = repeats self.common_seed = seed # a seed that's fixed across all worker / distributed instances self.shard_shuffle_size = 500 self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE self.image_key = 'jpg' self.image_format = input_image self.target_key = 'cls' self.filename_key = 'filename' self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet) self.info = _load_info(self.root) self.split_info = _parse_split_info(split, self.info) self.num_samples = self.split_info.num_samples if not self.num_samples: raise RuntimeError(f'Invalid split definition, no samples found.') # Distributed world state self.dist_rank = 0 self.dist_num_replicas = 1 if is_global_device(): dev_env = get_global_device() if dev_env.distributed and dev_env.world_size > 1: self.dist_rank = dev_env.global_rank self.dist_num_replicas = dev_env.world_size else: # FIXME warn if we fallback to torch distributed? import torch.distributed as dist if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: self.dist_rank = dist.get_rank() self.dist_num_replicas = dist.get_world_size() # Attributes that are updated in _lazy_init self.worker_id = 0 self.worker_seed = seed # seed unique to each worker instance self.num_workers = 1 self.global_worker_id = 0 self.global_num_workers = 1 self.init_count = 0 # DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed # is not handled in manner where it can be deterministic for each worker AND initialized up front self.ds = None def _lazy_init(self): """ Lazily initialize worker (in worker processes) """ worker_info = torch.utils.data.get_worker_info() if worker_info is not None: self.worker_id = worker_info.id self.worker_seed = worker_info.seed self.num_workers = worker_info.num_workers self.global_num_workers = self.dist_num_replicas * self.num_workers self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id # init data pipeline abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames] pipeline = [wds.SimpleShardList(abs_shard_filenames)] # at this point we have an iterator over all the shards if self.is_training: pipeline.extend([ wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed), self._split_by_node_and_worker, # at this point, we have an iterator over the shards assigned to each worker wds.tarfile_to_samples(), wds.shuffle( self.sample_shuffle_size, rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline ]) else: pipeline.extend([ self._split_by_node_and_worker, # at this point, we have an iterator over the shards assigned to each worker wds.tarfile_to_samples(), ]) pipeline.extend([ wds.map(partial(_decode, image_key=self.image_key, image_format=self.image_format)) ]) self.ds = wds.DataPipeline(*pipeline) self.init_count += 1 def _split_by_node_and_worker(self, src): if self.global_num_workers > 1: for s in islice(src, self.global_worker_id, self.global_num_workers): yield s else: for s in src: yield s def __iter__(self): if not self.init_count: self._lazy_init() i = 0 num_worker_samples = math.ceil(self.num_samples / self.global_num_workers) if self.is_training and self.batch_size is not None: num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size ds = self.ds.with_epoch(num_worker_samples) for sample in ds: yield sample[self.image_key], sample[self.target_key] i += 1 print('end', i) # FIXME debug def __len__(self): return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): assert False, "Not supported" # no random access to examples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base""" if not self.init_count: self._lazy_init() names = [] for sample in self.ds: if self.filename_key in sample: name = sample[self.filename_key] elif '__key__' in sample: name = sample['__key__'] + self.key_ext else: assert False, "No supported name field present" names.append(name) if len(names) >= self.num_samples: break # safety for ds.repeat() case return names