You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/parsers/parser_wds.py

262 lines
9.3 KiB

""" Dataset parser interface for webdataset
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
import os
import io
import json
import yaml
import random
from dataclasses import dataclass
from itertools import islice
from functools import partial
from typing import Dict, Tuple
import torch
from PIL import Image
try:
import webdataset as wds
from webdataset.shardlists import expand_urls
except ImportError:
wds = None
expand_urls = None
from .parser import Parser
from timm.bits import get_global_device, is_global_device
SHUFFLE_SIZE = 8192
def _load_info(root, basename='info'):
info_json = os.path.join(root, basename + '.json')
info_yaml = os.path.join(root, basename + '.yaml')
info_dict = {}
if os.path.exists(info_json):
with open(info_json, 'r') as f:
info_dict = json.load(f)
elif os.path.exists(info_yaml):
with open(info_yaml, 'r') as f:
info_dict = yaml.safe_load(f)
return info_dict
@dataclass
class SplitInfo:
num_samples: int
filenames: Tuple[str]
shard_lengths: Tuple[int] = ()
name: str = ''
def _parse_split_info(split: str, info: Dict):
def _info_convert(dict_info):
return SplitInfo(
num_samples=dict_info['num_samples'],
filenames=tuple(dict_info['filenames']),
shard_lengths=tuple(dict_info['shard_lengths']),
name=dict_info['name'],
)
if 'tar' in split or '..' in split:
# split in WDS string braceexpand format, sample count can be included with a | separator
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
split = split.split('|')
num_samples = 0
split_name = ''
if len(split) > 1:
num_samples = int(split[1])
split = split[0]
if '::' not in split:
split_parts = split.split('-', 3)
split_idx = len(split_parts) - 1
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
split_name = split_parts[split_idx]
split_filenames = expand_urls(split)
if split_name:
split_info = info['splits'][split_name]
if not num_samples:
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
num_samples = sum(_fc[f] for f in split_filenames)
split_info['filenames'] = tuple(_fc.keys())
split_info['shard_lengths'] = tuple(_fc.values())
split_info['num_samples'] = num_samples
split_info = _info_convert(split_info)
else:
split_info = SplitInfo(
name=split_name,
num_samples=num_samples,
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
return split_info
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'):
""" Custom sample decode
* decode and convert PIL Image
* cls byte string label to int
* pass through JSON byte string (if it exists) without parse
"""
with io.BytesIO(sample[image_key]) as b:
img = Image.open(b)
img.load()
if image_format:
img = img.convert(image_format)
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
class ParserWebdataset(Parser):
def __init__(
self,
root,
name,
split,
is_training=False,
batch_size=None,
repeats=0,
seed=42,
input_name='image',
input_image='RGB',
target_name=None,
target_image='',
prefetch_size=None,
shuffle_size=None,
):
super().__init__()
self.root = root
self.is_training = is_training
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shard_shuffle_size = 500
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
self.image_key = 'jpg'
self.image_format = input_image
self.target_key = 'cls'
self.filename_key = 'filename'
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
self.info = _load_info(self.root)
self.split_info = _parse_split_info(split, self.info)
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if is_global_device():
dev_env = get_global_device()
if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size
else:
# FIXME warn if we fallback to torch distributed?
import torch.distributed as dist
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_id = 0
self.worker_seed = seed # seed unique to each worker instance
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
self.init_count = 0
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
# is not handled in manner where it can be deterministic for each worker AND initialized up front
self.ds = None
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_id = worker_info.id
self.worker_seed = worker_info.seed
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
# init data pipeline
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
# at this point we have an iterator over all the shards
if self.is_training:
pipeline.extend([
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(),
wds.shuffle(
self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
])
else:
pipeline.extend([
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(),
])
pipeline.extend([
wds.map(partial(_decode, image_key=self.image_key, image_format=self.image_format))
])
self.ds = wds.DataPipeline(*pipeline)
self.init_count += 1
def _split_by_node_and_worker(self, src):
if self.global_num_workers > 1:
for s in islice(src, self.global_worker_id, self.global_num_workers):
yield s
else:
for s in src:
yield s
def __iter__(self):
if not self.init_count:
self._lazy_init()
i = 0
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
if self.is_training and self.batch_size is not None:
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size
ds = self.ds.with_epoch(num_worker_samples)
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
i += 1
print('end', i) # FIXME debug
def __len__(self):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if not self.init_count:
self._lazy_init()
names = []
for sample in self.ds:
if self.filename_key in sample:
name = sample[self.filename_key]
elif '__key__' in sample:
name = sample['__key__'] + self.key_ext
else:
assert False, "No supported name field present"
names.append(name)
if len(names) >= self.num_samples:
break # safety for ds.repeat() case
return names