You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/imagenet_info.py

93 lines
3.7 KiB

import csv
import os
import pkgutil
import re
from typing import Dict, List, Optional, Union
from .dataset_info import DatasetInfo
_NUM_CLASSES_TO_SUBSET = {
1000: 'imagenet-1k',
11821: 'imagenet-12k',
21841: 'imagenet-22k',
21843: 'imagenet-21k-goog',
11221: 'imagenet-21k-miil',
}
_SUBSETS = {
'imagenet1k': 'imagenet_synsets.txt',
'imagenet12k': 'imagenet12k_synsets.txt',
'imagenet22k': 'imagenet22k_synsets.txt',
'imagenet21k': 'imagenet21k_goog_synsets.txt',
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
}
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
if isinstance(model_or_cfg, dict):
num_classes = model_or_cfg.get('num_classes', None)
else:
num_classes = getattr(model_or_cfg, 'num_classes', None)
if not num_classes:
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
# FIXME at some point pretrained_cfg should include dataset-tag,
# which will be more robust than a guess based on num_classes
num_classes = pretrained_cfg.get('num_classes', None)
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
return None
return _NUM_CLASSES_TO_SUBSET[num_classes]
class ImageNetInfo(DatasetInfo):
def __init__(self, subset: str = 'imagenet-1k'):
super().__init__()
subset = re.sub(r'[-_\s]', '', subset.lower())
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
synset_file = _SUBSETS[subset]
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
self._synsets = synset_data.decode('utf-8').splitlines()
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
self._lemmas = dict(reader)
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
self._definitions = dict(reader)
def num_classes(self):
return len(self._synsets)
def label_names(self):
return self._synsets
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
if as_dict:
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
else:
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
def index_to_label_name(self, index) -> str:
assert 0 <= index < len(self._synsets), \
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
return self._synsets[index]
def index_to_description(self, index: int, detailed: bool = False) -> str:
label = self.index_to_label_name(index)
return self.label_name_to_description(label, detailed=detailed)
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
if detailed:
description = f'{self._lemmas[label]}: {self._definitions[label]}'
else:
description = f'{self._lemmas[label]}'
return description