diff --git a/timm/data/__init__.py b/timm/data/__init__.py index b31b3c6b..4b95fbd1 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -4,7 +4,7 @@ from .config import resolve_data_config, resolve_model_data_config from .constants import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset -from .dataset_info import DatasetInfo +from .dataset_info import DatasetInfo, CustomDatasetInfo from .imagenet_info import ImageNetInfo, infer_imagenet_subset from .loader import create_loader from .mixup import Mixup, FastCollateMixup diff --git a/timm/data/dataset_info.py b/timm/data/dataset_info.py index 107c3318..58e46196 100644 --- a/timm/data/dataset_info.py +++ b/timm/data/dataset_info.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union class DatasetInfo(ABC): @@ -29,4 +29,45 @@ class DatasetInfo(ABC): @abstractmethod def label_name_to_description(self, label: str, detailed: bool = False) -> str: - pass \ No newline at end of file + pass + + +class CustomDatasetInfo(DatasetInfo): + """ DatasetInfo that wraps passed values for custom datasets.""" + + def __init__( + self, + label_names: Union[List[str], Dict[int, str]], + label_descriptions: Optional[Dict[str, str]] = None + ): + super().__init__() + assert len(label_names) > 0 + self._label_names = label_names # label index => label name mapping + self._label_descriptions = label_descriptions # label name => label description mapping + if self._label_descriptions is not None: + # validate descriptions (label names required) + assert isinstance(self._label_descriptions, dict) + for n in self._label_names: + assert n in self._label_descriptions + + def num_classes(self): + return len(self._label_names) + + def label_names(self): + return self._label_names + + def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]: + return self._label_descriptions + + def label_name_to_description(self, label: str, detailed: bool = False) -> str: + if self._label_descriptions: + return self._label_descriptions[label] + return label # return label name itself if a descriptions is not present + + def index_to_label_name(self, index) -> str: + assert 0 <= index < len(self._label_names) + return self._label_names[index] + + def index_to_description(self, index: int, detailed: bool = False) -> str: + label = self.index_to_label_name(index) + return self.label_name_to_description(label, detailed=detailed) diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 378d646c..61594753 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -16,6 +16,7 @@ except ImportError: from torch.hub import _get_torch_home as get_dir from timm import __version__ +from timm.layers import ClassifierHead, NormMlpClassifierHead from timm.models._pretrained import filter_pretrained_cfg try: @@ -96,7 +97,7 @@ def has_hf_hub(necessary=False): return _has_hf_hub -def hf_split(hf_id): +def hf_split(hf_id: str): # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme rev_split = hf_id.split('@') assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' @@ -127,19 +128,26 @@ def load_model_config_from_hf(model_id: str): hf_config = {} hf_config['architecture'] = pretrained_cfg.pop('architecture') hf_config['num_features'] = pretrained_cfg.pop('num_features', None) - if 'labels' in pretrained_cfg: - hf_config['label_name'] = pretrained_cfg.pop('labels') + if 'labels' in pretrained_cfg: # deprecated name for 'label_names' + pretrained_cfg['label_names'] = pretrained_cfg.pop('labels') hf_config['pretrained_cfg'] = pretrained_cfg # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now pretrained_cfg = hf_config['pretrained_cfg'] pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation pretrained_cfg['source'] = 'hf-hub' + + # model should be created with base config num_classes if its exist if 'num_classes' in hf_config: - # model should be created with parent num_classes if they exist pretrained_cfg['num_classes'] = hf_config['num_classes'] - model_name = hf_config['architecture'] + # label meta-data in base config overrides saved pretrained_cfg on load + if 'label_names' in hf_config: + pretrained_cfg['label_names'] = hf_config.pop('label_names') + if 'label_descriptions' in hf_config: + pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions') + + model_name = hf_config['architecture'] return pretrained_cfg, model_name @@ -150,7 +158,7 @@ def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): return state_dict -def save_config_for_hf(model, config_path, model_config=None): +def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): model_config = model_config or {} hf_config = {} pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) @@ -164,22 +172,22 @@ def save_config_for_hf(model, config_path, model_config=None): if 'labels' in model_config: _logger.warning( - "'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " - "Using provided 'label' field as 'label_name'.") - model_config['label_name'] = model_config.pop('labels') + "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'." + " Renaming provided 'labels' field to 'label_names'.") + model_config.setdefault('label_names', model_config.pop('labels')) - label_name = model_config.pop('label_name', None) - if label_name: - assert isinstance(label_name, (dict, list, tuple)) + label_names = model_config.pop('label_names', None) + if label_names: + assert isinstance(label_names, (dict, list, tuple)) # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) # can be a dict id: name if there are id gaps, or tuple/list if no gaps. - hf_config['label_name'] = model_config['label_name'] + hf_config['label_names'] = label_names - display_name = model_config.pop('display_name', None) - if display_name: - assert isinstance(display_name, dict) - # map label_name -> user interface display name - hf_config['display_name'] = model_config['display_name'] + label_descriptions = model_config.pop('label_descriptions', None) + if label_descriptions: + assert isinstance(label_descriptions, dict) + # maps label names -> descriptions + hf_config['label_descriptions'] = label_descriptions hf_config['pretrained_cfg'] = pretrained_cfg hf_config.update(model_config) @@ -188,7 +196,7 @@ def save_config_for_hf(model, config_path, model_config=None): json.dump(hf_config, f, indent=2) -def save_for_hf(model, save_directory, model_config=None): +def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None): assert has_hf_hub(True) save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True, parents=True) @@ -249,7 +257,7 @@ def push_to_hf_hub( ) -def generate_readme(model_card, model_name): +def generate_readme(model_card: dict, model_name: str): readme_text = "---\n" readme_text += "tags:\n- image-classification\n- timm\n" readme_text += "library_tag: timm\n" diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index dca81eb0..11e4cff5 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -34,9 +34,11 @@ class PretrainedCfg: mean: Tuple[float, ...] = (0.485, 0.456, 0.406) std: Tuple[float, ...] = (0.229, 0.224, 0.225) - # head config + # head / classifier config and meta-data num_classes: int = 1000 label_offset: Optional[int] = None + label_names: Optional[Tuple[str]] = None + label_descriptions: Optional[Dict[str, str]] = None # model attributes that vary with above or required for pretrained adaptation pool_size: Optional[Tuple[int, ...]] = None