From 229ac6b8d88e3523db741ba74b5ac9da5b30720c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 11 Mar 2022 19:16:04 -0800 Subject: [PATCH] Fix alternate label handling in WDS parser to skip invalid alt labels --- timm/data/parsers/parser_wds.py | 53 ++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/timm/data/parsers/parser_wds.py b/timm/data/parsers/parser_wds.py index 83df8929..47ad6184 100644 --- a/timm/data/parsers/parser_wds.py +++ b/timm/data/parsers/parser_wds.py @@ -106,18 +106,51 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l * cls byte string label to int * pass through JSON byte string (if it exists) without parse """ + # decode class label, skip if alternate label not valid + if alt_label: + # alternative labels are encoded in json metadata + meta = json.loads(sample['json']) + class_label = int(meta[alt_label]) + if class_label < 0: + # skipped labels currently encoded as -1, may change to a null/None value + return None + else: + class_label = int(sample[target_key]) + + # decode image with io.BytesIO(sample[image_key]) as b: img = Image.open(b) img.load() if image_format: img = img.convert(image_format) - if alt_label: - # alternative labels are encoded in json metadata - assert 'json' in sample - meta = json.loads(sample['json']) - return dict(jpg=img, cls=int(meta[alt_label]), json=meta) - else: - return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None)) + + # json passed through in undecoded state + return dict(jpg=img, cls=class_label, json=sample.get('json', None)) + + +def _decode_samples( + data, + image_key='jpg', + image_format='RGB', + target_key='cls', + alt_label='', + handler=wds.reraise_exception): + """Decode samples with skip.""" + for sample in data: + try: + result = _decode( + sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label) + except Exception as exn: + if handler(exn): + continue + else: + break + + # null results are skipped + if result is not None: + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + yield result class ParserWebdataset(Parser): @@ -217,11 +250,11 @@ class ParserWebdataset(Parser): wds.tarfile_to_samples(), ]) pipeline.extend([ - wds.map(partial( - _decode, + partial( + _decode_samples, image_key=self.image_key, image_format=self.image_format, - alt_label=self.split_info.alt_label)) + alt_label=self.split_info.alt_label) ]) self.ds = wds.DataPipeline(*pipeline) self.init_count += 1