Improve support for custom dataset label name/description through HF hub export, via pretrained_cfg

pull/1673/head
Ross Wightman 2 years ago
parent 1e0b347227
commit 9c14654a0d

@ -4,7 +4,7 @@ from .config import resolve_data_config, resolve_model_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo from .dataset_info import DatasetInfo, CustomDatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
class DatasetInfo(ABC): class DatasetInfo(ABC):
@ -30,3 +30,44 @@ class DatasetInfo(ABC):
@abstractmethod @abstractmethod
def label_name_to_description(self, label: str, detailed: bool = False) -> str: def label_name_to_description(self, label: str, detailed: bool = False) -> str:
pass pass
class CustomDatasetInfo(DatasetInfo):
""" DatasetInfo that wraps passed values for custom datasets."""
def __init__(
self,
label_names: Union[List[str], Dict[int, str]],
label_descriptions: Optional[Dict[str, str]] = None
):
super().__init__()
assert len(label_names) > 0
self._label_names = label_names # label index => label name mapping
self._label_descriptions = label_descriptions # label name => label description mapping
if self._label_descriptions is not None:
# validate descriptions (label names required)
assert isinstance(self._label_descriptions, dict)
for n in self._label_names:
assert n in self._label_descriptions
def num_classes(self):
return len(self._label_names)
def label_names(self):
return self._label_names
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
return self._label_descriptions
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
if self._label_descriptions:
return self._label_descriptions[label]
return label # return label name itself if a descriptions is not present
def index_to_label_name(self, index) -> str:
assert 0 <= index < len(self._label_names)
return self._label_names[index]
def index_to_description(self, index: int, detailed: bool = False) -> str:
label = self.index_to_label_name(index)
return self.label_name_to_description(label, detailed=detailed)

@ -16,6 +16,7 @@ except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
from timm import __version__ from timm import __version__
from timm.layers import ClassifierHead, NormMlpClassifierHead
from timm.models._pretrained import filter_pretrained_cfg from timm.models._pretrained import filter_pretrained_cfg
try: try:
@ -96,7 +97,7 @@ def has_hf_hub(necessary=False):
return _has_hf_hub return _has_hf_hub
def hf_split(hf_id): def hf_split(hf_id: str):
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
rev_split = hf_id.split('@') rev_split = hf_id.split('@')
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
@ -127,19 +128,26 @@ def load_model_config_from_hf(model_id: str):
hf_config = {} hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture') hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None) hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg: if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
hf_config['label_name'] = pretrained_cfg.pop('labels') pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg'] pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub' pretrained_cfg['source'] = 'hf-hub'
# model should be created with base config num_classes if its exist
if 'num_classes' in hf_config: if 'num_classes' in hf_config:
# model should be created with parent num_classes if they exist
pretrained_cfg['num_classes'] = hf_config['num_classes'] pretrained_cfg['num_classes'] = hf_config['num_classes']
model_name = hf_config['architecture']
# label meta-data in base config overrides saved pretrained_cfg on load
if 'label_names' in hf_config:
pretrained_cfg['label_names'] = hf_config.pop('label_names')
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
model_name = hf_config['architecture']
return pretrained_cfg, model_name return pretrained_cfg, model_name
@ -150,7 +158,7 @@ def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
return state_dict return state_dict
def save_config_for_hf(model, config_path, model_config=None): def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
model_config = model_config or {} model_config = model_config or {}
hf_config = {} hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -164,22 +172,22 @@ def save_config_for_hf(model, config_path, model_config=None):
if 'labels' in model_config: if 'labels' in model_config:
_logger.warning( _logger.warning(
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
"Using provided 'label' field as 'label_name'.") " Renaming provided 'labels' field to 'label_names'.")
model_config['label_name'] = model_config.pop('labels') model_config.setdefault('label_names', model_config.pop('labels'))
label_name = model_config.pop('label_name', None) label_names = model_config.pop('label_names', None)
if label_name: if label_names:
assert isinstance(label_name, (dict, list, tuple)) assert isinstance(label_names, (dict, list, tuple))
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
# can be a dict id: name if there are id gaps, or tuple/list if no gaps. # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
hf_config['label_name'] = model_config['label_name'] hf_config['label_names'] = label_names
display_name = model_config.pop('display_name', None) label_descriptions = model_config.pop('label_descriptions', None)
if display_name: if label_descriptions:
assert isinstance(display_name, dict) assert isinstance(label_descriptions, dict)
# map label_name -> user interface display name # maps label names -> descriptions
hf_config['display_name'] = model_config['display_name'] hf_config['label_descriptions'] = label_descriptions
hf_config['pretrained_cfg'] = pretrained_cfg hf_config['pretrained_cfg'] = pretrained_cfg
hf_config.update(model_config) hf_config.update(model_config)
@ -188,7 +196,7 @@ def save_config_for_hf(model, config_path, model_config=None):
json.dump(hf_config, f, indent=2) json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory, model_config=None): def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
assert has_hf_hub(True) assert has_hf_hub(True)
save_directory = Path(save_directory) save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True) save_directory.mkdir(exist_ok=True, parents=True)
@ -249,7 +257,7 @@ def push_to_hf_hub(
) )
def generate_readme(model_card, model_name): def generate_readme(model_card: dict, model_name: str):
readme_text = "---\n" readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n" readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n" readme_text += "library_tag: timm\n"

@ -34,9 +34,11 @@ class PretrainedCfg:
mean: Tuple[float, ...] = (0.485, 0.456, 0.406) mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225) std: Tuple[float, ...] = (0.229, 0.224, 0.225)
# head config # head / classifier config and meta-data
num_classes: int = 1000 num_classes: int = 1000
label_offset: Optional[int] = None label_offset: Optional[int] = None
label_names: Optional[Tuple[str]] = None
label_descriptions: Optional[Dict[str, str]] = None
# model attributes that vary with above or required for pretrained adaptation # model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None pool_size: Optional[Tuple[int, ...]] = None

Loading…
Cancel
Save