From 0dadb4a6e9e245c30653db2a48752423df98fa44 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 Oct 2022 22:18:39 -0700 Subject: [PATCH] Initial multi-weight support, handled so old pretraing config handling co-exists with new tags. --- tests/test_models.py | 3 +- timm/__init__.py | 4 +- timm/data/config.py | 28 +- timm/models/__init__.py | 2 +- timm/models/_pretrained.py | 102 +++++ timm/models/factory.py | 56 ++- timm/models/helpers.py | 161 +++++--- timm/models/inception_v3.py | 13 +- timm/models/registry.py | 127 ++++--- timm/models/regnet.py | 56 ++- timm/models/resnetv2.py | 34 +- timm/models/vision_transformer.py | 450 +++++++++-------------- timm/models/vision_transformer_hybrid.py | 117 ++---- 13 files changed, 629 insertions(+), 524 deletions(-) create mode 100644 timm/models/_pretrained.py diff --git a/tests/test_models.py b/tests/test_models.py index d007d65a..dd1330eb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,8 +13,7 @@ except ImportError: has_fx_feature_extraction = False import timm -from timm import list_models, create_model, set_scriptable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ - get_pretrained_cfg_value +from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value from timm.models.fx_features import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): diff --git a/timm/__init__.py b/timm/__init__.py index c5f797b1..b8053a2b 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ - get_pretrained_cfg_value, is_model_pretrained + is_scriptable, is_exportable, set_scriptable, set_exportable, \ + is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/data/config.py b/timm/data/config.py index 78176e4b..c5da81f1 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -13,64 +13,64 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v # Resolve input/image size in_chans = 3 - if 'chans' in args and args['chans'] is not None: + if args.get('chans', None) is not None: in_chans = args['chans'] input_size = (in_chans, 224, 224) - if 'input_size' in args and args['input_size'] is not None: + if args.get('input_size', None) is not None: assert isinstance(args['input_size'], (tuple, list)) assert len(args['input_size']) == 3 input_size = tuple(args['input_size']) in_chans = input_size[0] # input_size overrides in_chans - elif 'img_size' in args and args['img_size'] is not None: + elif args.get('img_size', None) is not None: assert isinstance(args['img_size'], int) input_size = (in_chans, args['img_size'], args['img_size']) else: - if use_test_size and 'test_input_size' in default_cfg: + if use_test_size and default_cfg.get('test_input_size', None) is not None: input_size = default_cfg['test_input_size'] - elif 'input_size' in default_cfg: + elif default_cfg.get('input_size', None) is not None: input_size = default_cfg['input_size'] new_config['input_size'] = input_size # resolve interpolation method new_config['interpolation'] = 'bicubic' - if 'interpolation' in args and args['interpolation']: + if args.get('interpolation', None): new_config['interpolation'] = args['interpolation'] - elif 'interpolation' in default_cfg: + elif default_cfg.get('interpolation', None): new_config['interpolation'] = default_cfg['interpolation'] # resolve dataset + model mean for normalization new_config['mean'] = IMAGENET_DEFAULT_MEAN - if 'mean' in args and args['mean'] is not None: + if args.get('mean', None) is not None: mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: assert len(mean) == in_chans new_config['mean'] = mean - elif 'mean' in default_cfg: + elif default_cfg.get('mean', None): new_config['mean'] = default_cfg['mean'] # resolve dataset + model std deviation for normalization new_config['std'] = IMAGENET_DEFAULT_STD - if 'std' in args and args['std'] is not None: + if args.get('std', None) is not None: std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: assert len(std) == in_chans new_config['std'] = std - elif 'std' in default_cfg: + elif default_cfg.get('std', None): new_config['std'] = default_cfg['std'] # resolve default crop percentage crop_pct = DEFAULT_CROP_PCT - if 'crop_pct' in args and args['crop_pct'] is not None: + if args.get('crop_pct', None): crop_pct = args['crop_pct'] else: - if use_test_size and 'test_crop_pct' in default_cfg: + if use_test_size and default_cfg.get('test_crop_pct', None): crop_pct = default_cfg['test_crop_pct'] - elif 'crop_pct' in default_cfg: + elif default_cfg.get('crop_pct', None): crop_pct = default_cfg['crop_pct'] new_config['crop_pct'] = crop_pct diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 5ff79595..3d36ce07 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -71,4 +71,4 @@ from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import set_fast_norm from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ - is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value + is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py new file mode 100644 index 00000000..8a63ebe9 --- /dev/null +++ b/timm/models/_pretrained.py @@ -0,0 +1,102 @@ +from collections import deque, defaultdict +from dataclasses import dataclass, field, replace +from typing import Any, Deque, Dict, Tuple, Optional, Union + + +@dataclass +class PretrainedCfg: + """ + """ + # weight locations + url: str = '' + file: str = '' + hf_hub_id: str = '' + hf_hub_filename: str = '' + + source: str = '' # source of cfg / weight location used (url, file, hf-hub) + architecture: str = '' # architecture variant can be set when not implicit + custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) + + # input / data config + input_size: Tuple[int, int, int] = (3, 224, 224) + test_input_size: Optional[Tuple[int, int, int]] = None + min_input_size: Optional[Tuple[int, int, int]] = None + fixed_input_size: bool = False + interpolation: str = 'bicubic' + crop_pct: float = 0.875 + test_crop_pct: Optional[float] = None + crop_type: str = 'pct' + mean: Tuple[float, ...] = (0.485, 0.456, 0.406) + std: Tuple[float, ...] = (0.229, 0.224, 0.225) + + # head config + num_classes: int = 1000 + label_offset: int = 0 + + # model attributes that vary with above or required for pretrained adaptation + pool_size: Optional[Tuple[int, ...]] = None + test_pool_size: Optional[Tuple[int, ...]] = None + first_conv: str = '' + classifier: str = '' + + license: str = '' + source_url: str = '' + paper: str = '' + notes: str = '' + + @property + def has_weights(self): + return self.url.startswith('http') or self.file or self.hf_hub_id + + +@dataclass +class DefaultCfg: + tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default) + cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag + is_pretrained: bool = False # at least one of the configs has a pretrained source set + + @property + def default(self): + return self.cfgs[self.tags[0]] + + @property + def default_with_tag(self): + tag = self.tags[0] + return tag, self.cfgs[tag] + + +def split_model_name_tag(model_name: str, no_tag=''): + model_name, *tag_list = model_name.split('.', 1) + tag = tag_list[0] if tag_list else no_tag + return model_name, tag + + +def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]): + out = defaultdict(DefaultCfg) + default_set = set() # no tag and tags ending with * are prioritized as default + + for k, v in cfgs.items(): + if isinstance(v, dict): + v = PretrainedCfg(**v) + has_weights = v.has_weights + + model, tag = split_model_name_tag(k) + is_default_set = model in default_set + priority = not tag or (tag.endswith('*') and not is_default_set) + tag = tag.strip('*') + + default_cfg = out[model] + if has_weights: + default_cfg.is_pretrained = True + + if priority: + default_cfg.tags.appendleft(tag) + default_set.add(model) + elif has_weights and not default_set: + default_cfg.tags.appendleft(tag) + else: + default_cfg.tags.append(tag) + + default_cfg.cfgs[tag] = v + + return out diff --git a/timm/models/factory.py b/timm/models/factory.py index f7a8fd9c..0405f435 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,14 +1,18 @@ -from urllib.parse import urlsplit, urlunsplit import os +from typing import Any, Dict, Optional, Union +from urllib.parse import urlsplit -from .registry import is_model, is_model_in_modules, model_entrypoint +from ._pretrained import PretrainedCfg, split_model_name_tag from .helpers import load_checkpoint -from .layers import set_layer_config from .hub import load_model_config_from_hf +from .layers import set_layer_config +from .registry import is_model, model_entrypoint def parse_model_name(model_name): - model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use + if model_name.startswith('hf_hub'): + # NOTE for backwards compat, deprecate hf_hub use + model_name = model_name.replace('hf_hub', 'hf-hub') parsed = urlsplit(model_name) assert parsed.scheme in ('', 'timm', 'hf-hub') if parsed.scheme == 'hf-hub': @@ -20,6 +24,7 @@ def parse_model_name(model_name): def safe_model_name(model_name, remove_source=True): + # return a filename / path safe model name def make_safe(name): return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') if remove_source: @@ -28,20 +33,29 @@ def safe_model_name(model_name, remove_source=True): def create_model( - model_name, - pretrained=False, - pretrained_cfg=None, - checkpoint_path='', - scriptable=None, - exportable=None, - no_jit=None, - **kwargs): + model_name: str, + pretrained: bool = False, + pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None, + pretrained_cfg_overlay: Optional[Dict[str, Any]] = None, + checkpoint_path: str = '', + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + **kwargs, +): """Create a model + Lookup model's entrypoint function and pass relevant args to create a new model. + + **kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg() + and then the model class __init__(). kwargs values set to None are pruned before passing. + Args: model_name (str): name of model to instantiate pretrained (bool): load pretrained ImageNet-1k weights if true - checkpoint_path (str): path of checkpoint to load after model is initialized + pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model + pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these + checkpoint_path (str): path of checkpoint to load _after_ the model is initialized scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) @@ -49,7 +63,7 @@ def create_model( Keyword Args: drop_rate (float): dropout rate for training (default: 0.0) global_pool (str): global pool type (default: 'avg') - **: other kwargs are model specific + **: other kwargs are consumed by builder or model __init__() """ # Parameters that aren't supported by all models or are intended to only override model defaults if set # should default to None in command line args/cfg. Remove them if they are present and not set so that @@ -58,17 +72,27 @@ def create_model( model_source, model_name = parse_model_name(model_name) if model_source == 'hf-hub': - # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? + assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' # For model names specified in the form `hf-hub:path/architecture_name@revision`, # load model weights + pretrained_cfg from Hugging Face hub. pretrained_cfg, model_name = load_model_config_from_hf(model_name) + else: + model_name, pretrained_tag = split_model_name_tag(model_name) + if not pretrained_cfg: + # a valid pretrained_cfg argument takes priority over tag in model name + pretrained_cfg = pretrained_tag if not is_model(model_name): raise RuntimeError('Unknown model (%s)' % model_name) create_fn = model_entrypoint(model_name) with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): - model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) + model = create_fn( + pretrained=pretrained, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay, + **kwargs, + ) if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index c771e825..9050dea5 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -3,6 +3,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import collections.abc +import dataclasses import logging import math import os @@ -17,6 +18,7 @@ import torch.nn as nn from torch.hub import load_state_dict_from_url from torch.utils.checkpoint import checkpoint +from ._pretrained import PretrainedCfg from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf @@ -191,10 +193,14 @@ def load_custom_pretrained( Args: model: The instantiated model to load weights into pretrained_cfg (dict): Default pretrained model cfg - load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named + load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named 'laod_pretrained' on the model will be called if it exists """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) if not load_from: _logger.warning("No pretrained weights exist for this model. Using random initialization.") @@ -202,7 +208,11 @@ def load_custom_pretrained( if load_from == 'hf-hub': # FIXME _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") elif load_from == 'url': - pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS) + pretrained_loc = download_cached_file( + pretrained_loc, + check_hash=_CHECK_HASH, + progress=_DOWNLOAD_PROGRESS + ) if load_fn is not None: load_fn(model, pretrained_loc) @@ -250,13 +260,17 @@ def load_pretrained( Args: model (nn.Module) : PyTorch model module pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset - num_classes (int): num_classes for model - in_chans (int): in_chans for model + num_classes (int): num_classes for target model + in_chans (int): in_chans for target model filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) strict (bool): strict load of checkpoint """ - pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) + if not pretrained_cfg: + _logger.warning("Invalid pretrained config, cannot load weights.") + return + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) if load_from == 'file': _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') @@ -264,7 +278,11 @@ def load_pretrained( elif load_from == 'url': _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') state_dict = load_state_dict_from_url( - pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): @@ -428,40 +446,20 @@ def adapt_model_from_file(parent_module, model_variant): def pretrained_cfg_for_features(pretrained_cfg): pretrained_cfg = deepcopy(pretrained_cfg) # remove default pretrained cfg fields that don't have much relevance for feature backbone - to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? + to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size? for tr in to_remove: pretrained_cfg.pop(tr, None) return pretrained_cfg -def set_default_kwargs(kwargs, names, pretrained_cfg): - for n in names: - # for legacy reasons, model __init__args uses img_size + in_chans as separate args while - # pretrained_cfg has one input_size=(C, H ,W) entry - if n == 'img_size': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[-2:]) - elif n == 'in_chans': - input_size = pretrained_cfg.get('input_size', None) - if input_size is not None: - assert len(input_size) == 3 - kwargs.setdefault(n, input_size[0]) - else: - default_val = pretrained_cfg.get(n, None) - if default_val is not None: - kwargs.setdefault(n, pretrained_cfg[n]) - - -def filter_kwargs(kwargs, names): +def _filter_kwargs(kwargs, names): if not kwargs or not names: return for n in names: kwargs.pop(n, None) -def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): +def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): """ Update the default_cfg and kwargs before passing to model Args: @@ -474,31 +472,61 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): if pretrained_cfg.get('fixed_input_size', False): # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size default_kwarg_names += ('img_size',) - set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg) + + for n in default_kwarg_names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # pretrained_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = pretrained_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = pretrained_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, pretrained_cfg[n]) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) - filter_kwargs(kwargs, names=kwargs_filter) + _filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, pretrained_cfg=None): - if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed as arg - return deepcopy(pretrained_cfg) +def resolve_pretrained_cfg( + variant: str, + pretrained_cfg=None, + pretrained_cfg_overlay=None, +) -> PretrainedCfg: + model_with_tag = variant + pretrained_tag = None + if pretrained_cfg: + if isinstance(pretrained_cfg, dict): + # pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg + pretrained_cfg = PretrainedCfg(**pretrained_cfg) + elif isinstance(pretrained_cfg, str): + pretrained_tag = pretrained_cfg + pretrained_cfg = None + # fallback to looking up pretrained cfg in model registry by variant identifier - pretrained_cfg = get_pretrained_cfg(variant) + if not pretrained_cfg: + if pretrained_tag: + model_with_tag = '.'.join([variant, pretrained_tag]) + pretrained_cfg = get_pretrained_cfg(model_with_tag) + if not pretrained_cfg: _logger.warning( - f"No pretrained configuration specified for {variant} model. Using a default." + f"No pretrained configuration specified for {model_with_tag} model. Using a default." f" Please add a config to the model pretrained_cfg registry or pass explicitly.") - pretrained_cfg = dict( - url='', - num_classes=1000, - input_size=(3, 224, 224), - pool_size=None, - crop_pct=.9, - interpolation='bicubic', - first_conv='', - classifier='', - ) + pretrained_cfg = PretrainedCfg() # instance with defaults + + pretrained_cfg_overlay = pretrained_cfg_overlay or {} + if not pretrained_cfg.architecture: + pretrained_cfg_overlay.setdefault('architecture', variant) + pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay) + return pretrained_cfg @@ -507,13 +535,14 @@ def build_model_with_cfg( variant: str, pretrained: bool, pretrained_cfg: Optional[Dict] = None, + pretrained_cfg_overlay: Optional[Dict] = None, model_cfg: Optional[Any] = None, feature_cfg: Optional[Dict] = None, pretrained_strict: bool = True, pretrained_filter_fn: Optional[Callable] = None, - pretrained_custom_load: bool = False, kwargs_filter: Optional[Tuple[str]] = None, - **kwargs): + **kwargs, +): """ Build model with specified default_cfg and optional model_cfg This helper fn aids in the construction of a model including: @@ -531,7 +560,6 @@ def build_model_with_cfg( feature_cfg (Optional[Dict]: feature extraction adapter config pretrained_strict (bool): load pretrained weights strictly pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights - pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model **kwargs: model args passed through to model __init__ """ @@ -540,9 +568,16 @@ def build_model_with_cfg( feature_cfg = feature_cfg or {} # resolve and update model pretrained config and model kwargs - pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg) - update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter) - pretrained_cfg.setdefault('architecture', variant) + pretrained_cfg = resolve_pretrained_cfg( + variant, + pretrained_cfg=pretrained_cfg, + pretrained_cfg_overlay=pretrained_cfg_overlay + ) + + # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model + pretrained_cfg = dataclasses.asdict(pretrained_cfg) + + _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) # Setup for feature extraction wrapper done at end of this fn if kwargs.pop('features_only', False): @@ -551,8 +586,11 @@ def build_model_with_cfg( if 'out_indices' in kwargs: feature_cfg['out_indices'] = kwargs.pop('out_indices') - # Build the model - model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + # Instantiate the model + if model_cfg is None: + model = model_cls(**kwargs) + else: + model = model_cls(cfg=model_cfg, **kwargs) model.pretrained_cfg = pretrained_cfg model.default_cfg = model.pretrained_cfg # alias for backwards compat @@ -562,9 +600,11 @@ def build_model_with_cfg( # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: - if pretrained_custom_load: - # FIXME improve custom load trigger - load_custom_pretrained(model, pretrained_cfg=pretrained_cfg) + if pretrained_cfg.get('custom_load', False): + load_custom_pretrained( + model, + pretrained_cfg=pretrained_cfg, + ) else: load_pretrained( model, @@ -572,7 +612,8 @@ def build_model_with_cfg( num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, - strict=pretrained_strict) + strict=pretrained_strict, + ) # Wrap the model in a feature extraction module if enabled if features: diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 2c6e7eb7..c70bd608 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -27,24 +27,23 @@ def _cfg(url='', **kwargs): default_cfgs = { # original PyTorch weights, ported from Tensorflow but modified 'inception_v3': _cfg( - url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', - has_aux=True), # checkpoint has aux logit layer weights + # NOTE checkpoint has aux logit layer weights + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'), # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 'tf_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', - num_classes=1000, has_aux=False, label_offset=1), + num_classes=1000, label_offset=1), # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz 'adv_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', - num_classes=1000, has_aux=False, label_offset=1), + num_classes=1000, label_offset=1), # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html 'gluon_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults std=IMAGENET_DEFAULT_STD, # also works well with inception defaults - has_aux=False, ) } @@ -433,10 +432,10 @@ def _create_inception_v3(variant, pretrained=False, **kwargs): if aux_logits: assert not kwargs.pop('features_only', False) model_cls = InceptionV3Aux - load_strict = pretrained_cfg['has_aux'] + load_strict = variant == 'inception_v3' else: model_cls = InceptionV3 - load_strict = not pretrained_cfg['has_aux'] + load_strict = variant != 'inception_v3' return build_model_with_cfg( model_cls, variant, pretrained, diff --git a/timm/models/registry.py b/timm/models/registry.py index 9f58060f..9fa6f007 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -2,20 +2,30 @@ Hacked together by / Copyright 2020 Ross Wightman """ -import sys -import re import fnmatch -from collections import defaultdict +import re +import sys +from collections import defaultdict, deque from copy import deepcopy +from typing import Optional, Tuple + +from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag -__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', - 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] +__all__ = [ + 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] _module_to_models = defaultdict(set) # dict of sets to check membership of model in module _model_to_module = {} # mapping of model names to module names -_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_has_pretrained = set() # set of model names that have pretrained weight url present -_model_pretrained_cfgs = dict() # central repo for model default_cfgs +_model_default_cfgs = dict() # central repo for model arch -> default cfg objects +_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs +_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names + + +def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: + return split_model_name_tag(model_name)[0] def register_model(fn): @@ -35,19 +45,37 @@ def register_model(fn): _model_entrypoints[model_name] = fn _model_to_module[model_name] = module_name _module_to_models[module_name].add(model_name) - has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos cfg = mod.default_cfgs[model_name] - has_valid_pretrained = ( - ('url' in cfg and 'http' in cfg['url']) or - ('file' in cfg and cfg['file']) or - ('hf_hub_id' in cfg and cfg['hf_hub_id']) - ) - _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] - if has_valid_pretrained: - _model_has_pretrained.add(model_name) + if not isinstance(cfg, DefaultCfg): + # new style default cfg dataclass w/ multiple entries per model-arch + assert isinstance(cfg, dict) + # old style cfg dict per model-arch + cfg = PretrainedCfg(**cfg) + cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) + + for tag_idx, tag in enumerate(cfg.tags): + is_default = tag_idx == 0 + pretrained_cfg = cfg.cfgs[tag] + if is_default: + _model_pretrained_cfgs[model_name] = pretrained_cfg + if pretrained_cfg.has_weights: + # add tagless entry if it's default and has weights + _model_has_pretrained.add(model_name) + if tag: + model_name_tag = '.'.join([model_name, tag]) + _model_pretrained_cfgs[model_name_tag] = pretrained_cfg + if pretrained_cfg.has_weights: + # add model w/ tag if tag is valid + _model_has_pretrained.add(model_name_tag) + _model_with_tags[model_name].append(model_name_tag) + else: + _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) + + _model_default_cfgs[model_name] = cfg + return fn @@ -55,24 +83,39 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): +def list_models( + filter: str = '', + module: str = '', + pretrained=False, + exclude_filters: str = '', + name_matches_cfg: bool = False, + include_tags: Optional[bool] = None, +): """ Return list of available model names, sorted alphabetically Args: filter (str) - Wildcard filter string that works with fnmatch - module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') - pretrained (bool) - Include only models with pretrained weights if True + module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') + pretrained (bool) - Include only models with valid pretrained weights if True exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) - + include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults + set to True when pretrained=True else False (default: None) Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module """ + if include_tags is None: + # FIXME should this be default behaviour? or default to include_tags=True? + include_tags = pretrained + if module: all_models = list(_module_to_models[module]) else: all_models = _model_entrypoints.keys() + + # FIXME wildcard filter tag as well as model arch name + if filter: models = [] include_filters = filter if isinstance(filter, (tuple, list)) else [filter] @@ -82,6 +125,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name models = set(models).union(include_models) else: models = all_models + if exclude_filters: if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] @@ -89,23 +133,35 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name exclude_models = fnmatch.filter(models, xf) # exclude these models if len(exclude_models): models = set(models).difference(exclude_models) + + if include_tags: + # expand model names to include names w/ pretrained tags + models_with_tags = [] + for m in models: + models_with_tags.extend(_model_with_tags[m]) + models = models_with_tags + if pretrained: models = _model_has_pretrained.intersection(models) + if name_matches_cfg: models = set(_model_pretrained_cfgs).intersection(models) + return list(sorted(models, key=_natural_key)) def is_model(model_name): """ Check if a model name exists """ - return model_name in _model_entrypoints + arch_name = get_arch_name(model_name) + return arch_name in _model_entrypoints def model_entrypoint(model_name): """Fetch a model entrypoint for specified model name """ - return _model_entrypoints[model_name] + arch_name = get_arch_name(model_name) + return _model_entrypoints[arch_name] def list_modules(): @@ -121,8 +177,9 @@ def is_model_in_modules(model_name, module_names): model_name (str) - name of model to check module_names (tuple, list, set) - names of modules to search in """ + arch_name = get_arch_name(model_name) assert isinstance(module_names, (tuple, list, set)) - return any(model_name in _module_to_models[n] for n in module_names) + return any(arch_name in _module_to_models[n] for n in module_names) def is_model_pretrained(model_name): @@ -132,28 +189,12 @@ def is_model_pretrained(model_name): def get_pretrained_cfg(model_name): if model_name in _model_pretrained_cfgs: return deepcopy(_model_pretrained_cfgs[model_name]) - return {} - - -def has_pretrained_cfg_key(model_name, cfg_key): - """ Query model default_cfgs for existence of a specific key. - """ - if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]: - return True - return False - - -def is_pretrained_cfg_key(model_name, cfg_key): - """ Return truthy value for specified model default_cfg key, False if does not exist. - """ - if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False): - return True - return False + raise RuntimeError(f'No pretrained config exists for model {model_name}.') def get_pretrained_cfg_value(model_name, cfg_key): - """ Get a specific model default_cfg value by key. None if it doesn't exist. + """ Get a specific model default_cfg value by key. None if key doesn't exist. """ if model_name in _model_pretrained_cfgs: - return _model_pretrained_cfgs[model_name].get(cfg_key, None) - return None \ No newline at end of file + return getattr(_model_pretrained_cfgs[model_name], cfg_key, None) + raise RuntimeError(f'No pretrained config exist for model {model_name}.') \ No newline at end of file diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 9d1f1f64..0ad7c826 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -76,6 +76,9 @@ model_cfgs = dict( regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25), regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25), regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25), + regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25), + regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25), + regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25), # Experimental regnety_040s_gn=RegNetCfg( @@ -150,7 +153,12 @@ default_cfgs = dict( regnety_160=_cfg( url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository crop_pct=1.0, test_input_size=(3, 288, 288)), - regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), + regnety_320=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth' + ), + regnety_640=_cfg(url=''), + regnety_1280=_cfg(url=''), + regnety_2560=_cfg(url=''), regnety_040s_gn=_cfg(url=''), regnetv_040=_cfg( @@ -508,6 +516,34 @@ def _init_weights(module, name='', zero_init_last=False): def _filter_fn(state_dict): """ convert patch embedding weight from manual patchify + linear proj to conv""" + if 'classy_state_dict' in state_dict: + import re + state_dict = state_dict['classy_state_dict']['base_model']['model'] + out = {} + for k, v in state_dict['trunk'].items(): + k = k.replace('_feature_blocks.conv1.stem.0', 'stem.conv') + k = k.replace('_feature_blocks.conv1.stem.1', 'stem.bn') + k = re.sub( + r'^_feature_blocks.res\d.block(\d)-(\d+)', + lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k) + k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k) + k = k.replace('proj', 'downsample.conv') + k = k.replace('f.a.0', 'conv1.conv') + k = k.replace('f.a.1', 'conv1.bn') + k = k.replace('f.b.0', 'conv2.conv') + k = k.replace('f.b.1', 'conv2.bn') + k = k.replace('f.c', 'conv3.conv') + k = k.replace('f.final_bn', 'conv3.bn') + k = k.replace('f.se.excitation.0', 'se.fc1') + k = k.replace('f.se.excitation.2', 'se.fc2') + out[k] = v + for k, v in state_dict['heads'].items(): + if 'projection_head' in k or 'prototypes' in k: + continue + k = k.replace('0.clf.0', 'head.fc') + out[k] = v + return out + if 'model' in state_dict: # For DeiT trained regnety_160 pretraiend model state_dict = state_dict['model'] @@ -666,6 +702,24 @@ def regnety_320(pretrained=False, **kwargs): return _create_regnet('regnety_320', pretrained, **kwargs) +@register_model +def regnety_640(pretrained=False, **kwargs): + """RegNetY-64GF""" + return _create_regnet('regnety_640', pretrained, **kwargs) + + +@register_model +def regnety_1280(pretrained=False, **kwargs): + """RegNetY-128GF""" + return _create_regnet('regnety_1280', pretrained, **kwargs) + + +@register_model +def regnety_2560(pretrained=False, **kwargs): + """RegNetY-256GF""" + return _create_regnet('regnety_2560', pretrained, **kwargs) + + @register_model def regnety_040s_gn(pretrained=False, **kwargs): """RegNetY-4.0GF w/ GroupNorm """ diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index bde088db..b21ef7f5 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -57,52 +57,52 @@ default_cfgs = { # pretrained on imagenet21k, finetuned on imagenet1k 'resnetv2_50x1_bitm': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz', - input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), 'resnetv2_50x3_bitm': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz', - input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), 'resnetv2_101x1_bitm': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz', - input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), 'resnetv2_101x3_bitm': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz', - input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), 'resnetv2_152x2_bitm': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz', - input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True), 'resnetv2_152x4_bitm': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz', - input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480? + input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480? # trained on imagenet-21k 'resnetv2_50x1_bitm_in21k': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz', - num_classes=21843), + num_classes=21843, custom_load=True), 'resnetv2_50x3_bitm_in21k': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz', - num_classes=21843), + num_classes=21843, custom_load=True), 'resnetv2_101x1_bitm_in21k': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz', - num_classes=21843), + num_classes=21843, custom_load=True), 'resnetv2_101x3_bitm_in21k': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz', - num_classes=21843), + num_classes=21843, custom_load=True), 'resnetv2_152x2_bitm_in21k': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz', - num_classes=21843), + num_classes=21843, custom_load=True), 'resnetv2_152x4_bitm_in21k': _cfg( url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', - num_classes=21843), + num_classes=21843, custom_load=True), 'resnetv2_50x1_bit_distilled': _cfg( url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz', - interpolation='bicubic'), + interpolation='bicubic', custom_load=True), 'resnetv2_152x2_bit_teacher': _cfg( url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz', - interpolation='bicubic'), + interpolation='bicubic', custom_load=True), 'resnetv2_152x2_bit_teacher_384': _cfg( url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True), 'resnetv2_50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth', @@ -507,8 +507,8 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): return build_model_with_cfg( ResNetV2, variant, pretrained, feature_cfg=feature_cfg, - pretrained_custom_load='_bit' in variant, - **kwargs) + **kwargs, + ) def _create_resnetv2_bit(variant, pretrained=False, **kwargs): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 6ed70045..f735d44c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from ._pretrained import generate_defaults from .registry import register_model _logger = logging.getLogger(__name__) @@ -50,59 +51,50 @@ def _cfg(url='', **kwargs): } -default_cfgs = { +default_cfgs = generate_defaults({ # patch models (weights from official Google JAX impl) - 'vit_tiny_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_tiny_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_patch32_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_small_patch32_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_small_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch32_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), - 'vit_base_patch32_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), - 'vit_base_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_patch8_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), - 'vit_large_patch32_224': _cfg( - url='', # no official model weights for this combo, only for in21k - ), - 'vit_large_patch32_384': _cfg( + 'vit_tiny_patch16_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + custom_load=True), + 'vit_tiny_patch16_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + custom_load=True), + 'vit_small_patch32_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + custom_load=True), + 'vit_small_patch16_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + custom_load=True), + 'vit_base_patch32_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + custom_load=True), + 'vit_base_patch16_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + custom_load=True), + 'vit_large_patch32_384.in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_patch16_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), - 'vit_large_patch16_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + custom_load=True), + 'vit_large_patch16_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''), @@ -111,92 +103,128 @@ default_cfgs = { # patch models, imagenet21k (weights from official Google JAX impl) - 'vit_tiny_patch16_224_in21k': _cfg( + 'vit_tiny_patch16_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_small_patch32_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_small_patch32_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_small_patch16_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_small_patch16_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_base_patch32_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_base_patch32_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_base_patch16_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_base_patch16_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_base_patch8_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_base_patch8_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', - num_classes=21843), - 'vit_large_patch32_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_large_patch32_224.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843), - 'vit_large_patch16_224_in21k': _cfg( + 'vit_large_patch16_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', - num_classes=21843), - 'vit_huge_patch14_224_in21k': _cfg( + custom_load=True, num_classes=21843), + 'vit_huge_patch14_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub_id='timm/vit_huge_patch14_224_in21k', - num_classes=21843), + custom_load=True, num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) - 'vit_base_patch32_224_sam': _cfg( - url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), - 'vit_base_patch16_224_sam': _cfg( - url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + 'vit_base_patch32_224.sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True), + 'vit_base_patch16_224.sam': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True), # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) - 'vit_small_patch16_224_dino': _cfg( + 'vit_small_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_small_patch8_224_dino': _cfg( + 'vit_small_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_base_patch16_224_dino': _cfg( + 'vit_base_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_base_patch8_224_dino': _cfg( + 'vit_base_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), # ViT ImageNet-21K-P pretraining by MILL - 'vit_base_patch16_224_miil_in21k': _cfg( + 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), - 'vit_base_patch16_224_miil': _cfg( + 'vit_base_patch16_224_miil.in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), + # custom timm variants 'vit_base_patch16_rpn_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), - - # experimental (may be removed) - 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), - 'vit_small_patch16_36x1_224': _cfg(url=''), - 'vit_small_patch16_18x2_224': _cfg(url=''), - 'vit_base_patch16_18x2_224': _cfg(url=''), - - 'vit_base_patch32_224_clip_laion2b': _cfg( + 'vit_medium_patch16_gap_240.in12k': _cfg( + url='', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), + 'vit_medium_patch16_gap_256.in12ft1k': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_gap_384.in12ft1k': _cfg(url='', input_size=(3, 384, 384), crop_pct=0.95), + + # CLIP pretrained image tower and related fine-tuned weights + 'vit_base_patch32_224_clip.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_large_patch14_224_clip_laion2b': _cfg( + 'vit_large_patch14_224_clip.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', hf_hub_filename='open_clip_pytorch_model.bin', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768), - 'vit_huge_patch14_224_clip_laion2b': _cfg( + 'vit_huge_patch14_224_clip.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), - 'vit_giant_patch14_224_clip_laion2b': _cfg( + 'vit_giant_patch14_224_clip.laion2b': _cfg( hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), -} + 'vit_base_patch32_224_clip.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_large_patch14_224_clip.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'vit_huge_patch14_224_clip.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + + 'vit_base_patch32_224_clip.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_large_patch14_224_clip.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'vit_huge_patch14_224_clip.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + + 'vit_base_patch32_224_clip.laion2b_ft_in12k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_224_clip.laion2b_ft_in12k': _cfg( + hf_hub_id='', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=11821), + 'vit_huge_patch14_224_clip.laion2b_ft_in12k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + + # experimental (may be removed) + 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), + 'vit_small_patch16_36x1_224': _cfg(url=''), + 'vit_small_patch16_18x2_224': _cfg(url=''), + 'vit_base_patch16_18x2_224': _cfg(url=''), +}) class Attention(nn.Module): @@ -782,14 +810,11 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) - model = build_model_with_cfg( + return build_model_with_cfg( VisionTransformer, variant, pretrained, - pretrained_cfg=pretrained_cfg, pretrained_filter_fn=checkpoint_filter_fn, - pretrained_custom_load='npz' in pretrained_cfg['url'], - **kwargs) - return model + **kwargs, + ) @register_model @@ -831,7 +856,6 @@ def vit_small_patch32_384(pretrained=False, **kwargs): @register_model def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) - NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) @@ -841,13 +865,21 @@ def vit_small_patch16_224(pretrained=False, **kwargs): @register_model def vit_small_patch16_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) - NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_small_patch8_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/8) + """ + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch32_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). @@ -974,175 +1006,90 @@ def vit_gigantic_patch14_224(pretrained=False, **kwargs): @register_model -def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Tiny (Vit-Ti/16). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer - """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_patch32_224_in21k(pretrained=False, **kwargs): - """ ViT-Small (ViT-S/16) - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer - """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Small (ViT-S/16) - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer +def vit_base_patch16_224_miil(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch32_224_in21k(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer +def vit_base_patch32_224_clip(pretrained=False, **kwargs): + """ ViT-B/32 + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_clip', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer +def vit_medium_patch16_gap_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) + model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch8_224_in21k(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer +def vit_medium_patch16_gap_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 """ - model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) + model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_large_patch32_224_in21k(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights +def vit_medium_patch16_gap_384(pretrained=False, **kwargs): + """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) + model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_large_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer +def vit_large_patch14_224_clip(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) + model = _create_vision_transformer('vit_large_patch14_224_clip', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): +def vit_huge_patch14_224_clip(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights - """ - model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_patch16_224_sam(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 - """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_patch32_224_sam(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 - """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_patch16_224_dino(pretrained=False, **kwargs): - """ ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 - """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_patch8_224_dino(pretrained=False, **kwargs): - """ ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 - """ - model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_patch16_224_dino(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 - """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_patch8_224_dino(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294 - """ - model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_clip', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224_miil(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K +def vit_giant_patch14_224_clip(pretrained=False, **kwargs): + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + model_kwargs = dict( + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, + pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) + model = _create_vision_transformer('vit_giant_patch14_224_clip', pretrained=pretrained, **model_kwargs) return model @@ -1211,46 +1158,3 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): return model -@register_model -def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs): - """ ViT-B/32 - Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. - """ - model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/14) - Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. - """ - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs): - """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). - Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. - """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs): - """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 - Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs. - """ - model_kwargs = dict( - patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, - pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs) - return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 156894ac..043df661 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -20,6 +20,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._pretrained import generate_defaults from .layers import StdConv2dSame, StdConv2d, to_2tuple from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, create_resnetv2_stem @@ -38,52 +39,49 @@ def _cfg(url='', **kwargs): } -default_cfgs = { +default_cfgs = generate_defaults({ # hybrid in-1k models (weights from official JAX impl where they exist) - 'vit_tiny_r_s16_p8_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + 'vit_tiny_r_s16_p8_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + custom_load=True, first_conv='patch_embed.backbone.conv'), - 'vit_tiny_r_s16_p8_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r26_s32_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + 'vit_tiny_r_s16_p8_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), + 'vit_small_r26_s32_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + custom_load=True, ), - 'vit_small_r26_s32_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_r26_s32_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), 'vit_base_r26_s32_224': _cfg(), 'vit_base_r50_s16_224': _cfg(), - 'vit_base_r50_s16_384': _cfg( + 'vit_base_r50_s16_384.in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_r50_s32_224': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' + 'vit_large_r50_s32_224.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + custom_load=True, ), - 'vit_large_r50_s32_384': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/' - 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0 + 'vit_large_r50_s32_384.in21ft1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0, custom_load=True, ), # hybrid in-21k models (weights from official Google JAX impl where they exist) - 'vit_tiny_r_s16_p8_224_in21k': _cfg( + 'vit_tiny_r_s16_p8_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), - 'vit_small_r26_s32_224_in21k': _cfg( + num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True), + 'vit_small_r26_s32_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843, crop_pct=0.9), - 'vit_base_r50_s16_224_in21k': _cfg( + num_classes=21843, crop_pct=0.9, custom_load=True), + 'vit_base_r50_s16_224.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', num_classes=21843, crop_pct=0.9), - 'vit_large_r50_s32_224_in21k': _cfg( + 'vit_large_r50_s32_224.in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', - num_classes=21843, crop_pct=0.9), + num_classes=21843, crop_pct=0.9, custom_load=True), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg( @@ -94,7 +92,7 @@ default_cfgs = { mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), 'vit_base_resnet50d_224': _cfg( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), -} +}) class HybridEmbed(nn.Module): @@ -248,12 +246,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model -@register_model -def vit_base_resnet50_384(pretrained=False, **kwargs): - # DEPRECATED this is forwarding to model def above for backwards compatibility - return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) - - @register_model def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. @@ -276,57 +268,6 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs): return model -@register_model -def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid. ImageNet-21k. - """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): - """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - """ - backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): - # DEPRECATED this is forwarding to model def above for backwards compatibility - return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) - - -@register_model -def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): - """ R50+ViT-L/S32 hybrid. ImageNet-21k. - """ - backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.