Move ImageNet metadata (aka info) files to timm/data/_info. Add helper classes to make info available for labelling. Update inference.py for first use.

pull/1662/head
Ross Wightman 1 year ago
parent 89b0452171
commit 0f2803de7a

@ -1,2 +1,3 @@
include timm/models/pruned/*.txt include timm/models/_pruned/*.txt
include timm/data/_info/*.txt
include timm/data/_info/*.json

@ -17,7 +17,7 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
from timm.data import create_dataset, create_loader, resolve_data_config from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
from timm.layers import apply_test_time_pool from timm.layers import apply_test_time_pool
from timm.models import create_model from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
@ -46,6 +46,7 @@ has_compile = hasattr(torch, 'compile')
_FMT_EXT = { _FMT_EXT = {
'json': '.json', 'json': '.json',
'json-record': '.json',
'json-split': '.json', 'json-split': '.json',
'parquet': '.parquet', 'parquet': '.parquet',
'csv': '.csv', 'csv': '.csv',
@ -122,7 +123,7 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
scripting_group.add_argument('--aot-autograd', default=False, action='store_true', scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.") help="Enable AOT Autograd support.")
parser.add_argument('--results-dir',type=str, default=None, parser.add_argument('--results-dir', type=str, default=None,
help='folder for output results') help='folder for output results')
parser.add_argument('--results-file', type=str, default=None, parser.add_argument('--results-file', type=str, default=None,
help='results filename (relative to results-dir)') help='results filename (relative to results-dir)')
@ -134,14 +135,20 @@ parser.add_argument('--topk', default=1, type=int,
metavar='N', help='Top-k to output to CSV') metavar='N', help='Top-k to output to CSV')
parser.add_argument('--fullname', action='store_true', default=False, parser.add_argument('--fullname', action='store_true', default=False,
help='use full sample name in output (not just basename).') help='use full sample name in output (not just basename).')
parser.add_argument('--filename-col', default='filename', parser.add_argument('--filename-col', type=str, default='filename',
help='name for filename / sample name column') help='name for filename / sample name column')
parser.add_argument('--index-col', default='index', parser.add_argument('--index-col', type=str, default='index',
help='name for output indices column(s)') help='name for output indices column(s)')
parser.add_argument('--output-col', default=None, parser.add_argument('--label-col', type=str, default='label',
help='name for output indices column(s)')
parser.add_argument('--output-col', type=str, default=None,
help='name for logit/probs output column(s)') help='name for logit/probs output column(s)')
parser.add_argument('--output-type', default='prob', parser.add_argument('--output-type', type=str, default='prob',
help='output type colum ("prob" for probabilities, "logit" for raw logits)') help='output type colum ("prob" for probabilities, "logit" for raw logits)')
parser.add_argument('--label-type', type=str, default='description',
help='type of label to output, one of "none", "name", "description", "detailed"')
parser.add_argument('--include-index', action='store_true', default=False,
help='include the class index in results')
parser.add_argument('--exclude-output', action='store_true', default=False, parser.add_argument('--exclude-output', action='store_true', default=False,
help='exclude logits/probs from results, just indices. topk must be set !=0.') help='exclude logits/probs from results, just indices. topk must be set !=0.')
@ -237,10 +244,26 @@ def main():
**data_config, **data_config,
) )
to_label = None
if args.label_type in ('name', 'description', 'detail'):
imagenet_subset = infer_imagenet_subset(model)
if imagenet_subset is not None:
dataset_info = ImageNetInfo(imagenet_subset)
if args.label_type == 'name':
to_label = lambda x: dataset_info.index_to_label_name(x)
elif args.label_type == 'detail':
to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
else:
to_label = lambda x: dataset_info.index_to_description(x)
to_label = np.vectorize(to_label)
else:
_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")
top_k = min(args.topk, args.num_classes) top_k = min(args.topk, args.num_classes)
batch_time = AverageMeter() batch_time = AverageMeter()
end = time.time() end = time.time()
all_indices = [] all_indices = []
all_labels = []
all_outputs = [] all_outputs = []
use_probs = args.output_type == 'prob' use_probs = args.output_type == 'prob'
with torch.no_grad(): with torch.no_grad():
@ -254,7 +277,12 @@ def main():
if top_k: if top_k:
output, indices = output.topk(top_k) output, indices = output.topk(top_k)
all_indices.append(indices.cpu().numpy()) np_indices = indices.cpu().numpy()
if args.include_index:
all_indices.append(np_indices)
if to_label is not None:
np_labels = to_label(np_indices)
all_labels.append(np_labels)
all_outputs.append(output.cpu().numpy()) all_outputs.append(output.cpu().numpy())
@ -267,6 +295,7 @@ def main():
batch_idx, len(loader), batch_time=batch_time)) batch_idx, len(loader), batch_time=batch_time))
all_indices = np.concatenate(all_indices, axis=0) if all_indices else None all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32) all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
filenames = loader.dataset.filenames(basename=not args.fullname) filenames = loader.dataset.filenames(basename=not args.fullname)
@ -276,6 +305,9 @@ def main():
if all_indices is not None: if all_indices is not None:
for i in range(all_indices.shape[-1]): for i in range(all_indices.shape[-1]):
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i] data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
if all_labels is not None:
for i in range(all_labels.shape[-1]):
data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
for i in range(all_outputs.shape[-1]): for i in range(all_outputs.shape[-1]):
data_dict[f'{output_col}_{i}'] = all_outputs[:, i] data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
else: else:
@ -283,6 +315,10 @@ def main():
if all_indices.shape[-1] == 1: if all_indices.shape[-1] == 1:
all_indices = all_indices.squeeze(-1) all_indices = all_indices.squeeze(-1)
data_dict[args.index_col] = list(all_indices) data_dict[args.index_col] = list(all_indices)
if all_labels is not None:
if all_labels.shape[-1] == 1:
all_labels = all_labels.squeeze(-1)
data_dict[args.label_col] = list(all_labels)
if all_outputs.shape[-1] == 1: if all_outputs.shape[-1] == 1:
all_outputs = all_outputs.squeeze(-1) all_outputs = all_outputs.squeeze(-1)
data_dict[output_col] = list(all_outputs) data_dict[output_col] = list(all_outputs)
@ -291,7 +327,7 @@ def main():
results_filename = args.results_file results_filename = args.results_file
if results_filename: if results_filename:
filename_no_ext, ext = os.path.splitext(results_filename)[-1] filename_no_ext, ext = os.path.splitext(results_filename)
if ext and ext in _FMT_EXT.values(): if ext and ext in _FMT_EXT.values():
# if filename provided with one of expected ext, # if filename provided with one of expected ext,
# remove it as it will be added back # remove it as it will be added back
@ -308,7 +344,7 @@ def main():
save_results(df, results_filename, fmt) save_results(df, results_filename, fmt)
print(f'--result') print(f'--result')
print(json.dumps(dict(filename=results_filename))) print(df.set_index(args.filename_col).to_json(orient='index', indent=4))
def save_results(df, results_filename, results_format='csv', filename_col='filename'): def save_results(df, results_filename, results_format='csv', filename_col='filename'):
@ -316,6 +352,8 @@ def save_results(df, results_filename, results_format='csv', filename_col='filen
if results_format == 'parquet': if results_format == 'parquet':
df.set_index(filename_col).to_parquet(results_filename) df.set_index(filename_col).to_parquet(results_filename)
elif results_format == 'json': elif results_format == 'json':
df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
elif results_format == 'json-records':
df.to_json(results_filename, lines=True, orient='records') df.to_json(results_filename, lines=True, orient='records')
elif results_format == 'json-split': elif results_format == 'json-split':
df.to_json(results_filename, indent=4, orient='split', index=False) df.to_json(results_filename, indent=4, orient='split', index=False)

@ -4,6 +4,8 @@ from .config import resolve_data_config, resolve_model_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .readers import create_reader from .readers import create_reader

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
class DatasetInfo(ABC):
def __init__(self):
pass
@abstractmethod
def num_classes(self):
pass
@abstractmethod
def label_names(self):
pass
@abstractmethod
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
pass
@abstractmethod
def index_to_label_name(self, index) -> str:
pass
@abstractmethod
def index_to_description(self, index: int, detailed: bool = False) -> str:
pass
@abstractmethod
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
pass

@ -0,0 +1,92 @@
import csv
import os
import pkgutil
import re
from typing import Dict, List, Optional, Union
from .dataset_info import DatasetInfo
_NUM_CLASSES_TO_SUBSET = {
1000: 'imagenet-1k',
11821: 'imagenet-12k',
21841: 'imagenet-22k',
21843: 'imagenet-21k-goog',
11221: 'imagenet-21k-miil',
}
_SUBSETS = {
'imagenet1k': 'imagenet_synsets.txt',
'imagenet12k': 'imagenet12k_synsets.txt',
'imagenet22k': 'imagenet22k_synsets.txt',
'imagenet21k': 'imagenet21k_goog_synsets.txt',
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
}
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
if isinstance(model_or_cfg, dict):
num_classes = model_or_cfg.get('num_classes', None)
else:
num_classes = getattr(model_or_cfg, 'num_classes', None)
if not num_classes:
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
# FIXME at some point pretrained_cfg should include dataset-tag,
# which will be more robust than a guess based on num_classes
num_classes = pretrained_cfg.get('num_classes', None)
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
return None
return _NUM_CLASSES_TO_SUBSET[num_classes]
class ImageNetInfo(DatasetInfo):
def __init__(self, subset: str = 'imagenet-1k'):
super().__init__()
subset = re.sub(r'[-_\s]', '', subset.lower())
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
synset_file = _SUBSETS[subset]
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
self._synsets = synset_data.decode('utf-8').splitlines()
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
self._lemmas = dict(reader)
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
self._definitions = dict(reader)
def num_classes(self):
return len(self._synsets)
def label_names(self):
return self._synsets
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
if as_dict:
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
else:
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
def index_to_label_name(self, index) -> str:
assert 0 <= index < len(self._synsets), \
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
return self._synsets[index]
def index_to_description(self, index: int, detailed: bool = False) -> str:
label = self.index_to_label_name(index)
return self.label_name_to_description(label, detailed=detailed)
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
if detailed:
description = f'{self._lemmas[label]}: {self._definitions[label]}'
else:
description = f'{self._lemmas[label]}'
return description

@ -7,14 +7,19 @@ Hacked together by / Copyright 2020 Ross Wightman
import os import os
import json import json
import numpy as np import numpy as np
import pkgutil
class RealLabelsImagenet: class RealLabelsImagenet:
def __init__(self, filenames, real_json='real.json', topk=(1, 5)): def __init__(self, filenames, real_json=None, topk=(1, 5)):
with open(real_json) as real_labels: if real_json is not None:
real_labels = json.load(real_labels) with open(real_json) as real_labels:
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} real_labels = json.load(real_labels)
else:
real_labels = json.loads(
pkgutil.get_data(__name__, os.path.join('_info', 'imagenet_real_labels.json')).decode('utf-8'))
real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
self.real_labels = real_labels self.real_labels = real_labels
self.filenames = filenames self.filenames = filenames
assert len(self.filenames) == len(self.real_labels) assert len(self.filenames) == len(self.real_labels)

Loading…
Cancel
Save