Allow using class_map functionality w/ IterableDataset (TFDS/WDS) to remap class labels

pull/1628/head
Ross Wightman 1 year ago
parent 01fdf44438
commit c061d5e401

@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset):
root,
reader=None,
split='train',
class_map=None,
is_training=False,
batch_size=None,
seed=42,
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
reader,
root=root,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
seed=seed,

@ -157,6 +157,7 @@ def create_dataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
download=download,
batch_size=batch_size,
@ -169,6 +170,7 @@ def create_dataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,

@ -34,6 +34,7 @@ except ImportError as e:
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
@ -94,6 +95,7 @@ class ReaderTfds(Reader):
root,
name,
split='train',
class_map=None,
is_training=False,
batch_size=None,
download=False,
@ -151,7 +153,12 @@ class ReaderTfds(Reader):
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
@ -299,6 +306,8 @@ class ReaderTfds(Reader):
target_data = sample[self.target_name]
if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:

@ -29,6 +29,7 @@ except ImportError:
wds = None
expand_urls = None
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
info_yaml = os.path.join(root, basename + '.yaml')
err_str = ''
try:
with wds.gopen.gopen(info_json) as f:
with wds.gopen(info_json) as f:
info_dict = json.load(f)
return info_dict
except Exception as e:
err_str = str(e)
try:
with wds.gopen.gopen(info_yaml) as f:
with wds.gopen(info_yaml) as f:
info_dict = yaml.safe_load(f)
return info_dict
except Exception:
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
if 'splits' not in info or split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
@ -290,6 +291,7 @@ class ReaderWds(Reader):
batch_size=None,
repeats=0,
seed=42,
class_map=None,
input_name='jpg',
input_image='RGB',
target_name='cls',
@ -320,6 +322,12 @@ class ReaderWds(Reader):
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = {}
# Distributed world state
self.dist_rank = 0
@ -431,7 +439,10 @@ class ReaderWds(Reader):
i = 0
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
target = sample[self.target_key]
if self.remap_class:
target = self.class_to_idx[target]
yield sample[self.image_key], target
i += 1
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug

Loading…
Cancel
Save