Add log and continue handler for WDS errors, fix args.num_gpu for validation script fallback

pull/1414/head
Ross Wightman 2 years ago
parent 7eeaf521a0
commit ab16a358bb

@ -6,8 +6,10 @@ import math
import os import os
import io import io
import json import json
import yaml import logging
import random import random
import yaml
from dataclasses import dataclass from dataclasses import dataclass
from itertools import islice from itertools import islice
from functools import partial from functools import partial
@ -25,6 +27,8 @@ except ImportError:
from .parser import Parser from .parser import Parser
from timm.bits import get_global_device, is_global_device from timm.bits import get_global_device, is_global_device
_logger = logging.getLogger(__name__)
SHUFFLE_SIZE = 8192 SHUFFLE_SIZE = 8192
@ -110,6 +114,12 @@ def _parse_split_info(split: str, info: Dict):
return split_info return split_info
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''): def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
""" Custom sample decode """ Custom sample decode
* decode and convert PIL Image * decode and convert PIL Image
@ -135,7 +145,8 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l
img = img.convert(image_format) img = img.convert(image_format)
# json passed through in undecoded state # json passed through in undecoded state
return dict(jpg=img, cls=class_label, json=sample.get('json', None)) decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
return decoded
def _decode_samples( def _decode_samples(
@ -144,7 +155,7 @@ def _decode_samples(
image_format='RGB', image_format='RGB',
target_key='cls', target_key='cls',
alt_label='', alt_label='',
handler=wds.reraise_exception): handler=log_and_continue):
"""Decode samples with skip.""" """Decode samples with skip."""
for sample in data: for sample in data:
try: try:
@ -251,7 +262,7 @@ class ParserWebdataset(Parser):
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed), wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
self._split_by_node_and_worker, self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker # at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(), wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle( wds.shuffle(
self.sample_shuffle_size, self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
@ -260,7 +271,7 @@ class ParserWebdataset(Parser):
pipeline.extend([ pipeline.extend([
self._split_by_node_and_worker, self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker # at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(), wds.tarfile_to_samples(handler=log_and_continue),
]) ])
pipeline.extend([ pipeline.extend([
partial( partial(

@ -262,6 +262,7 @@ def main():
batch_size = start_batch_size batch_size = start_batch_size
args.model = m args.model = m
args.checkpoint = c args.checkpoint = c
args.num_gpu = 1 # FIXME support data-parallel?
result = OrderedDict(model=args.model) result = OrderedDict(model=args.model)
r = {} r = {}
while not r and batch_size >= 1: while not r and batch_size >= 1:

Loading…
Cancel
Save