|
|
|
@ -13,13 +13,14 @@ try:
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
print("Please install Hugging Face datasets package `pip install datasets`.")
|
|
|
|
|
exit(1)
|
|
|
|
|
from .class_map import load_class_map
|
|
|
|
|
from .reader import Reader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_class_labels(info):
|
|
|
|
|
def get_class_labels(info, label_key='label'):
|
|
|
|
|
if 'label' not in info.features:
|
|
|
|
|
return {}
|
|
|
|
|
class_label = info.features['label']
|
|
|
|
|
class_label = info.features[label_key]
|
|
|
|
|
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
|
|
|
|
|
return class_to_idx
|
|
|
|
|
|
|
|
|
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
|
|
|
|
|
name,
|
|
|
|
|
split='train',
|
|
|
|
|
class_map=None,
|
|
|
|
|
label_key='label',
|
|
|
|
|
download=False,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
|
|
|
|
|
name, # 'name' maps to path arg in hf datasets
|
|
|
|
|
split=split,
|
|
|
|
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
|
|
|
|
#use_auth_token=True,
|
|
|
|
|
)
|
|
|
|
|
# leave decode for caller, plus we want easy access to original path names...
|
|
|
|
|
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
|
|
|
|
|
|
|
|
|
|
self.class_to_idx = get_class_labels(self.dataset.info)
|
|
|
|
|
self.label_key = label_key
|
|
|
|
|
self.remap_class = False
|
|
|
|
|
if class_map:
|
|
|
|
|
self.class_to_idx = load_class_map(class_map)
|
|
|
|
|
self.remap_class = True
|
|
|
|
|
else:
|
|
|
|
|
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
|
|
|
|
|
self.split_info = self.dataset.info.splits[split]
|
|
|
|
|
self.num_samples = self.split_info.num_examples
|
|
|
|
|
|
|
|
|
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
|
|
|
|
|
else:
|
|
|
|
|
assert 'path' in image and image['path']
|
|
|
|
|
image = open(image['path'], 'rb')
|
|
|
|
|
return image, item['label']
|
|
|
|
|
label = item[self.label_key]
|
|
|
|
|
if self.remap_class:
|
|
|
|
|
label = self.class_to_idx[label]
|
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.dataset)
|
|
|
|
|