|
|
|
@ -6,8 +6,10 @@ import math
|
|
|
|
|
import os
|
|
|
|
|
import io
|
|
|
|
|
import json
|
|
|
|
|
import yaml
|
|
|
|
|
import logging
|
|
|
|
|
import random
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from itertools import islice
|
|
|
|
|
from functools import partial
|
|
|
|
@ -25,6 +27,8 @@ except ImportError:
|
|
|
|
|
from .parser import Parser
|
|
|
|
|
from timm.bits import get_global_device, is_global_device
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
SHUFFLE_SIZE = 8192
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -110,6 +114,12 @@ def _parse_split_info(split: str, info: Dict):
|
|
|
|
|
return split_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_and_continue(exn):
|
|
|
|
|
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
|
|
|
|
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
|
|
|
|
|
""" Custom sample decode
|
|
|
|
|
* decode and convert PIL Image
|
|
|
|
@ -135,7 +145,8 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l
|
|
|
|
|
img = img.convert(image_format)
|
|
|
|
|
|
|
|
|
|
# json passed through in undecoded state
|
|
|
|
|
return dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
|
|
|
|
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _decode_samples(
|
|
|
|
@ -144,7 +155,7 @@ def _decode_samples(
|
|
|
|
|
image_format='RGB',
|
|
|
|
|
target_key='cls',
|
|
|
|
|
alt_label='',
|
|
|
|
|
handler=wds.reraise_exception):
|
|
|
|
|
handler=log_and_continue):
|
|
|
|
|
"""Decode samples with skip."""
|
|
|
|
|
for sample in data:
|
|
|
|
|
try:
|
|
|
|
@ -251,7 +262,7 @@ class ParserWebdataset(Parser):
|
|
|
|
|
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
|
|
|
|
|
self._split_by_node_and_worker,
|
|
|
|
|
# at this point, we have an iterator over the shards assigned to each worker
|
|
|
|
|
wds.tarfile_to_samples(),
|
|
|
|
|
wds.tarfile_to_samples(handler=log_and_continue),
|
|
|
|
|
wds.shuffle(
|
|
|
|
|
self.sample_shuffle_size,
|
|
|
|
|
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
|
|
|
@ -260,7 +271,7 @@ class ParserWebdataset(Parser):
|
|
|
|
|
pipeline.extend([
|
|
|
|
|
self._split_by_node_and_worker,
|
|
|
|
|
# at this point, we have an iterator over the shards assigned to each worker
|
|
|
|
|
wds.tarfile_to_samples(),
|
|
|
|
|
wds.tarfile_to_samples(handler=log_and_continue),
|
|
|
|
|
])
|
|
|
|
|
pipeline.extend([
|
|
|
|
|
partial(
|
|
|
|
|