You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.7 KiB
93 lines
3.7 KiB
import csv
|
|
import os
|
|
import pkgutil
|
|
import re
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
from .dataset_info import DatasetInfo
|
|
|
|
|
|
_NUM_CLASSES_TO_SUBSET = {
|
|
1000: 'imagenet-1k',
|
|
11821: 'imagenet-12k',
|
|
21841: 'imagenet-22k',
|
|
21843: 'imagenet-21k-goog',
|
|
11221: 'imagenet-21k-miil',
|
|
}
|
|
|
|
_SUBSETS = {
|
|
'imagenet1k': 'imagenet_synsets.txt',
|
|
'imagenet12k': 'imagenet12k_synsets.txt',
|
|
'imagenet22k': 'imagenet22k_synsets.txt',
|
|
'imagenet21k': 'imagenet21k_goog_synsets.txt',
|
|
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
|
|
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
|
|
}
|
|
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
|
|
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
|
|
|
|
|
|
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
|
|
if isinstance(model_or_cfg, dict):
|
|
num_classes = model_or_cfg.get('num_classes', None)
|
|
else:
|
|
num_classes = getattr(model_or_cfg, 'num_classes', None)
|
|
if not num_classes:
|
|
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
|
|
# FIXME at some point pretrained_cfg should include dataset-tag,
|
|
# which will be more robust than a guess based on num_classes
|
|
num_classes = pretrained_cfg.get('num_classes', None)
|
|
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
|
|
return None
|
|
return _NUM_CLASSES_TO_SUBSET[num_classes]
|
|
|
|
|
|
class ImageNetInfo(DatasetInfo):
|
|
|
|
def __init__(self, subset: str = 'imagenet-1k'):
|
|
super().__init__()
|
|
subset = re.sub(r'[-_\s]', '', subset.lower())
|
|
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
|
|
|
|
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
|
|
synset_file = _SUBSETS[subset]
|
|
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
|
|
self._synsets = synset_data.decode('utf-8').splitlines()
|
|
|
|
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
|
|
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
|
|
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
|
|
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
|
|
self._lemmas = dict(reader)
|
|
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
|
|
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
|
|
self._definitions = dict(reader)
|
|
|
|
def num_classes(self):
|
|
return len(self._synsets)
|
|
|
|
def label_names(self):
|
|
return self._synsets
|
|
|
|
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
|
if as_dict:
|
|
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
|
|
else:
|
|
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
|
|
|
|
def index_to_label_name(self, index) -> str:
|
|
assert 0 <= index < len(self._synsets), \
|
|
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
|
|
return self._synsets[index]
|
|
|
|
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
|
label = self.index_to_label_name(index)
|
|
return self.label_name_to_description(label, detailed=detailed)
|
|
|
|
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
|
if detailed:
|
|
description = f'{self._lemmas[label]}: {self._definitions[label]}'
|
|
else:
|
|
description = f'{self._lemmas[label]}'
|
|
return description
|