Merge pull request #1662 from rwightman/dataset_info
ImageNet metadata (info) and labelling updatepull/1673/head
commit
88a5b8491d
@ -1,2 +1,3 @@
|
|||||||
include timm/models/pruned/*.txt
|
include timm/models/_pruned/*.txt
|
||||||
|
include timm/data/_info/*.txt
|
||||||
|
include timm/data/_info/*.json
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,32 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetInfo(ABC):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def num_classes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def label_names(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def index_to_label_name(self, index) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||||
|
pass
|
@ -0,0 +1,92 @@
|
|||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import pkgutil
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from .dataset_info import DatasetInfo
|
||||||
|
|
||||||
|
|
||||||
|
_NUM_CLASSES_TO_SUBSET = {
|
||||||
|
1000: 'imagenet-1k',
|
||||||
|
11821: 'imagenet-12k',
|
||||||
|
21841: 'imagenet-22k',
|
||||||
|
21843: 'imagenet-21k-goog',
|
||||||
|
11221: 'imagenet-21k-miil',
|
||||||
|
}
|
||||||
|
|
||||||
|
_SUBSETS = {
|
||||||
|
'imagenet1k': 'imagenet_synsets.txt',
|
||||||
|
'imagenet12k': 'imagenet12k_synsets.txt',
|
||||||
|
'imagenet22k': 'imagenet22k_synsets.txt',
|
||||||
|
'imagenet21k': 'imagenet21k_goog_synsets.txt',
|
||||||
|
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
|
||||||
|
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
|
||||||
|
}
|
||||||
|
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
|
||||||
|
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
|
||||||
|
|
||||||
|
|
||||||
|
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
|
||||||
|
if isinstance(model_or_cfg, dict):
|
||||||
|
num_classes = model_or_cfg.get('num_classes', None)
|
||||||
|
else:
|
||||||
|
num_classes = getattr(model_or_cfg, 'num_classes', None)
|
||||||
|
if not num_classes:
|
||||||
|
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
|
||||||
|
# FIXME at some point pretrained_cfg should include dataset-tag,
|
||||||
|
# which will be more robust than a guess based on num_classes
|
||||||
|
num_classes = pretrained_cfg.get('num_classes', None)
|
||||||
|
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
|
||||||
|
return None
|
||||||
|
return _NUM_CLASSES_TO_SUBSET[num_classes]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetInfo(DatasetInfo):
|
||||||
|
|
||||||
|
def __init__(self, subset: str = 'imagenet-1k'):
|
||||||
|
super().__init__()
|
||||||
|
subset = re.sub(r'[-_\s]', '', subset.lower())
|
||||||
|
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
|
||||||
|
|
||||||
|
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
|
||||||
|
synset_file = _SUBSETS[subset]
|
||||||
|
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
|
||||||
|
self._synsets = synset_data.decode('utf-8').splitlines()
|
||||||
|
|
||||||
|
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
|
||||||
|
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
|
||||||
|
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
|
||||||
|
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
|
||||||
|
self._lemmas = dict(reader)
|
||||||
|
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
|
||||||
|
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
|
||||||
|
self._definitions = dict(reader)
|
||||||
|
|
||||||
|
def num_classes(self):
|
||||||
|
return len(self._synsets)
|
||||||
|
|
||||||
|
def label_names(self):
|
||||||
|
return self._synsets
|
||||||
|
|
||||||
|
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||||
|
if as_dict:
|
||||||
|
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
|
||||||
|
else:
|
||||||
|
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
|
||||||
|
|
||||||
|
def index_to_label_name(self, index) -> str:
|
||||||
|
assert 0 <= index < len(self._synsets), \
|
||||||
|
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
|
||||||
|
return self._synsets[index]
|
||||||
|
|
||||||
|
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||||
|
label = self.index_to_label_name(index)
|
||||||
|
return self.label_name_to_description(label, detailed=detailed)
|
||||||
|
|
||||||
|
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||||
|
if detailed:
|
||||||
|
description = f'{self._lemmas[label]}: {self._definitions[label]}'
|
||||||
|
else:
|
||||||
|
description = f'{self._lemmas[label]}'
|
||||||
|
return description
|
Loading…
Reference in new issue