Add ported Tensorflow MaxVit weights. Add a few more CLIP ViT fine-tunes. Tweak some model tag names. Improve model tag name sorting. Update HF hub push config layout.

pull/1582/head
Ross Wightman 2 years ago committed by Ross Wightman
parent dbe7531aa3
commit 72cfa57761

@ -1,4 +1,4 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -70,5 +70,6 @@ from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
from ._pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -1,5 +1,6 @@
import copy
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace
from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union
@ -8,13 +9,13 @@ class PretrainedCfg:
"""
"""
# weight locations
url: str = ''
file: str = ''
hf_hub_id: str = ''
hf_hub_filename: str = ''
url: Optional[Union[str, Tuple[str, str]]] = None
file: Optional[str] = None
hf_hub_id: Optional[str] = None
hf_hub_filename: Optional[str] = None
source: str = '' # source of cfg / weight location used (url, file, hf-hub)
architecture: str = '' # architecture variant can be set when not implicit
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
@ -31,22 +32,40 @@ class PretrainedCfg:
# head config
num_classes: int = 1000
label_offset: int = 0
label_offset: Optional[int] = None
# model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None
test_pool_size: Optional[Tuple[int, ...]] = None
first_conv: str = ''
classifier: str = ''
first_conv: Optional[str] = None
classifier: Optional[str] = None
license: str = ''
source_url: str = ''
paper: str = ''
notes: str = ''
license: Optional[str] = None
source_url: Optional[str] = None
paper: Optional[str] = None
notes: Optional[str] = None
@property
def has_weights(self):
return self.url.startswith('http') or self.file or self.hf_hub_id
return self.url or self.file or self.hf_hub_id
def to_dict(self, remove_source=False, remove_null=True):
return filter_pretrained_cfg(
asdict(self),
remove_source=remove_source,
remove_null=remove_null
)
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {}
for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue
if remove_null and v is None:
continue
filtered_cfg[k] = v
return filtered_cfg
@dataclass
@ -71,7 +90,7 @@ def split_model_name_tag(model_name: str, no_tag=''):
return model_name, tag
def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
@ -82,21 +101,22 @@ def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = not tag or (tag.endswith('*') and not is_default_set)
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if has_weights:
default_cfg.is_pretrained = True
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_set:
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out

@ -21,7 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._pretrained import generate_defaults
from ._pretrained import generate_default_cfgs
from .registry import register_model
@ -373,7 +373,7 @@ def _cfg(url='', **kwargs):
}
default_cfgs = generate_defaults({
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_atto.timm_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',

@ -575,7 +575,7 @@ def build_model_with_cfg(
)
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = dataclasses.asdict(pretrained_cfg)
pretrained_cfg = pretrained_cfg.to_dict()
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)

@ -15,9 +15,11 @@ except ImportError:
from torch.hub import _get_torch_home as get_dir
from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg
try:
from huggingface_hub import (create_repo, get_hf_file_metadata,
from huggingface_hub import (
create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub.utils import EntryNotFoundError
@ -46,6 +48,9 @@ def get_cache_dir(child_dir=''):
def download_cached_file(url, check_hash=True, progress=False):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
@ -90,10 +95,27 @@ def _download_from_hf(model_id: str, filename: str):
def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json')
pretrained_cfg = load_cfg_from_json(cached_file)
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg:
hf_config['label_name'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub'
model_name = pretrained_cfg.get('architecture')
if 'num_classes' in hf_config:
# model should be created with parent num_classes if they exist
pretrained_cfg['num_classes'] = hf_config['num_classes']
model_name = hf_config['architecture']
return pretrained_cfg, model_name
@ -114,10 +136,34 @@ def save_for_hf(model, save_directory, model_config=None):
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = model.pretrained_cfg
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
# set some values at root config level
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
if 'label' in model_config:
_logger.warning(
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"Using provided 'label' field as 'label_name'.")
model_config['label_name'] = model_config.pop('label')
label_name = model_config.pop('label_name', None)
if label_name:
assert isinstance(label_name, (dict, list, tuple))
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
hf_config['label_name'] = model_config['label_name']
display_name = model_config.pop('display_name', None)
if display_name:
assert isinstance(display_name, dict)
# map label_name -> user interface display name
hf_config['display_name'] = model_config['display_name']
hf_config['pretrained_cfg'] = pretrained_cfg
hf_config.update(model_config)
with config_path.open('w') as f:
@ -134,7 +180,7 @@ def push_to_hf_hub(
create_pr: bool = False,
model_config: Optional[dict] = None,
):
# Create repo if doesn't exist yet
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url
@ -154,10 +200,11 @@ def push_to_hf_hub(
# Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config)
# Add readme if does not exist
# Add readme if it does not exist
if not has_readme:
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
readme_path.write_text(readme_text)
# Upload model and return

@ -54,7 +54,7 @@ from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, La
from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
from .layers import SelectAdaptivePool2d, create_pool2d
from .layers import to_2tuple, extend_tuple, make_divisible, _assert
from ._pretrained import generate_defaults
from ._pretrained import generate_default_cfgs
from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
@ -1859,7 +1859,7 @@ def _cfg(url='', **kwargs):
}
default_cfgs = generate_defaults({
default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg(
@ -1941,86 +1941,67 @@ default_cfgs = generate_defaults({
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# Trying to be like the MaxViT paper configs
# MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg(
url='',
#file='maxvit_tiny_tf_224_in1k.pth',
hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg(
url='',
#file='maxvit_tiny_tf_384_in1k.pth',
hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg(
url='',
#file='maxvit_tiny_tf_512_in1k.pth',
hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg(
url='',
#file='maxvit_small_tf_224_in1k.pth',
hf_hub_id='timm/maxvit_small_tf_224.in1k',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg(
url='',
#file='maxvit_small_tf_384_in1k.pth',
hf_hub_id='timm/maxvit_small_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg(
url='',
#file='maxvit_small_tf_512_in1k.pth',
hf_hub_id='timm/maxvit_small_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg(
url='',
#file='maxvit_base_tf_224_in1k.pth',
hf_hub_id='timm/maxvit_base_tf_224.in1k',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg(
url='',
#file='maxvit_base_tf_384_in1k.pth',
hf_hub_id='timm/maxvit_base_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg(
url='',
#file='maxvit_base_tf_512_in1k.pth',
hf_hub_id='timm/maxvit_base_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg(
url='',
#file='maxvit_large_tf_224_in1k.pth',
hf_hub_id='timm/maxvit_large_tf_224.in1k',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg(
url='',
#file='maxvit_large_tf_384_in1k.pth',
hf_hub_id='timm/maxvit_large_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg(
url='',
#file='maxvit_large_tf_512_in1k.pth',
hf_hub_id='timm/maxvit_large_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg(
url='',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in21k_ft1k': _cfg(
url='',
#file='maxvit_base_tf_384_in21k_ft_in1k.pth',
url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft1k': _cfg(
url='',
#file='maxvit_base_tf_512_in21k_ft_in1k.pth',
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg(
url=''),
'maxvit_large_tf_384.in21k_ft1k': _cfg(
url='',
#file='maxvit_large_tf_384_in21k_ft_in1k.pth',
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft1k': _cfg(
url='',
#file='maxvit_large_tf_512_in21k_ft_in1k.pth',
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg(
url=''),
'maxvit_xlarge_tf_384.in21k_ft1k': _cfg(
url='',
#file='maxvit_xlarge_tf_384_in21k_ft_in1k.pth',
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft1k': _cfg(
url='',
#file='maxvit_xlarge_tf_512_in21k_ft_in1k.pth',
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
})

@ -7,7 +7,7 @@ import re
import sys
from collections import defaultdict, deque
from copy import deepcopy
from typing import Optional, Tuple
from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -84,7 +84,7 @@ def _natural_key(string_):
def list_models(
filter: str = '',
filter: Union[str, List[str]] = '',
module: str = '',
pretrained=False,
exclude_filters: str = '',
@ -114,7 +114,12 @@ def list_models(
else:
all_models = _model_entrypoints.keys()
# FIXME wildcard filter tag as well as model arch name
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in all_models:
models_with_tags.extend(_model_with_tags[m])
all_models = models_with_tags
if filter:
models = []
@ -134,13 +139,6 @@ def list_models(
if len(exclude_models):
models = set(models).difference(exclude_models)
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in models:
models_with_tags.extend(_model_with_tags[m])
models = models_with_tags
if pretrained:
models = _model_has_pretrained.intersection(models)
@ -150,6 +148,18 @@ def list_models(
return list(sorted(models, key=_natural_key))
def list_pretrained(
filter: Union[str, List[str]] = '',
exclude_filters: str = '',
):
return list_models(
filter=filter,
pretrained=True,
exclude_filters=exclude_filters,
include_tags=True,
)
def is_model(model_name):
""" Check if a model name exists
"""

@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from ._pretrained import generate_defaults
from ._pretrained import generate_default_cfgs
from .registry import register_model
_logger = logging.getLogger(__name__)
@ -492,6 +492,7 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if model.cls_token is not None:
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
@ -630,51 +631,74 @@ def _cfg(url='', **kwargs):
}
default_cfgs = generate_defaults({
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224.augreg_in21k_ft_1k': _cfg(
default_cfgs = generate_default_cfgs({
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_tiny_patch16_384.augreg_in21k_ft_1k': _cfg(
'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224.augreg_in21k_ft_1k': _cfg(
'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_small_patch32_384.augreg_in21k_ft_1k': _cfg(
'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224.augreg_in21k_ft_1k': _cfg(
'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_small_patch16_384.augreg_in21k_ft_1k': _cfg(
'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224.augreg_in21k_ft_1k': _cfg(
'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_base_patch32_384.augreg_in21k_ft_1k': _cfg(
'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224.augreg_in21k_ft_1k': _cfg(
'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_base_patch16_384.augreg_in21k_ft_1k': _cfg(
'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224.augreg_in21k_ft_1k': _cfg(
'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_large_patch32_384.v1_in21k_ft_1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224.augreg_in21k_ft_1k': _cfg(
'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_large_patch16_384.augreg_in21k_ft_1k': _cfg(
'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
file='b16_augreg-a-8.pth'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(
url=''),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
url=''),
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'),
'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'),
'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
# How to train your ViT (augreg) weights trained on in1k
'vit_base_patch16_224.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_base_patch16_384.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch14_224.untrained': _cfg(url=''),
'vit_huge_patch14_224.untrained': _cfg(url=''),
'vit_giant_patch14_224.untrained': _cfg(url=''),
@ -682,6 +706,15 @@ default_cfgs = generate_defaults({
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
custom_load=True, num_classes=21843),
# How to train your ViT (augreg) weights, pretrained on in21k
'vit_tiny_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
custom_load=True, num_classes=21843),
@ -700,16 +733,9 @@ default_cfgs = generate_defaults({
'vit_base_patch8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
custom_load=True, num_classes=21843),
'vit_large_patch32_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_large_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
custom_load=True, num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
custom_load=True, num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_224.sam': _cfg(
@ -736,7 +762,7 @@ default_cfgs = generate_defaults({
'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
'vit_base_patch16_224_miil.in21k_ft_1k': _cfg(
'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
@ -744,14 +770,15 @@ default_cfgs = generate_defaults({
'vit_base_patch16_rpn_224.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
'vit_medium_patch16_gap_240.in12k': _cfg(
url='',
hf_hub_id='timm/vit_medium_patch16_gap_240.in12k',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
'vit_medium_patch16_gap_256.in12k_ft_1k': _cfg(
url='',
'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_gap_384.in12k_ft_1k': _cfg(
url='',
input_size=(3, 384, 384), crop_pct=0.95),
'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k',
input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
'vit_base_patch16_gap_224': _cfg(),
# CLIP pretrained image tower and related fine-tuned weights
'vit_base_patch32_clip_224.laion2b': _cfg(
@ -781,15 +808,16 @@ default_cfgs = generate_defaults({
'vit_base_patch32_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
@ -816,10 +844,11 @@ default_cfgs = generate_defaults({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
@ -866,7 +895,8 @@ default_cfgs = generate_defaults({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
@ -876,10 +906,15 @@ default_cfgs = generate_defaults({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
#hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
@ -1118,37 +1153,48 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
@register_model
def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
""" ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
""" ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
""" ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_gap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_clip_224(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 224x224

@ -20,7 +20,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._pretrained import generate_defaults
from ._pretrained import generate_default_cfgs
from .layers import StdConv2dSame, StdConv2d, to_2tuple
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -39,31 +39,31 @@ def _cfg(url='', **kwargs):
}
default_cfgs = generate_defaults({
default_cfgs = generate_default_cfgs({
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_1k': _cfg(
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_1k': _cfg(
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.augreg_in21k_ft_1k': _cfg(
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
),
'vit_small_r26_s32_384.augreg_in21k_ft_1k': _cfg(
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224.untrained': _cfg(),
'vit_base_r50_s16_384.v1_in21k_ft_1k': _cfg(
'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224.augreg_in21k_ft_1k': _cfg(
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True,
),
'vit_large_r50_s32_384.augreg_in21k_ft_1k': _cfg(
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
),

Loading…
Cancel
Save