Update efficientnet.py and convnext.py to multi-weight, add ImageNet-12k pretrained EfficientNet-B5 and ConvNeXt-Nano.

pull/1593/head
Ross Wightman 2 years ago
parent e7da205345
commit 6a01101905

@ -1,5 +1,6 @@
import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
else:
# default source == timm or unspecified
if pretrained_file:
# file load override is the highest priority if set
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
else:
# next, HF hub is prioritized unless a valid cached version of weights exists already
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
@ -105,7 +112,7 @@ def load_custom_pretrained(
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
progress=_DOWNLOAD_PROGRESS,
)
if load_fn is not None:

@ -1,3 +1,4 @@
import hashlib
import json
import logging
import os
@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
return cached_file
def check_cached_file(url, check_hash=True):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if os.path.exists(cached_file):
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
if hash_prefix:
with open(cached_file, 'rb') as f:
hd = hashlib.sha256(f.read()).hexdigest()
if hd[:len(hash_prefix)] != hash_prefix:
return False
return True
return False
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
@ -145,7 +166,9 @@ def save_for_hf(model, save_directory, model_config=None):
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
if isinstance(global_pool_type, str) and global_pool_type:
hf_config['global_pool'] = global_pool_type
if 'label' in model_config:
_logger.warning(

@ -19,6 +19,7 @@ class PretrainedCfg:
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
tag: Optional[str] = None # pretrained tag of source
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config

@ -7,6 +7,7 @@ import re
import sys
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
@ -48,24 +49,31 @@ def register_model(fn):
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name]
if not isinstance(cfg, DefaultCfg):
default_cfg = mod.default_cfgs[model_name]
if not isinstance(default_cfg, DefaultCfg):
# new style default cfg dataclass w/ multiple entries per model-arch
assert isinstance(cfg, dict)
assert isinstance(default_cfg, dict)
# old style cfg dict per model-arch
cfg = PretrainedCfg(**cfg)
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
pretrained_cfg = PretrainedCfg(**default_cfg)
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
for tag_idx, tag in enumerate(cfg.tags):
for tag_idx, tag in enumerate(default_cfg.tags):
is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag]
pretrained_cfg = default_cfg.cfgs[tag]
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
replace_items = dict(architecture=model_name, tag=tag if tag else None)
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
# auto-complete hub name w/ architecture.tag
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
pretrained_cfg = replace(pretrained_cfg, **replace_items)
if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name)
if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid
@ -74,7 +82,7 @@ def register_model(fn):
else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg
_model_default_cfgs[model_name] = default_cfg
return fn

@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
@ -375,90 +374,130 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_atto.timm_in1k': _cfg(
'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_atto_ols.timm_in1k': _cfg(
'convnext_atto_ols.a2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto.timm_in1k': _cfg(
'convnext_femto.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto_ols.timm_in1k': _cfg(
'convnext_femto_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico.timm_in1k': _cfg(
'convnext_pico.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico_ols.timm_in1k': _cfg(
'convnext_pico_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.timm_in1k': _cfg(
'convnext_nano.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano_ols.timm_in1k': _cfg(
'convnext_nano_ols.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny_hnf.timm_in1k': _cfg(
'convnext_tiny_hnf.a2h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_tiny_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
'convnext_small_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
'convnext_base_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
'convnext_large_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
'convnext_xlarge_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
'convnext_tiny.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_small.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_base.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_large.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_xlarge.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
})

File diff suppressed because it is too large Load Diff

@ -711,10 +711,10 @@ default_cfgs = generate_default_cfgs({
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.v1_in21k': _cfg(
'vit_large_patch32_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg(
'vit_huge_patch14_224.orig_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
custom_load=True, num_classes=21843),

Loading…
Cancel
Save