Fix alternate label handling in WDS parser to skip invalid alt labels

pull/1239/head
Ross Wightman 2 years ago
parent a444d4b891
commit 229ac6b8d8

@ -106,18 +106,51 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l
* cls byte string label to int
* pass through JSON byte string (if it exists) without parse
"""
# decode class label, skip if alternate label not valid
if alt_label:
# alternative labels are encoded in json metadata
meta = json.loads(sample['json'])
class_label = int(meta[alt_label])
if class_label < 0:
# skipped labels currently encoded as -1, may change to a null/None value
return None
else:
class_label = int(sample[target_key])
# decode image
with io.BytesIO(sample[image_key]) as b:
img = Image.open(b)
img.load()
if image_format:
img = img.convert(image_format)
if alt_label:
# alternative labels are encoded in json metadata
assert 'json' in sample
meta = json.loads(sample['json'])
return dict(jpg=img, cls=int(meta[alt_label]), json=meta)
else:
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
# json passed through in undecoded state
return dict(jpg=img, cls=class_label, json=sample.get('json', None))
def _decode_samples(
data,
image_key='jpg',
image_format='RGB',
target_key='cls',
alt_label='',
handler=wds.reraise_exception):
"""Decode samples with skip."""
for sample in data:
try:
result = _decode(
sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label)
except Exception as exn:
if handler(exn):
continue
else:
break
# null results are skipped
if result is not None:
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
class ParserWebdataset(Parser):
@ -217,11 +250,11 @@ class ParserWebdataset(Parser):
wds.tarfile_to_samples(),
])
pipeline.extend([
wds.map(partial(
_decode,
partial(
_decode_samples,
image_key=self.image_key,
image_format=self.image_format,
alt_label=self.split_info.alt_label))
alt_label=self.split_info.alt_label)
])
self.ds = wds.DataPipeline(*pipeline)
self.init_count += 1

Loading…
Cancel
Save