diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 757c2e5d..a4c18e39 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -151,7 +151,7 @@ def create_dataset( elif name.startswith('hfds/'): # NOTE right now, HF datasets default arrow format is a random-access Dataset, # There will be a IterableDataset variant too, TBD - ds = ImageDataset(root, reader=name, split=split, **kwargs) + ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs) elif name.startswith('tfds/'): ds = IterableImageDataset( root, diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py index 58ff56cd..226e3857 100644 --- a/timm/data/readers/reader_factory.py +++ b/timm/data/readers/reader_factory.py @@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar def create_reader(name, root, split='train', **kwargs): name = name.lower() - name = name.split('/', 2) + name = name.split('/', 1) prefix = '' if len(name) > 1: prefix = name[0] diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py index 901cf4bc..62ae5f4d 100644 --- a/timm/data/readers/reader_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -13,13 +13,14 @@ try: except ImportError as e: print("Please install Hugging Face datasets package `pip install datasets`.") exit(1) +from .class_map import load_class_map from .reader import Reader -def get_class_labels(info): +def get_class_labels(info, label_key='label'): if 'label' not in info.features: return {} - class_label = info.features['label'] + class_label = info.features[label_key] class_to_idx = {n: class_label.str2int(n) for n in class_label.names} return class_to_idx @@ -32,6 +33,7 @@ class ReaderHfds(Reader): name, split='train', class_map=None, + label_key='label', download=False, ): """ @@ -43,12 +45,17 @@ class ReaderHfds(Reader): name, # 'name' maps to path arg in hf datasets split=split, cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path - #use_auth_token=True, ) # leave decode for caller, plus we want easy access to original path names... self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False)) - self.class_to_idx = get_class_labels(self.dataset.info) + self.label_key = label_key + self.remap_class = False + if class_map: + self.class_to_idx = load_class_map(class_map) + self.remap_class = True + else: + self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) self.split_info = self.dataset.info.splits[split] self.num_samples = self.split_info.num_examples @@ -60,7 +67,10 @@ class ReaderHfds(Reader): else: assert 'path' in image and image['path'] image = open(image['path'], 'rb') - return image, item['label'] + label = item[self.label_key] + if self.remap_class: + label = self.class_to_idx[label] + return image, label def __len__(self): return len(self.dataset)