Add ported Tensorflow MaxVit weights. Add a few more CLIP ViT fine-tunes. Tweak some model tag names. Improve model tag name sorting. Update HF hub push config layout.

pull/1582/head
Ross Wightman 2 years ago committed by Ross Wightman
parent dbe7531aa3
commit 72cfa57761

@ -1,4 +1,4 @@
from .version import __version__ from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, \ is_scriptable, is_exportable, set_scriptable, set_exportable, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -70,5 +70,6 @@ from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ from ._pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -1,5 +1,6 @@
import copy
from collections import deque, defaultdict from collections import deque, defaultdict
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union from typing import Any, Deque, Dict, Tuple, Optional, Union
@ -8,13 +9,13 @@ class PretrainedCfg:
""" """
""" """
# weight locations # weight locations
url: str = '' url: Optional[Union[str, Tuple[str, str]]] = None
file: str = '' file: Optional[str] = None
hf_hub_id: str = '' hf_hub_id: Optional[str] = None
hf_hub_filename: str = '' hf_hub_filename: Optional[str] = None
source: str = '' # source of cfg / weight location used (url, file, hf-hub) source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: str = '' # architecture variant can be set when not implicit architecture: Optional[str] = None # architecture variant can be set when not implicit
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config # input / data config
@ -31,22 +32,40 @@ class PretrainedCfg:
# head config # head config
num_classes: int = 1000 num_classes: int = 1000
label_offset: int = 0 label_offset: Optional[int] = None
# model attributes that vary with above or required for pretrained adaptation # model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None pool_size: Optional[Tuple[int, ...]] = None
test_pool_size: Optional[Tuple[int, ...]] = None test_pool_size: Optional[Tuple[int, ...]] = None
first_conv: str = '' first_conv: Optional[str] = None
classifier: str = '' classifier: Optional[str] = None
license: str = '' license: Optional[str] = None
source_url: str = '' source_url: Optional[str] = None
paper: str = '' paper: Optional[str] = None
notes: str = '' notes: Optional[str] = None
@property @property
def has_weights(self): def has_weights(self):
return self.url.startswith('http') or self.file or self.hf_hub_id return self.url or self.file or self.hf_hub_id
def to_dict(self, remove_source=False, remove_null=True):
return filter_pretrained_cfg(
asdict(self),
remove_source=remove_source,
remove_null=remove_null
)
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {}
for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue
if remove_null and v is None:
continue
filtered_cfg[k] = v
return filtered_cfg
@dataclass @dataclass
@ -71,7 +90,7 @@ def split_model_name_tag(model_name: str, no_tag=''):
return model_name, tag return model_name, tag
def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]): def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg) out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default default_set = set() # no tag and tags ending with * are prioritized as default
@ -82,21 +101,22 @@ def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
model, tag = split_model_name_tag(k) model, tag = split_model_name_tag(k)
is_default_set = model in default_set is_default_set = model in default_set
priority = not tag or (tag.endswith('*') and not is_default_set) priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*') tag = tag.strip('*')
default_cfg = out[model] default_cfg = out[model]
if has_weights:
default_cfg.is_pretrained = True
if priority: if priority:
default_cfg.tags.appendleft(tag) default_cfg.tags.appendleft(tag)
default_set.add(model) default_set.add(model)
elif has_weights and not default_set: elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag) default_cfg.tags.appendleft(tag)
else: else:
default_cfg.tags.append(tag) default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v default_cfg.cfgs[tag] = v
return out return out

@ -21,7 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._pretrained import generate_defaults from ._pretrained import generate_default_cfgs
from .registry import register_model from .registry import register_model
@ -373,7 +373,7 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = generate_defaults({ default_cfgs = generate_default_cfgs({
# timm specific variants # timm specific variants
'convnext_atto.timm_in1k': _cfg( 'convnext_atto.timm_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',

@ -575,7 +575,7 @@ def build_model_with_cfg(
) )
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = dataclasses.asdict(pretrained_cfg) pretrained_cfg = pretrained_cfg.to_dict()
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)

@ -15,9 +15,11 @@ except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
from timm import __version__ from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg
try: try:
from huggingface_hub import (create_repo, get_hf_file_metadata, from huggingface_hub import (
create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url, hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder) repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
@ -46,6 +48,9 @@ def get_cache_dir(child_dir=''):
def download_cached_file(url, check_hash=True, progress=False): def download_cached_file(url, check_hash=True, progress=False):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url) parts = urlparse(url)
filename = os.path.basename(parts.path) filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename) cached_file = os.path.join(get_cache_dir(), filename)
@ -90,10 +95,27 @@ def _download_from_hf(model_id: str, filename: str):
def load_model_config_from_hf(model_id: str): def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json') cached_file = _download_from_hf(model_id, 'config.json')
pretrained_cfg = load_cfg_from_json(cached_file)
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg:
hf_config['label_name'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub' pretrained_cfg['source'] = 'hf-hub'
model_name = pretrained_cfg.get('architecture') if 'num_classes' in hf_config:
# model should be created with parent num_classes if they exist
pretrained_cfg['num_classes'] = hf_config['num_classes']
model_name = hf_config['architecture']
return pretrained_cfg, model_name return pretrained_cfg, model_name
@ -114,10 +136,34 @@ def save_for_hf(model, save_directory, model_config=None):
torch.save(model.state_dict(), weights_path) torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json' config_path = save_directory / 'config.json'
hf_config = model.pretrained_cfg hf_config = {}
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
hf_config['num_features'] = model_config.pop('num_features', model.num_features) # set some values at root config level
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
if 'label' in model_config:
_logger.warning(
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"Using provided 'label' field as 'label_name'.")
model_config['label_name'] = model_config.pop('label')
label_name = model_config.pop('label_name', None)
if label_name:
assert isinstance(label_name, (dict, list, tuple))
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
hf_config['label_name'] = model_config['label_name']
display_name = model_config.pop('display_name', None)
if display_name:
assert isinstance(display_name, dict)
# map label_name -> user interface display name
hf_config['display_name'] = model_config['display_name']
hf_config['pretrained_cfg'] = pretrained_cfg
hf_config.update(model_config) hf_config.update(model_config)
with config_path.open('w') as f: with config_path.open('w') as f:
@ -134,7 +180,7 @@ def push_to_hf_hub(
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
): ):
# Create repo if doesn't exist yet # Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url # Infer complete repo_id from repo_url
@ -154,10 +200,11 @@ def push_to_hf_hub(
# Save model weights and config. # Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config) save_for_hf(model, tmpdir, model_config=model_config)
# Add readme if does not exist # Add readme if it does not exist
if not has_readme: if not has_readme:
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md" readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}' readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
readme_path.write_text(readme_text) readme_path.write_text(readme_text)
# Upload model and return # Upload model and return

@ -54,7 +54,7 @@ from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, La
from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
from .layers import SelectAdaptivePool2d, create_pool2d from .layers import SelectAdaptivePool2d, create_pool2d
from .layers import to_2tuple, extend_tuple, make_divisible, _assert from .layers import to_2tuple, extend_tuple, make_divisible, _assert
from ._pretrained import generate_defaults from ._pretrained import generate_default_cfgs
from .registry import register_model from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
@ -1859,7 +1859,7 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = generate_defaults({ default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining # Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''), 'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg( 'coatnet_nano_rw_224': _cfg(
@ -1941,86 +1941,67 @@ default_cfgs = generate_defaults({
'maxxvit_rmlp_large_rw_224': _cfg(url=''), 'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# Trying to be like the MaxViT paper configs # MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg( 'maxvit_tiny_tf_224.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
#file='maxvit_tiny_tf_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg( 'maxvit_tiny_tf_384.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
#file='maxvit_tiny_tf_384_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg( 'maxvit_tiny_tf_512.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
#file='maxvit_tiny_tf_512_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg( 'maxvit_small_tf_224.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_small_tf_224.in1k',
#file='maxvit_small_tf_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg( 'maxvit_small_tf_384.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_small_tf_384.in1k',
#file='maxvit_small_tf_384_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg( 'maxvit_small_tf_512.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_small_tf_512.in1k',
#file='maxvit_small_tf_512_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg( 'maxvit_base_tf_224.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_base_tf_224.in1k',
#file='maxvit_base_tf_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg( 'maxvit_base_tf_384.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_base_tf_384.in1k',
#file='maxvit_base_tf_384_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg( 'maxvit_base_tf_512.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_base_tf_512.in1k',
#file='maxvit_base_tf_512_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg( 'maxvit_large_tf_224.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_large_tf_224.in1k',
#file='maxvit_large_tf_224_in1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg( 'maxvit_large_tf_384.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_large_tf_384.in1k',
#file='maxvit_large_tf_384_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg( 'maxvit_large_tf_512.in1k': _cfg(
url='', hf_hub_id='timm/maxvit_large_tf_512.in1k',
#file='maxvit_large_tf_512_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg( 'maxvit_base_tf_224.in21k': _cfg(
url='', url=''),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
'maxvit_base_tf_384.in21k_ft1k': _cfg( hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
url='',
#file='maxvit_base_tf_384_in21k_ft_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft1k': _cfg( 'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
url='', hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
#file='maxvit_base_tf_512_in21k_ft_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg( 'maxvit_large_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_large_tf_384.in21k_ft1k': _cfg( 'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
url='', hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
#file='maxvit_large_tf_384_in21k_ft_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft1k': _cfg( 'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
url='', hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
#file='maxvit_large_tf_512_in21k_ft_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg( 'maxvit_xlarge_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_xlarge_tf_384.in21k_ft1k': _cfg( 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
url='', hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
#file='maxvit_xlarge_tf_384_in21k_ft_in1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft1k': _cfg( 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
url='', hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
#file='maxvit_xlarge_tf_512_in21k_ft_in1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
}) })

@ -7,7 +7,7 @@ import re
import sys import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from typing import Optional, Tuple from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -84,7 +84,7 @@ def _natural_key(string_):
def list_models( def list_models(
filter: str = '', filter: Union[str, List[str]] = '',
module: str = '', module: str = '',
pretrained=False, pretrained=False,
exclude_filters: str = '', exclude_filters: str = '',
@ -114,7 +114,12 @@ def list_models(
else: else:
all_models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
# FIXME wildcard filter tag as well as model arch name if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in all_models:
models_with_tags.extend(_model_with_tags[m])
all_models = models_with_tags
if filter: if filter:
models = [] models = []
@ -134,13 +139,6 @@ def list_models(
if len(exclude_models): if len(exclude_models):
models = set(models).difference(exclude_models) models = set(models).difference(exclude_models)
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in models:
models_with_tags.extend(_model_with_tags[m])
models = models_with_tags
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
@ -150,6 +148,18 @@ def list_models(
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))
def list_pretrained(
filter: Union[str, List[str]] = '',
exclude_filters: str = '',
):
return list_models(
filter=filter,
pretrained=True,
exclude_filters=exclude_filters,
include_tags=True,
)
def is_model(model_name): def is_model(model_name):
""" Check if a model name exists """ Check if a model name exists
""" """

@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from ._pretrained import generate_defaults from ._pretrained import generate_default_cfgs
from .registry import register_model from .registry import register_model
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -492,6 +492,7 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if model.cls_token is not None:
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape: if pos_embed_w.shape != model.pos_embed.shape:
@ -630,51 +631,74 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = generate_defaults({ default_cfgs = generate_default_cfgs({
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224.augreg_in21k_ft_1k': _cfg( # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True), custom_load=True),
'vit_tiny_patch16_384.augreg_in21k_ft_1k': _cfg( 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224.augreg_in21k_ft_1k': _cfg( 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True), custom_load=True),
'vit_small_patch32_384.augreg_in21k_ft_1k': _cfg( 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224.augreg_in21k_ft_1k': _cfg( 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True), custom_load=True),
'vit_small_patch16_384.augreg_in21k_ft_1k': _cfg( 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224.augreg_in21k_ft_1k': _cfg( 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True), custom_load=True),
'vit_base_patch32_384.augreg_in21k_ft_1k': _cfg( 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224.augreg_in21k_ft_1k': _cfg( 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True), custom_load=True),
'vit_base_patch16_384.augreg_in21k_ft_1k': _cfg( 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224.augreg_in21k_ft_1k': _cfg( 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True), custom_load=True),
'vit_large_patch32_384.v1_in21k_ft_1k': _cfg( 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224.augreg_in21k_ft_1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True), custom_load=True),
'vit_large_patch16_384.augreg_in21k_ft_1k': _cfg( 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
file='b16_augreg-a-8.pth'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(
url=''),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
url=''),
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'),
'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'),
'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
# How to train your ViT (augreg) weights trained on in1k
'vit_base_patch16_224.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_base_patch16_384.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch14_224.untrained': _cfg(url=''), 'vit_large_patch14_224.untrained': _cfg(url=''),
'vit_huge_patch14_224.untrained': _cfg(url=''), 'vit_huge_patch14_224.untrained': _cfg(url=''),
'vit_giant_patch14_224.untrained': _cfg(url=''), 'vit_giant_patch14_224.untrained': _cfg(url=''),
@ -682,6 +706,15 @@ default_cfgs = generate_defaults({
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
custom_load=True, num_classes=21843),
# How to train your ViT (augreg) weights, pretrained on in21k
'vit_tiny_patch16_224.augreg_in21k': _cfg( 'vit_tiny_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
custom_load=True, num_classes=21843), custom_load=True, num_classes=21843),
@ -700,16 +733,9 @@ default_cfgs = generate_defaults({
'vit_base_patch8_224.augreg_in21k': _cfg( 'vit_base_patch8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
custom_load=True, num_classes=21843), custom_load=True, num_classes=21843),
'vit_large_patch32_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_large_patch16_224.augreg_in21k': _cfg( 'vit_large_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
custom_load=True, num_classes=21843), custom_load=True, num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
custom_load=True, num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548) # SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_224.sam': _cfg( 'vit_base_patch32_224.sam': _cfg(
@ -736,7 +762,7 @@ default_cfgs = generate_defaults({
'vit_base_patch16_224_miil.in21k': _cfg( 'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
'vit_base_patch16_224_miil.in21k_ft_1k': _cfg( 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
@ -744,14 +770,15 @@ default_cfgs = generate_defaults({
'vit_base_patch16_rpn_224.in1k': _cfg( 'vit_base_patch16_rpn_224.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
'vit_medium_patch16_gap_240.in12k': _cfg( 'vit_medium_patch16_gap_240.in12k': _cfg(
url='', hf_hub_id='timm/vit_medium_patch16_gap_240.in12k',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
'vit_medium_patch16_gap_256.in12k_ft_1k': _cfg( 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
url='', hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=0.95), input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_gap_384.in12k_ft_1k': _cfg( 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
url='', hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k',
input_size=(3, 384, 384), crop_pct=0.95), input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
'vit_base_patch16_gap_224': _cfg(),
# CLIP pretrained image tower and related fine-tuned weights # CLIP pretrained image tower and related fine-tuned weights
'vit_base_patch32_clip_224.laion2b': _cfg( 'vit_base_patch32_clip_224.laion2b': _cfg(
@ -781,15 +808,16 @@ default_cfgs = generate_defaults({
'vit_base_patch32_clip_384.laion2b_ft_in1k': _cfg( 'vit_base_patch32_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in1k', hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k', hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k', hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg( crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k', hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
@ -816,10 +844,11 @@ default_cfgs = generate_defaults({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k', hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k', hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k', hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
@ -866,7 +895,8 @@ default_cfgs = generate_defaults({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k', hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k', hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
@ -876,10 +906,15 @@ default_cfgs = generate_defaults({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k', hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
#hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k', hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k', hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
@ -1118,37 +1153,48 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
@register_model @register_model
def vit_medium_patch16_gap_240(pretrained=False, **kwargs): def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_medium_patch16_gap_256(pretrained=False, **kwargs): def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_medium_patch16_gap_384(pretrained=False, **kwargs): def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_base_patch16_gap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch32_clip_224(pretrained=False, **kwargs): def vit_base_patch32_clip_224(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 224x224 """ ViT-B/32 CLIP image tower @ 224x224

@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._pretrained import generate_defaults from ._pretrained import generate_default_cfgs
from .layers import StdConv2dSame, StdConv2d, to_2tuple from .layers import StdConv2dSame, StdConv2d, to_2tuple
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -39,31 +39,31 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = generate_defaults({ default_cfgs = generate_default_cfgs({
# hybrid in-1k models (weights from official JAX impl where they exist) # hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_1k': _cfg( 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True, custom_load=True,
first_conv='patch_embed.backbone.conv'), first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_1k': _cfg( 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.augreg_in21k_ft_1k': _cfg( 'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True, custom_load=True,
), ),
'vit_small_r26_s32_384.augreg_in21k_ft_1k': _cfg( 'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224.untrained': _cfg(), 'vit_base_r26_s32_224.untrained': _cfg(),
'vit_base_r50_s16_384.v1_in21k_ft_1k': _cfg( 'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224.augreg_in21k_ft_1k': _cfg( 'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True, custom_load=True,
), ),
'vit_large_r50_s32_384.augreg_in21k_ft_1k': _cfg( 'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
), ),

Loading…
Cancel
Save