Fix alternate label handling in WDS parser to skip invalid alt labels

pull/1239/head
Ross Wightman 2 years ago
parent a444d4b891
commit 229ac6b8d8

@ -106,18 +106,51 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l
* cls byte string label to int * cls byte string label to int
* pass through JSON byte string (if it exists) without parse * pass through JSON byte string (if it exists) without parse
""" """
# decode class label, skip if alternate label not valid
if alt_label:
# alternative labels are encoded in json metadata
meta = json.loads(sample['json'])
class_label = int(meta[alt_label])
if class_label < 0:
# skipped labels currently encoded as -1, may change to a null/None value
return None
else:
class_label = int(sample[target_key])
# decode image
with io.BytesIO(sample[image_key]) as b: with io.BytesIO(sample[image_key]) as b:
img = Image.open(b) img = Image.open(b)
img.load() img.load()
if image_format: if image_format:
img = img.convert(image_format) img = img.convert(image_format)
if alt_label:
# alternative labels are encoded in json metadata # json passed through in undecoded state
assert 'json' in sample return dict(jpg=img, cls=class_label, json=sample.get('json', None))
meta = json.loads(sample['json'])
return dict(jpg=img, cls=int(meta[alt_label]), json=meta)
else: def _decode_samples(
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None)) data,
image_key='jpg',
image_format='RGB',
target_key='cls',
alt_label='',
handler=wds.reraise_exception):
"""Decode samples with skip."""
for sample in data:
try:
result = _decode(
sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label)
except Exception as exn:
if handler(exn):
continue
else:
break
# null results are skipped
if result is not None:
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
class ParserWebdataset(Parser): class ParserWebdataset(Parser):
@ -217,11 +250,11 @@ class ParserWebdataset(Parser):
wds.tarfile_to_samples(), wds.tarfile_to_samples(),
]) ])
pipeline.extend([ pipeline.extend([
wds.map(partial( partial(
_decode, _decode_samples,
image_key=self.image_key, image_key=self.image_key,
image_format=self.image_format, image_format=self.image_format,
alt_label=self.split_info.alt_label)) alt_label=self.split_info.alt_label)
]) ])
self.ds = wds.DataPipeline(*pipeline) self.ds = wds.DataPipeline(*pipeline)
self.init_count += 1 self.init_count += 1

Loading…
Cancel
Save