You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.3 KiB
74 lines
2.3 KiB
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
|
|
class DatasetInfo(ABC):
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def num_classes(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def label_names(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def index_to_label_name(self, index) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
|
pass
|
|
|
|
|
|
class CustomDatasetInfo(DatasetInfo):
|
|
""" DatasetInfo that wraps passed values for custom datasets."""
|
|
|
|
def __init__(
|
|
self,
|
|
label_names: Union[List[str], Dict[int, str]],
|
|
label_descriptions: Optional[Dict[str, str]] = None
|
|
):
|
|
super().__init__()
|
|
assert len(label_names) > 0
|
|
self._label_names = label_names # label index => label name mapping
|
|
self._label_descriptions = label_descriptions # label name => label description mapping
|
|
if self._label_descriptions is not None:
|
|
# validate descriptions (label names required)
|
|
assert isinstance(self._label_descriptions, dict)
|
|
for n in self._label_names:
|
|
assert n in self._label_descriptions
|
|
|
|
def num_classes(self):
|
|
return len(self._label_names)
|
|
|
|
def label_names(self):
|
|
return self._label_names
|
|
|
|
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
|
return self._label_descriptions
|
|
|
|
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
|
if self._label_descriptions:
|
|
return self._label_descriptions[label]
|
|
return label # return label name itself if a descriptions is not present
|
|
|
|
def index_to_label_name(self, index) -> str:
|
|
assert 0 <= index < len(self._label_names)
|
|
return self._label_names[index]
|
|
|
|
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
|
label = self.index_to_label_name(index)
|
|
return self.label_name_to_description(label, detailed=detailed)
|