Add webdataset (WDS) support, update TFDS to make some naming in parsers more similar. Fix workers=0 compatibility. Add ImageNet22k/12k synset defs.
parent
3fce010ca8
commit
da2796ae82
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,261 @@
|
||||
""" Dataset parser interface for webdataset
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import yaml
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from functools import partial
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
try:
|
||||
import webdataset as wds
|
||||
from webdataset.shardlists import expand_urls
|
||||
except ImportError:
|
||||
wds = None
|
||||
expand_urls = None
|
||||
|
||||
from .parser import Parser
|
||||
from timm.bits import get_global_device, is_global_device
|
||||
|
||||
SHUFFLE_SIZE = 8192
|
||||
|
||||
|
||||
def _load_info(root, basename='info'):
|
||||
info_json = os.path.join(root, basename + '.json')
|
||||
info_yaml = os.path.join(root, basename + '.yaml')
|
||||
info_dict = {}
|
||||
if os.path.exists(info_json):
|
||||
with open(info_json, 'r') as f:
|
||||
info_dict = json.load(f)
|
||||
elif os.path.exists(info_yaml):
|
||||
with open(info_yaml, 'r') as f:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
|
||||
@dataclass
|
||||
class SplitInfo:
|
||||
num_samples: int
|
||||
filenames: Tuple[str]
|
||||
shard_lengths: Tuple[int] = ()
|
||||
name: str = ''
|
||||
|
||||
|
||||
def _parse_split_info(split: str, info: Dict):
|
||||
def _info_convert(dict_info):
|
||||
return SplitInfo(
|
||||
num_samples=dict_info['num_samples'],
|
||||
filenames=tuple(dict_info['filenames']),
|
||||
shard_lengths=tuple(dict_info['shard_lengths']),
|
||||
name=dict_info['name'],
|
||||
)
|
||||
|
||||
if 'tar' in split or '..' in split:
|
||||
# split in WDS string braceexpand format, sample count can be included with a | separator
|
||||
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
|
||||
split = split.split('|')
|
||||
num_samples = 0
|
||||
split_name = ''
|
||||
if len(split) > 1:
|
||||
num_samples = int(split[1])
|
||||
split = split[0]
|
||||
if '::' not in split:
|
||||
split_parts = split.split('-', 3)
|
||||
split_idx = len(split_parts) - 1
|
||||
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
|
||||
split_name = split_parts[split_idx]
|
||||
|
||||
split_filenames = expand_urls(split)
|
||||
if split_name:
|
||||
split_info = info['splits'][split_name]
|
||||
if not num_samples:
|
||||
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
|
||||
num_samples = sum(_fc[f] for f in split_filenames)
|
||||
split_info['filenames'] = tuple(_fc.keys())
|
||||
split_info['shard_lengths'] = tuple(_fc.values())
|
||||
split_info['num_samples'] = num_samples
|
||||
split_info = _info_convert(split_info)
|
||||
else:
|
||||
split_info = SplitInfo(
|
||||
name=split_name,
|
||||
num_samples=num_samples,
|
||||
filenames=split_filenames,
|
||||
)
|
||||
else:
|
||||
if split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
||||
split = split
|
||||
split_info = info['splits'][split]
|
||||
split_info = _info_convert(split_info)
|
||||
|
||||
return split_info
|
||||
|
||||
|
||||
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'):
|
||||
""" Custom sample decode
|
||||
* decode and convert PIL Image
|
||||
* cls byte string label to int
|
||||
* pass through JSON byte string (if it exists) without parse
|
||||
"""
|
||||
with io.BytesIO(sample[image_key]) as b:
|
||||
img = Image.open(b)
|
||||
img.load()
|
||||
if image_format:
|
||||
img = img.convert(image_format)
|
||||
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
|
||||
|
||||
|
||||
class ParserWebdataset(Parser):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
input_name='image',
|
||||
input_image='RGB',
|
||||
target_name=None,
|
||||
target_image='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
self.shard_shuffle_size = 500
|
||||
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||
|
||||
self.image_key = 'jpg'
|
||||
self.image_format = input_image
|
||||
self.target_key = 'cls'
|
||||
self.filename_key = 'filename'
|
||||
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
||||
|
||||
self.info = _load_info(self.root)
|
||||
self.split_info = _parse_split_info(split, self.info)
|
||||
self.num_samples = self.split_info.num_samples
|
||||
if not self.num_samples:
|
||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if is_global_device():
|
||||
dev_env = get_global_device()
|
||||
if dev_env.distributed and dev_env.world_size > 1:
|
||||
self.dist_rank = dev_env.global_rank
|
||||
self.dist_num_replicas = dev_env.world_size
|
||||
else:
|
||||
# FIXME warn if we fallback to torch distributed?
|
||||
import torch.distributed as dist
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
# Attributes that are updated in _lazy_init
|
||||
self.worker_id = 0
|
||||
self.worker_seed = seed # seed unique to each worker instance
|
||||
self.num_workers = 1
|
||||
self.global_worker_id = 0
|
||||
self.global_num_workers = 1
|
||||
self.init_count = 0
|
||||
|
||||
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
||||
self.ds = None
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize worker (in worker processes)
|
||||
"""
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
self.worker_id = worker_info.id
|
||||
self.worker_seed = worker_info.seed
|
||||
self.num_workers = worker_info.num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
|
||||
|
||||
# init data pipeline
|
||||
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
|
||||
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
|
||||
# at this point we have an iterator over all the shards
|
||||
if self.is_training:
|
||||
pipeline.extend([
|
||||
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(),
|
||||
wds.shuffle(
|
||||
self.sample_shuffle_size,
|
||||
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
||||
])
|
||||
else:
|
||||
pipeline.extend([
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(),
|
||||
])
|
||||
pipeline.extend([
|
||||
wds.map(partial(_decode, image_key=self.image_key, image_format=self.image_format))
|
||||
])
|
||||
self.ds = wds.DataPipeline(*pipeline)
|
||||
self.init_count += 1
|
||||
|
||||
def _split_by_node_and_worker(self, src):
|
||||
if self.global_num_workers > 1:
|
||||
for s in islice(src, self.global_worker_id, self.global_num_workers):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
def __iter__(self):
|
||||
if not self.init_count:
|
||||
self._lazy_init()
|
||||
|
||||
i = 0
|
||||
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size
|
||||
ds = self.ds.with_epoch(num_worker_samples)
|
||||
for sample in ds:
|
||||
yield sample[self.image_key], sample[self.target_key]
|
||||
i += 1
|
||||
print('end', i) # FIXME debug
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to examples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if not self.init_count:
|
||||
self._lazy_init()
|
||||
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if self.filename_key in sample:
|
||||
name = sample[self.filename_key]
|
||||
elif '__key__' in sample:
|
||||
name = sample['__key__'] + self.key_ext
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
if len(names) >= self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
return names
|
Loading…
Reference in new issue