From 23b357f1dffc2bb4fe20b8f697d81d56a5b9e1cf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 2 Dec 2022 16:54:18 -0800 Subject: [PATCH] Add ported Tensorflow MaxVit weights. Add a few more CLIP ViT fine-tunes. Tweak some model tag names. Improve model tag name sorting. Update HF hub push config layout. --- timm/__init__.py | 2 +- timm/models/__init__.py | 5 +- timm/models/_pretrained.py | 60 ++++++---- timm/models/convnext.py | 4 +- timm/models/helpers.py | 2 +- timm/models/hub.py | 77 +++++++++--- timm/models/maxxvit.py | 75 +++++------- timm/models/registry.py | 30 +++-- timm/models/vision_transformer.py | 146 +++++++++++++++-------- timm/models/vision_transformer_hybrid.py | 18 +-- 10 files changed, 262 insertions(+), 157 deletions(-) diff --git a/timm/__init__.py b/timm/__init__.py index b8053a2b..faf34dbc 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ -from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ +from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \ is_scriptable, is_exportable, set_scriptable, set_exportable, \ is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 3d36ce07..e3449103 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -70,5 +70,6 @@ from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import set_fast_norm -from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ - is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value +from ._pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag +from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\ + is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index 61fb718c..60f38fd4 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -1,5 +1,6 @@ +import copy from collections import deque, defaultdict -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field, replace, asdict from typing import Any, Deque, Dict, Tuple, Optional, Union @@ -8,13 +9,13 @@ class PretrainedCfg: """ """ # weight locations - url: str = '' - file: str = '' - hf_hub_id: str = '' - hf_hub_filename: str = '' + url: Optional[Union[str, Tuple[str, str]]] = None + file: Optional[str] = None + hf_hub_id: Optional[str] = None + hf_hub_filename: Optional[str] = None - source: str = '' # source of cfg / weight location used (url, file, hf-hub) - architecture: str = '' # architecture variant can be set when not implicit + source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) + architecture: Optional[str] = None # architecture variant can be set when not implicit custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) # input / data config @@ -31,22 +32,40 @@ class PretrainedCfg: # head config num_classes: int = 1000 - label_offset: int = 0 + label_offset: Optional[int] = None # model attributes that vary with above or required for pretrained adaptation pool_size: Optional[Tuple[int, ...]] = None test_pool_size: Optional[Tuple[int, ...]] = None - first_conv: str = '' - classifier: str = '' + first_conv: Optional[str] = None + classifier: Optional[str] = None - license: str = '' - source_url: str = '' - paper: str = '' - notes: str = '' + license: Optional[str] = None + source_url: Optional[str] = None + paper: Optional[str] = None + notes: Optional[str] = None @property def has_weights(self): - return self.url.startswith('http') or self.file or self.hf_hub_id + return self.url or self.file or self.hf_hub_id + + def to_dict(self, remove_source=False, remove_null=True): + return filter_pretrained_cfg( + asdict(self), + remove_source=remove_source, + remove_null=remove_null + ) + + +def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): + filtered_cfg = {} + for k, v in cfg.items(): + if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: + continue + if remove_null and v is None: + continue + filtered_cfg[k] = v + return filtered_cfg @dataclass @@ -71,7 +90,7 @@ def split_model_name_tag(model_name: str, no_tag=''): return model_name, tag -def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]): +def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]): out = defaultdict(DefaultCfg) default_set = set() # no tag and tags ending with * are prioritized as default @@ -82,21 +101,22 @@ def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]): model, tag = split_model_name_tag(k) is_default_set = model in default_set - priority = not tag or (tag.endswith('*') and not is_default_set) + priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set) tag = tag.strip('*') default_cfg = out[model] - if has_weights: - default_cfg.is_pretrained = True if priority: default_cfg.tags.appendleft(tag) default_set.add(model) - elif has_weights and not default_set: + elif has_weights and not default_cfg.is_pretrained: default_cfg.tags.appendleft(tag) else: default_cfg.tags.append(tag) + if has_weights: + default_cfg.is_pretrained = True + default_cfg.cfgs[tag] = v return out diff --git a/timm/models/convnext.py b/timm/models/convnext.py index e64bd0ef..4c89ebf2 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -21,7 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple -from ._pretrained import generate_defaults +from ._pretrained import generate_default_cfgs from .registry import register_model @@ -373,7 +373,7 @@ def _cfg(url='', **kwargs): } -default_cfgs = generate_defaults({ +default_cfgs = generate_default_cfgs({ # timm specific variants 'convnext_atto.timm_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 93f81030..60ff2b0a 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -575,7 +575,7 @@ def build_model_with_cfg( ) # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model - pretrained_cfg = dataclasses.asdict(pretrained_cfg) + pretrained_cfg = pretrained_cfg.to_dict() _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) diff --git a/timm/models/hub.py b/timm/models/hub.py index 4cf3e4e7..2a87ae7e 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -15,11 +15,13 @@ except ImportError: from torch.hub import _get_torch_home as get_dir from timm import __version__ +from timm.models._pretrained import filter_pretrained_cfg try: - from huggingface_hub import (create_repo, get_hf_file_metadata, - hf_hub_download, hf_hub_url, - repo_type_and_id_from_hf_id, upload_folder) + from huggingface_hub import ( + create_repo, get_hf_file_metadata, + hf_hub_download, hf_hub_url, + repo_type_and_id_from_hf_id, upload_folder) from huggingface_hub.utils import EntryNotFoundError hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__) _has_hf_hub = True @@ -46,8 +48,11 @@ def get_cache_dir(child_dir=''): def download_cached_file(url, check_hash=True, progress=False): - parts = urlparse(url) - filename = os.path.basename(parts.path) + if isinstance(url, (list, tuple)): + url, filename = url + else: + parts = urlparse(url) + filename = os.path.basename(parts.path) cached_file = os.path.join(get_cache_dir(), filename) if not os.path.exists(cached_file): _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) @@ -90,10 +95,27 @@ def _download_from_hf(model_id: str, filename: str): def load_model_config_from_hf(model_id: str): assert has_hf_hub(True) cached_file = _download_from_hf(model_id, 'config.json') - pretrained_cfg = load_cfg_from_json(cached_file) + + hf_config = load_cfg_from_json(cached_file) + if 'pretrained_cfg' not in hf_config: + # old form, pull pretrain_cfg out of the base dict + pretrained_cfg = hf_config + hf_config = {} + hf_config['architecture'] = pretrained_cfg.pop('architecture') + hf_config['num_features'] = pretrained_cfg.pop('num_features', None) + if 'labels' in pretrained_cfg: + hf_config['label_name'] = pretrained_cfg.pop('labels') + hf_config['pretrained_cfg'] = pretrained_cfg + + # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now + pretrained_cfg = hf_config['pretrained_cfg'] pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation pretrained_cfg['source'] = 'hf-hub' - model_name = pretrained_cfg.get('architecture') + if 'num_classes' in hf_config: + # model should be created with parent num_classes if they exist + pretrained_cfg['num_classes'] = hf_config['num_classes'] + model_name = hf_config['architecture'] + return pretrained_cfg, model_name @@ -114,10 +136,34 @@ def save_for_hf(model, save_directory, model_config=None): torch.save(model.state_dict(), weights_path) config_path = save_directory / 'config.json' - hf_config = model.pretrained_cfg - hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) - hf_config['num_features'] = model_config.pop('num_features', model.num_features) - hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) + hf_config = {} + pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) + # set some values at root config level + hf_config['architecture'] = pretrained_cfg.pop('architecture') + hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) + hf_config['num_features'] = model_config.get('num_features', model.num_features) + hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) + + if 'label' in model_config: + _logger.warning( + "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " + "Using provided 'label' field as 'label_name'.") + model_config['label_name'] = model_config.pop('label') + + label_name = model_config.pop('label_name', None) + if label_name: + assert isinstance(label_name, (dict, list, tuple)) + # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) + # can be a dict id: name if there are id gaps, or tuple/list if no gaps. + hf_config['label_name'] = model_config['label_name'] + + display_name = model_config.pop('display_name', None) + if display_name: + assert isinstance(display_name, dict) + # map label_name -> user interface display name + hf_config['display_name'] = model_config['display_name'] + + hf_config['pretrained_cfg'] = pretrained_cfg hf_config.update(model_config) with config_path.open('w') as f: @@ -127,14 +173,14 @@ def save_for_hf(model, save_directory, model_config=None): def push_to_hf_hub( model, repo_id: str, - commit_message: str ='Add model', + commit_message: str = 'Add model', token: Optional[str] = None, revision: Optional[str] = None, private: bool = False, create_pr: bool = False, model_config: Optional[dict] = None, ): - # Create repo if doesn't exist yet + # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) # Infer complete repo_id from repo_url @@ -154,10 +200,11 @@ def push_to_hf_hub( # Save model weights and config. save_for_hf(model, tmpdir, model_config=model_config) - # Add readme if does not exist + # Add readme if it does not exist if not has_readme: + model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}' + readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' readme_path.write_text(readme_text) # Upload model and return diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 13fd7abf..c03c7d0f 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -54,7 +54,7 @@ from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, La from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d from .layers import SelectAdaptivePool2d, create_pool2d from .layers import to_2tuple, extend_tuple, make_divisible, _assert -from ._pretrained import generate_defaults +from ._pretrained import generate_default_cfgs from .registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location @@ -1859,7 +1859,7 @@ def _cfg(url='', **kwargs): } -default_cfgs = generate_defaults({ +default_cfgs = generate_default_cfgs({ # Fiddling with configs / defaults / still pretraining 'coatnet_pico_rw_224': _cfg(url=''), 'coatnet_nano_rw_224': _cfg( @@ -1941,86 +1941,67 @@ default_cfgs = generate_defaults({ 'maxxvit_rmlp_large_rw_224': _cfg(url=''), - # Trying to be like the MaxViT paper configs + # MaxViT models ported from official Tensorflow impl 'maxvit_tiny_tf_224.in1k': _cfg( - url='', - #file='maxvit_tiny_tf_224_in1k.pth', + hf_hub_id='timm/maxvit_tiny_tf_224.in1k', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_tiny_tf_384.in1k': _cfg( - url='', - #file='maxvit_tiny_tf_384_in1k.pth', + hf_hub_id='timm/maxvit_tiny_tf_384.in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), 'maxvit_tiny_tf_512.in1k': _cfg( - url='', - #file='maxvit_tiny_tf_512_in1k.pth', + hf_hub_id='timm/maxvit_tiny_tf_512.in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_224.in1k': _cfg( - url='', - #file='maxvit_small_tf_224_in1k.pth', + hf_hub_id='timm/maxvit_small_tf_224.in1k', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_small_tf_384.in1k': _cfg( - url='', - #file='maxvit_small_tf_384_in1k.pth', + hf_hub_id='timm/maxvit_small_tf_384.in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_512.in1k': _cfg( - url='', - #file='maxvit_small_tf_512_in1k.pth', + hf_hub_id='timm/maxvit_small_tf_512.in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in1k': _cfg( - url='', - #file='maxvit_base_tf_224_in1k.pth', + hf_hub_id='timm/maxvit_base_tf_224.in1k', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_base_tf_384.in1k': _cfg( - url='', - #file='maxvit_base_tf_384_in1k.pth', + hf_hub_id='timm/maxvit_base_tf_384.in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in1k': _cfg( - url='', - #file='maxvit_base_tf_512_in1k.pth', + hf_hub_id='timm/maxvit_base_tf_512.in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in1k': _cfg( - url='', - #file='maxvit_large_tf_224_in1k.pth', + hf_hub_id='timm/maxvit_large_tf_224.in1k', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_large_tf_384.in1k': _cfg( - url='', - #file='maxvit_large_tf_384_in1k.pth', + hf_hub_id='timm/maxvit_large_tf_384.in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in1k': _cfg( - url='', - #file='maxvit_large_tf_512_in1k.pth', + hf_hub_id='timm/maxvit_large_tf_512.in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in21k': _cfg( - url='', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'maxvit_base_tf_384.in21k_ft1k': _cfg( - url='', - #file='maxvit_base_tf_384_in21k_ft_in1k.pth', + url=''), + 'maxvit_base_tf_384.in21k_ft_in1k': _cfg( + hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), - 'maxvit_base_tf_512.in21k_ft1k': _cfg( - url='', - #file='maxvit_base_tf_512_in21k_ft_in1k.pth', + 'maxvit_base_tf_512.in21k_ft_in1k': _cfg( + hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in21k': _cfg( url=''), - 'maxvit_large_tf_384.in21k_ft1k': _cfg( - url='', - #file='maxvit_large_tf_384_in21k_ft_in1k.pth', + 'maxvit_large_tf_384.in21k_ft_in1k': _cfg( + hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), - 'maxvit_large_tf_512.in21k_ft1k': _cfg( - url='', - #file='maxvit_large_tf_512_in21k_ft_in1k.pth', + 'maxvit_large_tf_512.in21k_ft_in1k': _cfg( + hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_224.in21k': _cfg( url=''), - 'maxvit_xlarge_tf_384.in21k_ft1k': _cfg( - url='', - #file='maxvit_xlarge_tf_384_in21k_ft_in1k.pth', + 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( + hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k', input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), - 'maxvit_xlarge_tf_512.in21k_ft1k': _cfg( - url='', - #file='maxvit_xlarge_tf_512_in21k_ft_in1k.pth', + 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( + hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), }) diff --git a/timm/models/registry.py b/timm/models/registry.py index 9fa6f007..00857a03 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -7,7 +7,7 @@ import re import sys from collections import defaultdict, deque from copy import deepcopy -from typing import Optional, Tuple +from typing import List, Optional, Union, Tuple from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag @@ -84,7 +84,7 @@ def _natural_key(string_): def list_models( - filter: str = '', + filter: Union[str, List[str]] = '', module: str = '', pretrained=False, exclude_filters: str = '', @@ -114,7 +114,12 @@ def list_models( else: all_models = _model_entrypoints.keys() - # FIXME wildcard filter tag as well as model arch name + if include_tags: + # expand model names to include names w/ pretrained tags + models_with_tags = [] + for m in all_models: + models_with_tags.extend(_model_with_tags[m]) + all_models = models_with_tags if filter: models = [] @@ -134,13 +139,6 @@ def list_models( if len(exclude_models): models = set(models).difference(exclude_models) - if include_tags: - # expand model names to include names w/ pretrained tags - models_with_tags = [] - for m in models: - models_with_tags.extend(_model_with_tags[m]) - models = models_with_tags - if pretrained: models = _model_has_pretrained.intersection(models) @@ -150,6 +148,18 @@ def list_models( return list(sorted(models, key=_natural_key)) +def list_pretrained( + filter: Union[str, List[str]] = '', + exclude_filters: str = '', +): + return list_models( + filter=filter, + pretrained=True, + exclude_filters=exclude_filters, + include_tags=True, + ) + + def is_model(model_name): """ Check if a model name exists """ diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index cde0018b..3e667eef 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ -from ._pretrained import generate_defaults +from ._pretrained import generate_default_cfgs from .registry import register_model _logger = logging.getLogger(__name__) @@ -492,7 +492,8 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) - model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + if model.cls_token is not None: + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights @@ -630,51 +631,74 @@ def _cfg(url='', **kwargs): } -default_cfgs = generate_defaults({ - # patch models (weights from official Google JAX impl) - 'vit_tiny_patch16_224.augreg_in21k_ft_1k': _cfg( +default_cfgs = generate_default_cfgs({ + + # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k + 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), - 'vit_tiny_patch16_384.augreg_in21k_ft_1k': _cfg( + 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_patch32_224.augreg_in21k_ft_1k': _cfg( + 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), - 'vit_small_patch32_384.augreg_in21k_ft_1k': _cfg( + 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_patch16_224.augreg_in21k_ft_1k': _cfg( + 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), - 'vit_small_patch16_384.augreg_in21k_ft_1k': _cfg( + 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch32_224.augreg_in21k_ft_1k': _cfg( + 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True), - 'vit_base_patch32_384.augreg_in21k_ft_1k': _cfg( + 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch16_224.augreg_in21k_ft_1k': _cfg( + 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), - 'vit_base_patch16_384.augreg_in21k_ft_1k': _cfg( + 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch8_224.augreg_in21k_ft_1k': _cfg( + 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), - 'vit_large_patch32_384.v1_in21k_ft_1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_patch16_224.augreg_in21k_ft_1k': _cfg( + 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True), - 'vit_large_patch16_384.augreg_in21k_ft_1k': _cfg( + 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + # re-finetuned augreg 21k FT on in1k weights + 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( + file='b16_augreg-a-8.pth'), + 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg( + url=''), + 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( + url=''), + + # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k + 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'), + 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'), + 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + # How to train your ViT (augreg) weights trained on in1k + 'vit_base_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + custom_load=True), + 'vit_base_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch14_224.untrained': _cfg(url=''), 'vit_huge_patch14_224.untrained': _cfg(url=''), 'vit_giant_patch14_224.untrained': _cfg(url=''), @@ -682,6 +706,15 @@ default_cfgs = generate_defaults({ # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_large_patch32_224.v1_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_huge_patch14_224.v1_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub_id='timm/vit_huge_patch14_224_in21k', + custom_load=True, num_classes=21843), + + # How to train your ViT (augreg) weights, pretrained on in21k 'vit_tiny_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), @@ -700,16 +733,9 @@ default_cfgs = generate_defaults({ 'vit_base_patch8_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', custom_load=True, num_classes=21843), - 'vit_large_patch32_224.v1_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', - num_classes=21843), 'vit_large_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', custom_load=True, num_classes=21843), - 'vit_huge_patch14_224.v1_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', - hf_hub_id='timm/vit_huge_patch14_224_in21k', - custom_load=True, num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) 'vit_base_patch32_224.sam': _cfg( @@ -736,7 +762,7 @@ default_cfgs = generate_defaults({ 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), - 'vit_base_patch16_224_miil.in21k_ft_1k': _cfg( + 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), @@ -744,14 +770,15 @@ default_cfgs = generate_defaults({ 'vit_base_patch16_rpn_224.in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), 'vit_medium_patch16_gap_240.in12k': _cfg( - url='', + hf_hub_id='timm/vit_medium_patch16_gap_240.in12k', input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), - 'vit_medium_patch16_gap_256.in12k_ft_1k': _cfg( - url='', + 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_medium_patch16_gap_384.in12k_ft_1k': _cfg( - url='', - input_size=(3, 384, 384), crop_pct=0.95), + 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg( + hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k', + input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), + 'vit_base_patch16_gap_224': _cfg(), # CLIP pretrained image tower and related fine-tuned weights 'vit_base_patch32_clip_224.laion2b': _cfg( @@ -781,15 +808,16 @@ default_cfgs = generate_defaults({ 'vit_base_patch32_clip_384.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + 'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), - 'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), @@ -816,10 +844,11 @@ default_cfgs = generate_defaults({ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), @@ -866,7 +895,8 @@ default_cfgs = generate_defaults({ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), @@ -876,10 +906,15 @@ default_cfgs = generate_defaults({ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( - #hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), @@ -1118,37 +1153,48 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): @register_model def vit_medium_patch16_gap_240(pretrained=False, **kwargs): - """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) + global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) return model @register_model def vit_medium_patch16_gap_256(pretrained=False, **kwargs): - """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) + global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) return model @register_model def vit_medium_patch16_gap_384(pretrained=False, **kwargs): - """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) + global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_base_patch16_gap_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, + global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch32_clip_224(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 224x224 diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index ffd1be54..b4c7d9e7 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from ._pretrained import generate_defaults +from ._pretrained import generate_default_cfgs from .layers import StdConv2dSame, StdConv2d, to_2tuple from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem @@ -39,31 +39,31 @@ def _cfg(url='', **kwargs): } -default_cfgs = generate_defaults({ +default_cfgs = generate_default_cfgs({ # hybrid in-1k models (weights from official JAX impl where they exist) - 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_1k': _cfg( + 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True, first_conv='patch_embed.backbone.conv'), - 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_1k': _cfg( + 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), - 'vit_small_r26_s32_224.augreg_in21k_ft_1k': _cfg( + 'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True, ), - 'vit_small_r26_s32_384.augreg_in21k_ft_1k': _cfg( + 'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), 'vit_base_r26_s32_224.untrained': _cfg(), - 'vit_base_r50_s16_384.v1_in21k_ft_1k': _cfg( + 'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_r50_s32_224.augreg_in21k_ft_1k': _cfg( + 'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', custom_load=True, ), - 'vit_large_r50_s32_384.augreg_in21k_ft_1k': _cfg( + 'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True, ),