diff --git a/timm/data/dataset.py b/timm/data/dataset.py index e7f67925..17c08e4d 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset): root, reader=None, split='train', + class_map=None, is_training=False, batch_size=None, seed=42, @@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset): reader, root=root, split=split, + class_map=class_map, is_training=is_training, batch_size=batch_size, seed=seed, diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index a4c18e39..6f0dcfcd 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -157,6 +157,7 @@ def create_dataset( root, reader=name, split=split, + class_map=class_map, is_training=is_training, download=download, batch_size=batch_size, @@ -169,6 +170,7 @@ def create_dataset( root, reader=name, split=split, + class_map=class_map, is_training=is_training, batch_size=batch_size, repeats=repeats, diff --git a/timm/data/readers/reader_tfds.py b/timm/data/readers/reader_tfds.py index 25aab471..012a27a9 100644 --- a/timm/data/readers/reader_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -34,6 +34,7 @@ except ImportError as e: print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") exit(1) +from .class_map import load_class_map from .reader import Reader from .shared_count import SharedCount @@ -94,6 +95,7 @@ class ReaderTfds(Reader): root, name, split='train', + class_map=None, is_training=False, batch_size=None, download=False, @@ -151,7 +153,12 @@ class ReaderTfds(Reader): # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() - self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} + self.remap_class = False + if class_map: + self.class_to_idx = load_class_map(class_map) + self.remap_class = True + else: + self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.split_info = self.builder.info.splits[split] self.num_samples = self.split_info.num_examples @@ -299,6 +306,8 @@ class ReaderTfds(Reader): target_data = sample[self.target_name] if self.target_img_mode: target_data = Image.fromarray(target_data, mode=self.target_img_mode) + elif self.remap_class: + target_data = self.class_to_idx[target_data] yield input_data, target_data sample_count += 1 if self.is_training and sample_count >= target_sample_count: diff --git a/timm/data/readers/reader_wds.py b/timm/data/readers/reader_wds.py index 36890eed..3bf99d26 100644 --- a/timm/data/readers/reader_wds.py +++ b/timm/data/readers/reader_wds.py @@ -29,6 +29,7 @@ except ImportError: wds = None expand_urls = None +from .class_map import load_class_map from .reader import Reader from .shared_count import SharedCount @@ -42,13 +43,13 @@ def _load_info(root, basename='info'): info_yaml = os.path.join(root, basename + '.yaml') err_str = '' try: - with wds.gopen.gopen(info_json) as f: + with wds.gopen(info_json) as f: info_dict = json.load(f) return info_dict except Exception as e: err_str = str(e) try: - with wds.gopen.gopen(info_yaml) as f: + with wds.gopen(info_yaml) as f: info_dict = yaml.safe_load(f) return info_dict except Exception: @@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict): filenames=split_filenames, ) else: - if split not in info['splits']: - raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})") + if 'splits' not in info or split not in info['splits']: + raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})") split = split split_info = info['splits'][split] split_info = _info_convert(split_info) @@ -290,6 +291,7 @@ class ReaderWds(Reader): batch_size=None, repeats=0, seed=42, + class_map=None, input_name='jpg', input_image='RGB', target_name='cls', @@ -320,6 +322,12 @@ class ReaderWds(Reader): self.num_samples = self.split_info.num_samples if not self.num_samples: raise RuntimeError(f'Invalid split definition, no samples found.') + self.remap_class = False + if class_map: + self.class_to_idx = load_class_map(class_map) + self.remap_class = True + else: + self.class_to_idx = {} # Distributed world state self.dist_rank = 0 @@ -431,7 +439,10 @@ class ReaderWds(Reader): i = 0 # _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug for sample in ds: - yield sample[self.image_key], sample[self.target_key] + target = sample[self.target_key] + if self.remap_class: + target = self.class_to_idx[target] + yield sample[self.image_key], target i += 1 # _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug