|
|
@ -106,18 +106,51 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_l
|
|
|
|
* cls byte string label to int
|
|
|
|
* cls byte string label to int
|
|
|
|
* pass through JSON byte string (if it exists) without parse
|
|
|
|
* pass through JSON byte string (if it exists) without parse
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
# decode class label, skip if alternate label not valid
|
|
|
|
|
|
|
|
if alt_label:
|
|
|
|
|
|
|
|
# alternative labels are encoded in json metadata
|
|
|
|
|
|
|
|
meta = json.loads(sample['json'])
|
|
|
|
|
|
|
|
class_label = int(meta[alt_label])
|
|
|
|
|
|
|
|
if class_label < 0:
|
|
|
|
|
|
|
|
# skipped labels currently encoded as -1, may change to a null/None value
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
class_label = int(sample[target_key])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# decode image
|
|
|
|
with io.BytesIO(sample[image_key]) as b:
|
|
|
|
with io.BytesIO(sample[image_key]) as b:
|
|
|
|
img = Image.open(b)
|
|
|
|
img = Image.open(b)
|
|
|
|
img.load()
|
|
|
|
img.load()
|
|
|
|
if image_format:
|
|
|
|
if image_format:
|
|
|
|
img = img.convert(image_format)
|
|
|
|
img = img.convert(image_format)
|
|
|
|
if alt_label:
|
|
|
|
|
|
|
|
# alternative labels are encoded in json metadata
|
|
|
|
# json passed through in undecoded state
|
|
|
|
assert 'json' in sample
|
|
|
|
return dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
|
|
|
meta = json.loads(sample['json'])
|
|
|
|
|
|
|
|
return dict(jpg=img, cls=int(meta[alt_label]), json=meta)
|
|
|
|
|
|
|
|
|
|
|
|
def _decode_samples(
|
|
|
|
|
|
|
|
data,
|
|
|
|
|
|
|
|
image_key='jpg',
|
|
|
|
|
|
|
|
image_format='RGB',
|
|
|
|
|
|
|
|
target_key='cls',
|
|
|
|
|
|
|
|
alt_label='',
|
|
|
|
|
|
|
|
handler=wds.reraise_exception):
|
|
|
|
|
|
|
|
"""Decode samples with skip."""
|
|
|
|
|
|
|
|
for sample in data:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
result = _decode(
|
|
|
|
|
|
|
|
sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label)
|
|
|
|
|
|
|
|
except Exception as exn:
|
|
|
|
|
|
|
|
if handler(exn):
|
|
|
|
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# null results are skipped
|
|
|
|
|
|
|
|
if result is not None:
|
|
|
|
|
|
|
|
if isinstance(sample, dict) and isinstance(result, dict):
|
|
|
|
|
|
|
|
result["__key__"] = sample.get("__key__")
|
|
|
|
|
|
|
|
yield result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParserWebdataset(Parser):
|
|
|
|
class ParserWebdataset(Parser):
|
|
|
@ -217,11 +250,11 @@ class ParserWebdataset(Parser):
|
|
|
|
wds.tarfile_to_samples(),
|
|
|
|
wds.tarfile_to_samples(),
|
|
|
|
])
|
|
|
|
])
|
|
|
|
pipeline.extend([
|
|
|
|
pipeline.extend([
|
|
|
|
wds.map(partial(
|
|
|
|
partial(
|
|
|
|
_decode,
|
|
|
|
_decode_samples,
|
|
|
|
image_key=self.image_key,
|
|
|
|
image_key=self.image_key,
|
|
|
|
image_format=self.image_format,
|
|
|
|
image_format=self.image_format,
|
|
|
|
alt_label=self.split_info.alt_label))
|
|
|
|
alt_label=self.split_info.alt_label)
|
|
|
|
])
|
|
|
|
])
|
|
|
|
self.ds = wds.DataPipeline(*pipeline)
|
|
|
|
self.ds = wds.DataPipeline(*pipeline)
|
|
|
|
self.init_count += 1
|
|
|
|
self.init_count += 1
|
|
|
|