Add alternative label support to WDS for imagenet22k/12k split, add 21k/22k/12k indices filters to results/

pull/1239/head
Ross Wightman 3 years ago
parent da2796ae82
commit a444d4b891

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -45,6 +45,7 @@ class SplitInfo:
num_samples: int num_samples: int
filenames: Tuple[str] filenames: Tuple[str]
shard_lengths: Tuple[int] = () shard_lengths: Tuple[int] = ()
alt_label: str = ''
name: str = '' name: str = ''
@ -54,6 +55,7 @@ def _parse_split_info(split: str, info: Dict):
num_samples=dict_info['num_samples'], num_samples=dict_info['num_samples'],
filenames=tuple(dict_info['filenames']), filenames=tuple(dict_info['filenames']),
shard_lengths=tuple(dict_info['shard_lengths']), shard_lengths=tuple(dict_info['shard_lengths']),
alt_label=dict_info.get('alt_label', ''),
name=dict_info['name'], name=dict_info['name'],
) )
@ -98,7 +100,7 @@ def _parse_split_info(split: str, info: Dict):
return split_info return split_info
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'): def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
""" Custom sample decode """ Custom sample decode
* decode and convert PIL Image * decode and convert PIL Image
* cls byte string label to int * cls byte string label to int
@ -109,6 +111,12 @@ def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls'):
img.load() img.load()
if image_format: if image_format:
img = img.convert(image_format) img = img.convert(image_format)
if alt_label:
# alternative labels are encoded in json metadata
assert 'json' in sample
meta = json.loads(sample['json'])
return dict(jpg=img, cls=int(meta[alt_label]), json=meta)
else:
return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None)) return dict(jpg=img, cls=int(sample[target_key]), json=sample.get('json', None))
@ -209,7 +217,11 @@ class ParserWebdataset(Parser):
wds.tarfile_to_samples(), wds.tarfile_to_samples(),
]) ])
pipeline.extend([ pipeline.extend([
wds.map(partial(_decode, image_key=self.image_key, image_format=self.image_format)) wds.map(partial(
_decode,
image_key=self.image_key,
image_format=self.image_format,
alt_label=self.split_info.alt_label))
]) ])
self.ds = wds.DataPipeline(*pipeline) self.ds = wds.DataPipeline(*pipeline)
self.init_count += 1 self.init_count += 1

Loading…
Cancel
Save