Add log and continue handler for WDS errors, fix args.num_gpu for validation script fallback

pull/1414/head
Ross Wightman 2 years ago
parent 7eeaf521a0
commit ab16a358bb

@ -6,8 +6,10 @@ import math
import os
import io
import json
import yaml
import logging
import random
import yaml
from dataclasses import dataclass
from itertools import islice
from functools import partial
@ -25,6 +27,8 @@ except ImportError:
from .parser import Parser
from timm.bits import get_global_device, is_global_device
_logger = logging.getLogger(__name__)
SHUFFLE_SIZE = 8192
@ -110,6 +114,12 @@ def _parse_split_info(split: str, info: Dict):
return split_info
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
""" Custom sample decode
* decode and convert PIL Image
@ -135,7 +145,8 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l
img = img.convert(image_format)
# json passed through in undecoded state
return dict(jpg=img, cls=class_label, json=sample.get('json', None))
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
return decoded
def _decode_samples(
@ -144,7 +155,7 @@ def _decode_samples(
image_format='RGB',
target_key='cls',
alt_label='',
handler=wds.reraise_exception):
handler=log_and_continue):
"""Decode samples with skip."""
for sample in data:
try:
@ -251,7 +262,7 @@ class ParserWebdataset(Parser):
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(),
wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
@ -260,7 +271,7 @@ class ParserWebdataset(Parser):
pipeline.extend([
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(),
wds.tarfile_to_samples(handler=log_and_continue),
])
pipeline.extend([
partial(

@ -262,6 +262,7 @@ def main():
batch_size = start_batch_size
args.model = m
args.checkpoint = c
args.num_gpu = 1 # FIXME support data-parallel?
result = OrderedDict(model=args.model)
r = {}
while not r and batch_size >= 1:

Loading…
Cancel
Save