Merge pull request #1662 from rwightman/dataset_info

ImageNet metadata (info) and labelling update
pull/1673/head
Ross Wightman 2 years ago committed by GitHub
commit 88a5b8491d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,2 +1,3 @@
include timm/models/pruned/*.txt
include timm/models/_pruned/*.txt
include timm/data/_info/*.txt
include timm/data/_info/*.json

@ -17,7 +17,7 @@ import numpy as np
import pandas as pd
import torch
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
@ -46,6 +46,7 @@ has_compile = hasattr(torch, 'compile')
_FMT_EXT = {
'json': '.json',
'json-record': '.json',
'json-split': '.json',
'parquet': '.parquet',
'csv': '.csv',
@ -122,7 +123,7 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
parser.add_argument('--results-dir',type=str, default=None,
parser.add_argument('--results-dir', type=str, default=None,
help='folder for output results')
parser.add_argument('--results-file', type=str, default=None,
help='results filename (relative to results-dir)')
@ -134,14 +135,20 @@ parser.add_argument('--topk', default=1, type=int,
metavar='N', help='Top-k to output to CSV')
parser.add_argument('--fullname', action='store_true', default=False,
help='use full sample name in output (not just basename).')
parser.add_argument('--filename-col', default='filename',
parser.add_argument('--filename-col', type=str, default='filename',
help='name for filename / sample name column')
parser.add_argument('--index-col', default='index',
parser.add_argument('--index-col', type=str, default='index',
help='name for output indices column(s)')
parser.add_argument('--output-col', default=None,
parser.add_argument('--label-col', type=str, default='label',
help='name for output indices column(s)')
parser.add_argument('--output-col', type=str, default=None,
help='name for logit/probs output column(s)')
parser.add_argument('--output-type', default='prob',
parser.add_argument('--output-type', type=str, default='prob',
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
parser.add_argument('--label-type', type=str, default='description',
help='type of label to output, one of "none", "name", "description", "detailed"')
parser.add_argument('--include-index', action='store_true', default=False,
help='include the class index in results')
parser.add_argument('--exclude-output', action='store_true', default=False,
help='exclude logits/probs from results, just indices. topk must be set !=0.')
@ -237,10 +244,26 @@ def main():
**data_config,
)
to_label = None
if args.label_type in ('name', 'description', 'detail'):
imagenet_subset = infer_imagenet_subset(model)
if imagenet_subset is not None:
dataset_info = ImageNetInfo(imagenet_subset)
if args.label_type == 'name':
to_label = lambda x: dataset_info.index_to_label_name(x)
elif args.label_type == 'detail':
to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
else:
to_label = lambda x: dataset_info.index_to_description(x)
to_label = np.vectorize(to_label)
else:
_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")
top_k = min(args.topk, args.num_classes)
batch_time = AverageMeter()
end = time.time()
all_indices = []
all_labels = []
all_outputs = []
use_probs = args.output_type == 'prob'
with torch.no_grad():
@ -254,7 +277,12 @@ def main():
if top_k:
output, indices = output.topk(top_k)
all_indices.append(indices.cpu().numpy())
np_indices = indices.cpu().numpy()
if args.include_index:
all_indices.append(np_indices)
if to_label is not None:
np_labels = to_label(np_indices)
all_labels.append(np_labels)
all_outputs.append(output.cpu().numpy())
@ -267,6 +295,7 @@ def main():
batch_idx, len(loader), batch_time=batch_time))
all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
filenames = loader.dataset.filenames(basename=not args.fullname)
@ -276,6 +305,9 @@ def main():
if all_indices is not None:
for i in range(all_indices.shape[-1]):
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
if all_labels is not None:
for i in range(all_labels.shape[-1]):
data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
for i in range(all_outputs.shape[-1]):
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
else:
@ -283,6 +315,10 @@ def main():
if all_indices.shape[-1] == 1:
all_indices = all_indices.squeeze(-1)
data_dict[args.index_col] = list(all_indices)
if all_labels is not None:
if all_labels.shape[-1] == 1:
all_labels = all_labels.squeeze(-1)
data_dict[args.label_col] = list(all_labels)
if all_outputs.shape[-1] == 1:
all_outputs = all_outputs.squeeze(-1)
data_dict[output_col] = list(all_outputs)
@ -291,7 +327,7 @@ def main():
results_filename = args.results_file
if results_filename:
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
filename_no_ext, ext = os.path.splitext(results_filename)
if ext and ext in _FMT_EXT.values():
# if filename provided with one of expected ext,
# remove it as it will be added back
@ -308,7 +344,7 @@ def main():
save_results(df, results_filename, fmt)
print(f'--result')
print(json.dumps(dict(filename=results_filename)))
print(df.set_index(args.filename_col).to_json(orient='index', indent=4))
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
@ -316,6 +352,8 @@ def save_results(df, results_filename, results_format='csv', filename_col='filen
if results_format == 'parquet':
df.set_index(filename_col).to_parquet(results_filename)
elif results_format == 'json':
df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
elif results_format == 'json-records':
df.to_json(results_filename, lines=True, orient='records')
elif results_format == 'json-split':
df.to_json(results_filename, indent=4, orient='split', index=False)

@ -4,6 +4,8 @@ from .config import resolve_data_config, resolve_model_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .readers import create_reader

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
class DatasetInfo(ABC):
def __init__(self):
pass
@abstractmethod
def num_classes(self):
pass
@abstractmethod
def label_names(self):
pass
@abstractmethod
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
pass
@abstractmethod
def index_to_label_name(self, index) -> str:
pass
@abstractmethod
def index_to_description(self, index: int, detailed: bool = False) -> str:
pass
@abstractmethod
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
pass

@ -0,0 +1,92 @@
import csv
import os
import pkgutil
import re
from typing import Dict, List, Optional, Union
from .dataset_info import DatasetInfo
_NUM_CLASSES_TO_SUBSET = {
1000: 'imagenet-1k',
11821: 'imagenet-12k',
21841: 'imagenet-22k',
21843: 'imagenet-21k-goog',
11221: 'imagenet-21k-miil',
}
_SUBSETS = {
'imagenet1k': 'imagenet_synsets.txt',
'imagenet12k': 'imagenet12k_synsets.txt',
'imagenet22k': 'imagenet22k_synsets.txt',
'imagenet21k': 'imagenet21k_goog_synsets.txt',
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
}
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
if isinstance(model_or_cfg, dict):
num_classes = model_or_cfg.get('num_classes', None)
else:
num_classes = getattr(model_or_cfg, 'num_classes', None)
if not num_classes:
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
# FIXME at some point pretrained_cfg should include dataset-tag,
# which will be more robust than a guess based on num_classes
num_classes = pretrained_cfg.get('num_classes', None)
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
return None
return _NUM_CLASSES_TO_SUBSET[num_classes]
class ImageNetInfo(DatasetInfo):
def __init__(self, subset: str = 'imagenet-1k'):
super().__init__()
subset = re.sub(r'[-_\s]', '', subset.lower())
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
synset_file = _SUBSETS[subset]
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
self._synsets = synset_data.decode('utf-8').splitlines()
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
self._lemmas = dict(reader)
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
self._definitions = dict(reader)
def num_classes(self):
return len(self._synsets)
def label_names(self):
return self._synsets
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
if as_dict:
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
else:
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
def index_to_label_name(self, index) -> str:
assert 0 <= index < len(self._synsets), \
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
return self._synsets[index]
def index_to_description(self, index: int, detailed: bool = False) -> str:
label = self.index_to_label_name(index)
return self.label_name_to_description(label, detailed=detailed)
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
if detailed:
description = f'{self._lemmas[label]}: {self._definitions[label]}'
else:
description = f'{self._lemmas[label]}'
return description

@ -7,14 +7,19 @@ Hacked together by / Copyright 2020 Ross Wightman
import os
import json
import numpy as np
import pkgutil
class RealLabelsImagenet:
def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
with open(real_json) as real_labels:
real_labels = json.load(real_labels)
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
def __init__(self, filenames, real_json=None, topk=(1, 5)):
if real_json is not None:
with open(real_json) as real_labels:
real_labels = json.load(real_labels)
else:
real_labels = json.loads(
pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
self.real_labels = real_labels
self.filenames = filenames
assert len(self.filenames) == len(self.real_labels)

@ -1,4 +1,5 @@
import os
import pkgutil
from copy import deepcopy
from torch import nn as nn
@ -108,6 +109,5 @@ def adapt_model_from_string(parent_module, model_string):
def adapt_model_from_file(parent_module, model_variant):
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
with open(adapt_file, 'r') as f:
return adapt_model_from_string(parent_module, f.read().strip())
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())

Loading…
Cancel
Save