Update efficientnet.py and convnext.py to multi-weight, add ImageNet-12k pretrained EfficientNet-B5 and ConvNeXt-Nano.

pull/1593/head
Ross Wightman 2 years ago
parent e7da205345
commit 6a01101905

@ -1,5 +1,6 @@
import dataclasses import dataclasses
import logging import logging
import os
from copy import deepcopy from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple from typing import Optional, Dict, Callable, Any, Tuple
@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file from timm.models._prune import adapt_model_from_file
@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
pretrained_url = pretrained_cfg.get('url', None) pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None) pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None) hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from # resolve where to load pretrained weights from
load_from = '' load_from = ''
pretrained_loc = '' pretrained_loc = ''
@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
else: else:
# default source == timm or unspecified # default source == timm or unspecified
if pretrained_file: if pretrained_file:
# file load override is the highest priority if set
load_from = 'file' load_from = 'file'
pretrained_loc = pretrained_file pretrained_loc = pretrained_file
elif pretrained_url: else:
load_from = 'url' # next, HF hub is prioritized unless a valid cached version of weights exists already
pretrained_loc = pretrained_url cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
elif hf_hub_id and has_hf_hub(necessary=True): if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
# hf-hub available as alternate weight source in default_cfg # hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub' load_from = 'hf-hub'
pretrained_loc = hf_hub_id pretrained_loc = hf_hub_id
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename) # if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
@ -105,7 +112,7 @@ def load_custom_pretrained(
pretrained_loc = download_cached_file( pretrained_loc = download_cached_file(
pretrained_loc, pretrained_loc,
check_hash=_CHECK_HASH, check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS progress=_DOWNLOAD_PROGRESS,
) )
if load_fn is not None: if load_fn is not None:

@ -1,3 +1,4 @@
import hashlib
import json import json
import logging import logging
import os import os
@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
return cached_file return cached_file
def check_cached_file(url, check_hash=True):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if os.path.exists(cached_file):
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
if hash_prefix:
with open(cached_file, 'rb') as f:
hd = hashlib.sha256(f.read()).hexdigest()
if hd[:len(hash_prefix)] != hash_prefix:
return False
return True
return False
def has_hf_hub(necessary=False): def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary: if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error # if no HF Hub module installed, and it is necessary to continue, raise error
@ -145,7 +166,9 @@ def save_for_hf(model, save_directory, model_config=None):
hf_config['architecture'] = pretrained_cfg.pop('architecture') hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features) hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
if isinstance(global_pool_type, str) and global_pool_type:
hf_config['global_pool'] = global_pool_type
if 'label' in model_config: if 'label' in model_config:
_logger.warning( _logger.warning(

@ -19,6 +19,7 @@ class PretrainedCfg:
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit architecture: Optional[str] = None # architecture variant can be set when not implicit
tag: Optional[str] = None # pretrained tag of source
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config # input / data config

@ -7,6 +7,7 @@ import re
import sys import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from dataclasses import replace
from typing import List, Optional, Union, Tuple from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects _model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs _model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names _model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
@ -48,24 +49,31 @@ def register_model(fn):
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name] default_cfg = mod.default_cfgs[model_name]
if not isinstance(cfg, DefaultCfg): if not isinstance(default_cfg, DefaultCfg):
# new style default cfg dataclass w/ multiple entries per model-arch # new style default cfg dataclass w/ multiple entries per model-arch
assert isinstance(cfg, dict) assert isinstance(default_cfg, dict)
# old style cfg dict per model-arch # old style cfg dict per model-arch
cfg = PretrainedCfg(**cfg) pretrained_cfg = PretrainedCfg(**default_cfg)
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
for tag_idx, tag in enumerate(cfg.tags): for tag_idx, tag in enumerate(default_cfg.tags):
is_default = tag_idx == 0 is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag] pretrained_cfg = default_cfg.cfgs[tag]
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
replace_items = dict(architecture=model_name, tag=tag if tag else None)
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
# auto-complete hub name w/ architecture.tag
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
pretrained_cfg = replace(pretrained_cfg, **replace_items)
if is_default: if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg _model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights: if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights # add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name) _model_has_pretrained.add(model_name)
if tag: if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg _model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights: if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid # add model w/ tag if tag is valid
@ -74,7 +82,7 @@ def register_model(fn):
else: else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg _model_default_cfgs[model_name] = default_cfg
return fn return fn

@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
return model return model
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {
'url': url, 'url': url,
@ -375,90 +374,130 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# timm specific variants # timm specific variants
'convnext_atto.timm_in1k': _cfg( 'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95), test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_atto_ols.timm_in1k': _cfg( 'convnext_atto_ols.a2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95), test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto.timm_in1k': _cfg( 'convnext_femto.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95), test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto_ols.timm_in1k': _cfg( 'convnext_femto_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95), test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico.timm_in1k': _cfg( 'convnext_pico.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95), test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico_ols.timm_in1k': _cfg( 'convnext_pico_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.timm_in1k': _cfg( 'convnext_nano.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano_ols.timm_in1k': _cfg( 'convnext_nano_ols.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny_hnf.timm_in1k': _cfg( 'convnext_tiny_hnf.a2h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg( 'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg( 'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg( 'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg( 'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(), 'convnext_xlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg( 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in22k_ft_in1k': _cfg( 'convnext_small.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in22k_ft_in1k': _cfg( 'convnext_base.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in22k_ft_in1k': _cfg( 'convnext_large.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.fb_in22k_ft_in1k': _cfg( 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg( 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small..fb_in22k_ft_in1k_384': _cfg( 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_base.fb_in22k_ft_in1k_384': _cfg( 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_large.fb_in22k_ft_in1k_384': _cfg( 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg( 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_tiny_in22k.fb_in22k': _cfg( 'convnext_tiny.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
'convnext_small_in22k.fb_in22k': _cfg( hf_hub_id='timm/',
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), num_classes=21841),
'convnext_base_in22k.fb_in22k': _cfg( 'convnext_small.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
'convnext_large_in22k.fb_in22k': _cfg( hf_hub_id='timm/',
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), num_classes=21841),
'convnext_xlarge_in22k.fb_in22k': _cfg( 'convnext_base.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_large.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_xlarge.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
}) })

File diff suppressed because it is too large Load Diff

@ -711,10 +711,10 @@ default_cfgs = generate_default_cfgs({
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.v1_in21k': _cfg( 'vit_large_patch32_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843), num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg( 'vit_huge_patch14_224.orig_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k', hf_hub_id='timm/vit_huge_patch14_224_in21k',
custom_load=True, num_classes=21843), custom_load=True, num_classes=21843),

Loading…
Cancel
Save