From ab16a358bb5ae7d7e2cdd78c90dcdd01d972963a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 16 Mar 2022 11:44:29 -0700 Subject: [PATCH] Add log and continue handler for WDS errors, fix args.num_gpu for validation script fallback --- timm/data/parsers/parser_wds.py | 21 ++++++++++++++++----- validate.py | 1 + 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/timm/data/parsers/parser_wds.py b/timm/data/parsers/parser_wds.py index 0bf3fb7a..7011d967 100644 --- a/timm/data/parsers/parser_wds.py +++ b/timm/data/parsers/parser_wds.py @@ -6,8 +6,10 @@ import math import os import io import json -import yaml +import logging import random +import yaml + from dataclasses import dataclass from itertools import islice from functools import partial @@ -25,6 +27,8 @@ except ImportError: from .parser import Parser from timm.bits import get_global_device, is_global_device +_logger = logging.getLogger(__name__) + SHUFFLE_SIZE = 8192 @@ -110,6 +114,12 @@ def _parse_split_info(split: str, info: Dict): return split_info +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + _logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''): """ Custom sample decode * decode and convert PIL Image @@ -135,7 +145,8 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l img = img.convert(image_format) # json passed through in undecoded state - return dict(jpg=img, cls=class_label, json=sample.get('json', None)) + decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None)) + return decoded def _decode_samples( @@ -144,7 +155,7 @@ def _decode_samples( image_format='RGB', target_key='cls', alt_label='', - handler=wds.reraise_exception): + handler=log_and_continue): """Decode samples with skip.""" for sample in data: try: @@ -251,7 +262,7 @@ class ParserWebdataset(Parser): wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed), self._split_by_node_and_worker, # at this point, we have an iterator over the shards assigned to each worker - wds.tarfile_to_samples(), + wds.tarfile_to_samples(handler=log_and_continue), wds.shuffle( self.sample_shuffle_size, rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline @@ -260,7 +271,7 @@ class ParserWebdataset(Parser): pipeline.extend([ self._split_by_node_and_worker, # at this point, we have an iterator over the shards assigned to each worker - wds.tarfile_to_samples(), + wds.tarfile_to_samples(handler=log_and_continue), ]) pipeline.extend([ partial( diff --git a/validate.py b/validate.py index d2eca03e..aa2555fc 100755 --- a/validate.py +++ b/validate.py @@ -262,6 +262,7 @@ def main(): batch_size = start_batch_size args.model = m args.checkpoint = c + args.num_gpu = 1 # FIXME support data-parallel? result = OrderedDict(model=args.model) r = {} while not r and batch_size >= 1: