Initial multi-weight support, handled so old pretraing config handling co-exists with new tags.

pull/1520/head
Ross Wightman 2 years ago
parent 475ecdfa3d
commit ebb99a1f8d

@ -13,8 +13,7 @@ except ImportError:
has_fx_feature_extraction = False
import timm
from timm import list_models, create_model, set_scriptable, has_pretrained_cfg_key, is_pretrained_cfg_key, \
get_pretrained_cfg_value
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'):

@ -1,4 +1,4 @@
from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \
get_pretrained_cfg_value, is_model_pretrained
is_scriptable, is_exportable, set_scriptable, set_exportable, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -13,64 +13,64 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
# Resolve input/image size
in_chans = 3
if 'chans' in args and args['chans'] is not None:
if args.get('chans', None) is not None:
in_chans = args['chans']
input_size = (in_chans, 224, 224)
if 'input_size' in args and args['input_size'] is not None:
if args.get('input_size', None) is not None:
assert isinstance(args['input_size'], (tuple, list))
assert len(args['input_size']) == 3
input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None:
elif args.get('img_size', None) is not None:
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
else:
if use_test_size and 'test_input_size' in default_cfg:
if use_test_size and default_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size']
elif 'input_size' in default_cfg:
elif default_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']:
if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg:
elif default_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'mean' in args and args['mean'] is not None:
if args.get('mean', None) is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif 'mean' in default_cfg:
elif default_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD
if 'std' in args and args['std'] is not None:
if args.get('std', None) is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif 'std' in default_cfg:
elif default_cfg.get('std', None):
new_config['std'] = default_cfg['std']
# resolve default crop percentage
crop_pct = DEFAULT_CROP_PCT
if 'crop_pct' in args and args['crop_pct'] is not None:
if args.get('crop_pct', None):
crop_pct = args['crop_pct']
else:
if use_test_size and 'test_crop_pct' in default_cfg:
if use_test_size and default_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct']
elif 'crop_pct' in default_cfg:
elif default_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct

@ -71,4 +71,4 @@ from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -0,0 +1,102 @@
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace
from typing import Any, Deque, Dict, Tuple, Optional, Union
@dataclass
class PretrainedCfg:
"""
"""
# weight locations
url: str = ''
file: str = ''
hf_hub_id: str = ''
hf_hub_filename: str = ''
source: str = '' # source of cfg / weight location used (url, file, hf-hub)
architecture: str = '' # architecture variant can be set when not implicit
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
input_size: Tuple[int, int, int] = (3, 224, 224)
test_input_size: Optional[Tuple[int, int, int]] = None
min_input_size: Optional[Tuple[int, int, int]] = None
fixed_input_size: bool = False
interpolation: str = 'bicubic'
crop_pct: float = 0.875
test_crop_pct: Optional[float] = None
crop_type: str = 'pct'
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
# head config
num_classes: int = 1000
label_offset: int = 0
# model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None
test_pool_size: Optional[Tuple[int, ...]] = None
first_conv: str = ''
classifier: str = ''
license: str = ''
source_url: str = ''
paper: str = ''
notes: str = ''
@property
def has_weights(self):
return self.url.startswith('http') or self.file or self.hf_hub_id
@dataclass
class DefaultCfg:
tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
is_pretrained: bool = False # at least one of the configs has a pretrained source set
@property
def default(self):
return self.cfgs[self.tags[0]]
@property
def default_with_tag(self):
tag = self.tags[0]
return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''):
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = not tag or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if has_weights:
default_cfg.is_pretrained = True
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_set:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
default_cfg.cfgs[tag] = v
return out

@ -1,14 +1,18 @@
from urllib.parse import urlsplit, urlunsplit
import os
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from .registry import is_model, is_model_in_modules, model_entrypoint
from ._pretrained import PretrainedCfg, split_model_name_tag
from .helpers import load_checkpoint
from .layers import set_layer_config
from .hub import load_model_config_from_hf
from .layers import set_layer_config
from .registry import is_model, model_entrypoint
def parse_model_name(model_name):
model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub')
if parsed.scheme == 'hf-hub':
@ -20,6 +24,7 @@ def parse_model_name(model_name):
def safe_model_name(model_name, remove_source=True):
# return a filename / path safe model name
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source:
@ -28,20 +33,29 @@ def safe_model_name(model_name, remove_source=True):
def create_model(
model_name,
pretrained=False,
pretrained_cfg=None,
checkpoint_path='',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs):
model_name: str,
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = '',
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
**kwargs,
):
"""Create a model
Lookup model's entrypoint function and pass relevant args to create a new model.
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
and then the model class __init__(). kwargs values set to None are pruned before passing.
Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
checkpoint_path (str): path of checkpoint to load after model is initialized
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
@ -49,7 +63,7 @@ def create_model(
Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
**: other kwargs are consumed by builder or model __init__()
"""
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
@ -58,17 +72,27 @@ def create_model(
model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub':
# FIXME hf-hub source overrides any passed in pretrained_cfg, warn?
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
if not is_model(model_name):
raise RuntimeError('Unknown model (%s)' % model_name)
create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs)
model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
**kwargs,
)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)

@ -3,6 +3,7 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import collections.abc
import dataclasses
import logging
import math
import os
@ -17,6 +18,7 @@ import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torch.utils.checkpoint import checkpoint
from ._pretrained import PretrainedCfg
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
@ -194,7 +196,11 @@ def load_custom_pretrained(
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
@ -202,7 +208,11 @@ def load_custom_pretrained(
if load_from == 'hf-hub': # FIXME
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS)
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
)
if load_fn is not None:
load_fn(model, pretrained_loc)
@ -250,13 +260,17 @@ def load_pretrained(
Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for model
in_chans (int): in_chans for model
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
@ -264,7 +278,11 @@ def load_pretrained(
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
@ -428,40 +446,20 @@ def adapt_model_from_file(parent_module, model_variant):
def pretrained_cfg_for_features(pretrained_cfg):
pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove:
pretrained_cfg.pop(tr, None)
return pretrained_cfg
def set_default_kwargs(kwargs, names, pretrained_cfg):
for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
def filter_kwargs(kwargs, names):
def _filter_kwargs(kwargs, names):
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)
def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
Args:
@ -474,31 +472,61 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
if pretrained_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg)
for n in default_kwarg_names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
_filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
if pretrained_cfg and isinstance(pretrained_cfg, dict):
# highest priority, pretrained_cfg available and passed as arg
return deepcopy(pretrained_cfg)
def resolve_pretrained_cfg(
variant: str,
pretrained_cfg=None,
pretrained_cfg_overlay=None,
) -> PretrainedCfg:
model_with_tag = variant
pretrained_tag = None
if pretrained_cfg:
if isinstance(pretrained_cfg, dict):
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
elif isinstance(pretrained_cfg, str):
pretrained_tag = pretrained_cfg
pretrained_cfg = None
# fallback to looking up pretrained cfg in model registry by variant identifier
pretrained_cfg = get_pretrained_cfg(variant)
if not pretrained_cfg:
if pretrained_tag:
model_with_tag = '.'.join([variant, pretrained_tag])
pretrained_cfg = get_pretrained_cfg(model_with_tag)
if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {variant} model. Using a default."
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = dict(
url='',
num_classes=1000,
input_size=(3, 224, 224),
pool_size=None,
crop_pct=.9,
interpolation='bicubic',
first_conv='',
classifier='',
)
pretrained_cfg = PretrainedCfg() # instance with defaults
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
if not pretrained_cfg.architecture:
pretrained_cfg_overlay.setdefault('architecture', variant)
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
return pretrained_cfg
@ -507,13 +535,14 @@ def build_model_with_cfg(
variant: str,
pretrained: bool,
pretrained_cfg: Optional[Dict] = None,
pretrained_cfg_overlay: Optional[Dict] = None,
model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
pretrained_custom_load: bool = False,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs):
**kwargs,
):
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including:
@ -531,7 +560,6 @@ def build_model_with_cfg(
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
@ -540,9 +568,16 @@ def build_model_with_cfg(
feature_cfg = feature_cfg or {}
# resolve and update model pretrained config and model kwargs
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg)
update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter)
pretrained_cfg.setdefault('architecture', variant)
pretrained_cfg = resolve_pretrained_cfg(
variant,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay
)
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = dataclasses.asdict(pretrained_cfg)
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):
@ -551,8 +586,11 @@ def build_model_with_cfg(
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
# Build the model
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
@ -562,9 +600,11 @@ def build_model_with_cfg(
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
# FIXME improve custom load trigger
load_custom_pretrained(model, pretrained_cfg=pretrained_cfg)
if pretrained_cfg.get('custom_load', False):
load_custom_pretrained(
model,
pretrained_cfg=pretrained_cfg,
)
else:
load_pretrained(
model,
@ -572,7 +612,8 @@ def build_model_with_cfg(
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict)
strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled
if features:

@ -27,24 +27,23 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# original PyTorch weights, ported from Tensorflow but modified
'inception_v3': _cfg(
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
has_aux=True), # checkpoint has aux logit layer weights
# NOTE checkpoint has aux logit layer weights
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'),
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
'tf_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
num_classes=1000, has_aux=False, label_offset=1),
num_classes=1000, label_offset=1),
# my port of Tensorflow adversarially trained Inception V3 from
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
'adv_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
num_classes=1000, has_aux=False, label_offset=1),
num_classes=1000, label_offset=1),
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
# https://gluon-cv.mxnet.io/model_zoo/classification.html
'gluon_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
has_aux=False,
)
}
@ -433,10 +432,10 @@ def _create_inception_v3(variant, pretrained=False, **kwargs):
if aux_logits:
assert not kwargs.pop('features_only', False)
model_cls = InceptionV3Aux
load_strict = pretrained_cfg['has_aux']
load_strict = variant == 'inception_v3'
else:
model_cls = InceptionV3
load_strict = not pretrained_cfg['has_aux']
load_strict = variant != 'inception_v3'
return build_model_with_cfg(
model_cls, variant, pretrained,

@ -2,20 +2,30 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import sys
import re
import fnmatch
from collections import defaultdict
import re
import sys
from collections import defaultdict, deque
from copy import deepcopy
from typing import Optional, Tuple
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained']
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
__all__ = [
'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_pretrained_cfgs = dict() # central repo for model default_cfgs
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
return split_model_name_tag(model_name)[0]
def register_model(fn):
@ -35,19 +45,37 @@ def register_model(fn):
_model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name)
has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name]
has_valid_pretrained = (
('url' in cfg and 'http' in cfg['url']) or
('file' in cfg and cfg['file']) or
('hf_hub_id' in cfg and cfg['hf_hub_id'])
)
_model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name]
if has_valid_pretrained:
if not isinstance(cfg, DefaultCfg):
# new style default cfg dataclass w/ multiple entries per model-arch
assert isinstance(cfg, dict)
# old style cfg dict per model-arch
cfg = PretrainedCfg(**cfg)
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
for tag_idx, tag in enumerate(cfg.tags):
is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag]
if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name)
if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid
_model_has_pretrained.add(model_name_tag)
_model_with_tags[model_name].append(model_name_tag)
else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg
return fn
@ -55,24 +83,39 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
def list_models(
filter: str = '',
module: str = '',
pretrained=False,
exclude_filters: str = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None,
):
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if include_tags is None:
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained
if module:
all_models = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
# FIXME wildcard filter tag as well as model arch name
if filter:
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
@ -82,6 +125,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
@ -89,23 +133,35 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in models:
models_with_tags.extend(_model_with_tags[m])
models = models_with_tags
if pretrained:
models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))
def is_model(model_name):
""" Check if a model name exists
"""
return model_name in _model_entrypoints
arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints
def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name
"""
return _model_entrypoints[model_name]
arch_name = get_arch_name(model_name)
return _model_entrypoints[arch_name]
def list_modules():
@ -121,8 +177,9 @@ def is_model_in_modules(model_name, module_names):
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names)
return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name):
@ -132,28 +189,12 @@ def is_model_pretrained(model_name):
def get_pretrained_cfg(model_name):
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
return {}
def has_pretrained_cfg_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key.
"""
if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]:
return True
return False
def is_pretrained_cfg_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist.
"""
if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False):
return True
return False
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist.
""" Get a specific model default_cfg value by key. None if key doesn't exist.
"""
if model_name in _model_pretrained_cfgs:
return _model_pretrained_cfgs[model_name].get(cfg_key, None)
return None
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
raise RuntimeError(f'No pretrained config exist for model {model_name}.')

@ -76,6 +76,9 @@ model_cfgs = dict(
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25),
regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25),
regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25),
# Experimental
regnety_040s_gn=RegNetCfg(
@ -150,7 +153,12 @@ default_cfgs = dict(
regnety_160=_cfg(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_320=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
),
regnety_640=_cfg(url=''),
regnety_1280=_cfg(url=''),
regnety_2560=_cfg(url=''),
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(
@ -508,6 +516,34 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'classy_state_dict' in state_dict:
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']
out = {}
for k, v in state_dict['trunk'].items():
k = k.replace('_feature_blocks.conv1.stem.0', 'stem.conv')
k = k.replace('_feature_blocks.conv1.stem.1', 'stem.bn')
k = re.sub(
r'^_feature_blocks.res\d.block(\d)-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k)
k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k)
k = k.replace('proj', 'downsample.conv')
k = k.replace('f.a.0', 'conv1.conv')
k = k.replace('f.a.1', 'conv1.bn')
k = k.replace('f.b.0', 'conv2.conv')
k = k.replace('f.b.1', 'conv2.bn')
k = k.replace('f.c', 'conv3.conv')
k = k.replace('f.final_bn', 'conv3.bn')
k = k.replace('f.se.excitation.0', 'se.fc1')
k = k.replace('f.se.excitation.2', 'se.fc2')
out[k] = v
for k, v in state_dict['heads'].items():
if 'projection_head' in k or 'prototypes' in k:
continue
k = k.replace('0.clf.0', 'head.fc')
out[k] = v
return out
if 'model' in state_dict:
# For DeiT trained regnety_160 pretraiend model
state_dict = state_dict['model']
@ -666,6 +702,24 @@ def regnety_320(pretrained=False, **kwargs):
return _create_regnet('regnety_320', pretrained, **kwargs)
@register_model
def regnety_640(pretrained=False, **kwargs):
"""RegNetY-64GF"""
return _create_regnet('regnety_640', pretrained, **kwargs)
@register_model
def regnety_1280(pretrained=False, **kwargs):
"""RegNetY-128GF"""
return _create_regnet('regnety_1280', pretrained, **kwargs)
@register_model
def regnety_2560(pretrained=False, **kwargs):
"""RegNetY-256GF"""
return _create_regnet('regnety_2560', pretrained, **kwargs)
@register_model
def regnety_040s_gn(pretrained=False, **kwargs):
"""RegNetY-4.0GF w/ GroupNorm """

@ -57,52 +57,52 @@ default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_50x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x2_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x4_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480?
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
# trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
num_classes=21843),
num_classes=21843, custom_load=True),
'resnetv2_50x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
num_classes=21843),
num_classes=21843, custom_load=True),
'resnetv2_101x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
num_classes=21843),
num_classes=21843, custom_load=True),
'resnetv2_101x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
num_classes=21843),
num_classes=21843, custom_load=True),
'resnetv2_152x2_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
num_classes=21843),
num_classes=21843, custom_load=True),
'resnetv2_152x4_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843),
num_classes=21843, custom_load=True),
'resnetv2_50x1_bit_distilled': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
interpolation='bicubic'),
interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit_teacher': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
interpolation='bicubic'),
interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit_teacher_384': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
'resnetv2_50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth',
@ -507,8 +507,8 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNetV2, variant, pretrained,
feature_cfg=feature_cfg,
pretrained_custom_load='_bit' in variant,
**kwargs)
**kwargs,
)
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):

@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from ._pretrained import generate_defaults
from .registry import register_model
_logger = logging.getLogger(__name__)
@ -50,59 +51,50 @@ def _cfg(url='', **kwargs):
}
default_cfgs = {
default_cfgs = generate_defaults({
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_tiny_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_base_patch32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
),
'vit_large_patch32_384': _cfg(
'vit_tiny_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_tiny_patch16_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_small_patch32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_small_patch16_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True),
'vit_base_patch32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_base_patch16_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_large_patch32_384.in21ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True),
'vit_large_patch16_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch14_224': _cfg(url=''),
'vit_huge_patch14_224': _cfg(url=''),
@ -111,92 +103,128 @@ default_cfgs = {
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
'vit_tiny_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch32_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_small_patch32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch16_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_small_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch32_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_base_patch32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch16_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_base_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch8_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_base_patch8_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_large_patch32_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_large_patch32_224.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_large_patch16_224_in21k': _cfg(
'vit_large_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg(
custom_load=True, num_classes=21843),
'vit_huge_patch14_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
num_classes=21843),
custom_load=True, num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_224_sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
'vit_base_patch16_224_sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
'vit_base_patch32_224.sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True),
'vit_base_patch16_224.sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True),
# DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
'vit_small_patch16_224_dino': _cfg(
'vit_small_patch16_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_small_patch8_224_dino': _cfg(
'vit_small_patch8_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch16_224_dino': _cfg(
'vit_base_patch16_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch8_224_dino': _cfg(
'vit_base_patch8_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg(
'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
'vit_base_patch16_224_miil': _cfg(
'vit_base_patch16_224_miil.in21ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
# custom timm variants
'vit_base_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
# experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),
'vit_base_patch32_224_clip_laion2b': _cfg(
'vit_medium_patch16_gap_240.in12k': _cfg(
url='',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
'vit_medium_patch16_gap_256.in12ft1k': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_gap_384.in12ft1k': _cfg(url='', input_size=(3, 384, 384), crop_pct=0.95),
# CLIP pretrained image tower and related fine-tuned weights
'vit_base_patch32_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_224_clip_laion2b': _cfg(
'vit_large_patch14_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
'vit_huge_patch14_224_clip_laion2b': _cfg(
'vit_huge_patch14_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
'vit_giant_patch14_224_clip_laion2b': _cfg(
'vit_giant_patch14_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
}
'vit_base_patch32_224_clip.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_large_patch14_224_clip.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'vit_huge_patch14_224_clip.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_224_clip.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_large_patch14_224_clip.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'vit_huge_patch14_224_clip.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_224_clip.laion2b_ft_in12k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_large_patch14_224_clip.laion2b_ft_in12k': _cfg(
hf_hub_id='',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=11821),
'vit_huge_patch14_224_clip.laion2b_ft_in12k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
# experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),
})
class Attention(nn.Module):
@ -782,14 +810,11 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
model = build_model_with_cfg(
return build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in pretrained_cfg['url'],
**kwargs)
return model
**kwargs,
)
@register_model
@ -831,7 +856,6 @@ def vit_small_patch32_384(pretrained=False, **kwargs):
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
@ -841,13 +865,21 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
@register_model
def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch8_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8)
"""
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
@ -974,175 +1006,90 @@ def vit_gigantic_patch14_224(pretrained=False, **kwargs):
@register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
def vit_base_patch16_224_miil(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
def vit_base_patch32_224_clip(pretrained=False, **kwargs):
""" ViT-B/32
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
def vit_large_patch14_224_clip(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
def vit_huge_patch14_224_clip(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224_dino(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch8_224_dino(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_dino(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch8_224_dino(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_miil(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
def vit_giant_patch14_224_clip(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip', pretrained=pretrained, **model_kwargs)
return model
@ -1211,46 +1158,3 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-B/32
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model

@ -20,6 +20,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._pretrained import generate_defaults
from .layers import StdConv2dSame, StdConv2d, to_2tuple
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -38,52 +39,49 @@ def _cfg(url='', **kwargs):
}
default_cfgs = {
default_cfgs = generate_defaults({
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
'vit_tiny_r_s16_p8_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r26_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
'vit_tiny_r_s16_p8_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
),
'vit_small_r26_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r26_s32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg(
'vit_base_r50_s16_384.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
'vit_large_r50_s32_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True,
),
'vit_large_r50_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0
'vit_large_r50_s32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224_in21k': _cfg(
'vit_tiny_r_s16_p8_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
'vit_small_r26_s32_224_in21k': _cfg(
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
'vit_small_r26_s32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
'vit_base_r50_s16_224_in21k': _cfg(
num_classes=21843, crop_pct=0.9, custom_load=True),
'vit_base_r50_s16_224.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(
'vit_large_r50_s32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
num_classes=21843, crop_pct=0.9, custom_load=True),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(
@ -94,7 +92,7 @@ default_cfgs = {
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
}
})
class HybridEmbed(nn.Module):
@ -248,12 +246,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
@ -276,57 +268,6 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs):
return model
@register_model
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

Loading…
Cancel
Save