Support HF datasets and TFSD w/ a sub-path by fixing split, fix #1598 ... add class mapping support to HF datasets in case class label isn't in info.

pull/1593/head
Ross Wightman 2 years ago
parent 35fb00c779
commit d1bfa9a000

@ -151,7 +151,7 @@ def create_dataset(
elif name.startswith('hfds/'): elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset, # NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD # There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, **kwargs) ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
elif name.startswith('tfds/'): elif name.startswith('tfds/'):
ds = IterableImageDataset( ds = IterableImageDataset(
root, root,

@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs): def create_reader(name, root, split='train', **kwargs):
name = name.lower() name = name.lower()
name = name.split('/', 2) name = name.split('/', 1)
prefix = '' prefix = ''
if len(name) > 1: if len(name) > 1:
prefix = name[0] prefix = name[0]

@ -13,13 +13,14 @@ try:
except ImportError as e: except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.") print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1) exit(1)
from .class_map import load_class_map
from .reader import Reader from .reader import Reader
def get_class_labels(info): def get_class_labels(info, label_key='label'):
if 'label' not in info.features: if 'label' not in info.features:
return {} return {}
class_label = info.features['label'] class_label = info.features[label_key]
class_to_idx = {n: class_label.str2int(n) for n in class_label.names} class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
return class_to_idx return class_to_idx
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
name, name,
split='train', split='train',
class_map=None, class_map=None,
label_key='label',
download=False, download=False,
): ):
""" """
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
name, # 'name' maps to path arg in hf datasets name, # 'name' maps to path arg in hf datasets
split=split, split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
#use_auth_token=True,
) )
# leave decode for caller, plus we want easy access to original path names... # leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False)) self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
self.class_to_idx = get_class_labels(self.dataset.info) self.label_key = label_key
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
self.split_info = self.dataset.info.splits[split] self.split_info = self.dataset.info.splits[split]
self.num_samples = self.split_info.num_examples self.num_samples = self.split_info.num_examples
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
else: else:
assert 'path' in image and image['path'] assert 'path' in image and image['path']
image = open(image['path'], 'rb') image = open(image['path'], 'rb')
return image, item['label'] label = item[self.label_key]
if self.remap_class:
label = self.class_to_idx[label]
return image, label
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)

Loading…
Cancel
Save