|
|
@ -45,6 +45,7 @@ class SplitInfo:
|
|
|
|
num_samples: int
|
|
|
|
num_samples: int
|
|
|
|
filenames: Tuple[str]
|
|
|
|
filenames: Tuple[str]
|
|
|
|
shard_lengths: Tuple[int] = ()
|
|
|
|
shard_lengths: Tuple[int] = ()
|
|
|
|
|
|
|
|
alt_label: str = ''
|
|
|
|
name: str = ''
|
|
|
|
name: str = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -54,6 +55,7 @@ def _parse_split_info(split: str, info: Dict):
|
|
|
|
num_samples=dict_info['num_samples'],
|
|
|
|
num_samples=dict_info['num_samples'],
|
|
|
|
filenames=tuple(dict_info['filenames']),
|
|
|
|
filenames=tuple(dict_info['filenames']),
|
|
|
|
shard_lengths=tuple(dict_info['shard_lengths']),
|
|
|
|
shard_lengths=tuple(dict_info['shard_lengths']),
|
|
|
|
|
|
|
|
alt_label=dict_info.get('alt_label', ''),
|
|
|
|
name=dict_info['name'],
|
|
|
|
name=dict_info['name'],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -98,7 +100,7 @@ def _parse_split_info(split: str, info: Dict):
|
|
|
|
return split_info
|
|
|
|
return split_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'):
|
|
|
|
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
|
|
|
|
""" Custom sample decode
|
|
|
|
""" Custom sample decode
|
|
|
|
* decode and convert PIL Image
|
|
|
|
* decode and convert PIL Image
|
|
|
|
* cls byte string label to int
|
|
|
|
* cls byte string label to int
|
|
|
@ -109,7 +111,13 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'):
|
|
|
|
img.load()
|
|
|
|
img.load()
|
|
|
|
if image_format:
|
|
|
|
if image_format:
|
|
|
|
img = img.convert(image_format)
|
|
|
|
img = img.convert(image_format)
|
|
|
|
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
|
|
|
|
if alt_label:
|
|
|
|
|
|
|
|
# alternative labels are encoded in json metadata
|
|
|
|
|
|
|
|
assert 'json' in sample
|
|
|
|
|
|
|
|
meta = json.loads(sample['json'])
|
|
|
|
|
|
|
|
return dict(jpg=img, cls=int(meta[alt_label]), json=meta)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParserWebdataset(Parser):
|
|
|
|
class ParserWebdataset(Parser):
|
|
|
@ -209,7 +217,11 @@ class ParserWebdataset(Parser):
|
|
|
|
wds.tarfile_to_samples(),
|
|
|
|
wds.tarfile_to_samples(),
|
|
|
|
])
|
|
|
|
])
|
|
|
|
pipeline.extend([
|
|
|
|
pipeline.extend([
|
|
|
|
wds.map(partial(_decode, image_key=self.image_key, image_format=self.image_format))
|
|
|
|
wds.map(partial(
|
|
|
|
|
|
|
|
_decode,
|
|
|
|
|
|
|
|
image_key=self.image_key,
|
|
|
|
|
|
|
|
image_format=self.image_format,
|
|
|
|
|
|
|
|
alt_label=self.split_info.alt_label))
|
|
|
|
])
|
|
|
|
])
|
|
|
|
self.ds = wds.DataPipeline(*pipeline)
|
|
|
|
self.ds = wds.DataPipeline(*pipeline)
|
|
|
|
self.init_count += 1
|
|
|
|
self.init_count += 1
|
|
|
|