Initial multi-weight support, handled so old pretraing config handling co-exists with new tags.

pull/1520/head
Ross Wightman 2 years ago
parent 475ecdfa3d
commit ebb99a1f8d

@ -13,8 +13,7 @@ except ImportError:
has_fx_feature_extraction = False has_fx_feature_extraction = False
import timm import timm
from timm import list_models, create_model, set_scriptable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'): if hasattr(torch._C, '_jit_set_profiling_executor'):

@ -1,4 +1,4 @@
from .version import __version__ from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ is_scriptable, is_exportable, set_scriptable, set_exportable, \
get_pretrained_cfg_value, is_model_pretrained is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -13,64 +13,64 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
# Resolve input/image size # Resolve input/image size
in_chans = 3 in_chans = 3
if 'chans' in args and args['chans'] is not None: if args.get('chans', None) is not None:
in_chans = args['chans'] in_chans = args['chans']
input_size = (in_chans, 224, 224) input_size = (in_chans, 224, 224)
if 'input_size' in args and args['input_size'] is not None: if args.get('input_size', None) is not None:
assert isinstance(args['input_size'], (tuple, list)) assert isinstance(args['input_size'], (tuple, list))
assert len(args['input_size']) == 3 assert len(args['input_size']) == 3
input_size = tuple(args['input_size']) input_size = tuple(args['input_size'])
in_chans = input_size[0] # input_size overrides in_chans in_chans = input_size[0] # input_size overrides in_chans
elif 'img_size' in args and args['img_size'] is not None: elif args.get('img_size', None) is not None:
assert isinstance(args['img_size'], int) assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size']) input_size = (in_chans, args['img_size'], args['img_size'])
else: else:
if use_test_size and 'test_input_size' in default_cfg: if use_test_size and default_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size'] input_size = default_cfg['test_input_size']
elif 'input_size' in default_cfg: elif default_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size'] input_size = default_cfg['input_size']
new_config['input_size'] = input_size new_config['input_size'] = input_size
# resolve interpolation method # resolve interpolation method
new_config['interpolation'] = 'bicubic' new_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']: if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation'] new_config['interpolation'] = args['interpolation']
elif 'interpolation' in default_cfg: elif default_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation'] new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization # resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN new_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'mean' in args and args['mean'] is not None: if args.get('mean', None) is not None:
mean = tuple(args['mean']) mean = tuple(args['mean'])
if len(mean) == 1: if len(mean) == 1:
mean = tuple(list(mean) * in_chans) mean = tuple(list(mean) * in_chans)
else: else:
assert len(mean) == in_chans assert len(mean) == in_chans
new_config['mean'] = mean new_config['mean'] = mean
elif 'mean' in default_cfg: elif default_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean'] new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization # resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD new_config['std'] = IMAGENET_DEFAULT_STD
if 'std' in args and args['std'] is not None: if args.get('std', None) is not None:
std = tuple(args['std']) std = tuple(args['std'])
if len(std) == 1: if len(std) == 1:
std = tuple(list(std) * in_chans) std = tuple(list(std) * in_chans)
else: else:
assert len(std) == in_chans assert len(std) == in_chans
new_config['std'] = std new_config['std'] = std
elif 'std' in default_cfg: elif default_cfg.get('std', None):
new_config['std'] = default_cfg['std'] new_config['std'] = default_cfg['std']
# resolve default crop percentage # resolve default crop percentage
crop_pct = DEFAULT_CROP_PCT crop_pct = DEFAULT_CROP_PCT
if 'crop_pct' in args and args['crop_pct'] is not None: if args.get('crop_pct', None):
crop_pct = args['crop_pct'] crop_pct = args['crop_pct']
else: else:
if use_test_size and 'test_crop_pct' in default_cfg: if use_test_size and default_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct'] crop_pct = default_cfg['test_crop_pct']
elif 'crop_pct' in default_cfg: elif default_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct'] crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct new_config['crop_pct'] = crop_pct

@ -71,4 +71,4 @@ from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

@ -0,0 +1,102 @@
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace
from typing import Any, Deque, Dict, Tuple, Optional, Union
@dataclass
class PretrainedCfg:
"""
"""
# weight locations
url: str = ''
file: str = ''
hf_hub_id: str = ''
hf_hub_filename: str = ''
source: str = '' # source of cfg / weight location used (url, file, hf-hub)
architecture: str = '' # architecture variant can be set when not implicit
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
input_size: Tuple[int, int, int] = (3, 224, 224)
test_input_size: Optional[Tuple[int, int, int]] = None
min_input_size: Optional[Tuple[int, int, int]] = None
fixed_input_size: bool = False
interpolation: str = 'bicubic'
crop_pct: float = 0.875
test_crop_pct: Optional[float] = None
crop_type: str = 'pct'
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
# head config
num_classes: int = 1000
label_offset: int = 0
# model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None
test_pool_size: Optional[Tuple[int, ...]] = None
first_conv: str = ''
classifier: str = ''
license: str = ''
source_url: str = ''
paper: str = ''
notes: str = ''
@property
def has_weights(self):
return self.url.startswith('http') or self.file or self.hf_hub_id
@dataclass
class DefaultCfg:
tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
is_pretrained: bool = False # at least one of the configs has a pretrained source set
@property
def default(self):
return self.cfgs[self.tags[0]]
@property
def default_with_tag(self):
tag = self.tags[0]
return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''):
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = not tag or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if has_weights:
default_cfg.is_pretrained = True
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_set:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
default_cfg.cfgs[tag] = v
return out

@ -1,14 +1,18 @@
from urllib.parse import urlsplit, urlunsplit
import os import os
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from .registry import is_model, is_model_in_modules, model_entrypoint from ._pretrained import PretrainedCfg, split_model_name_tag
from .helpers import load_checkpoint from .helpers import load_checkpoint
from .layers import set_layer_config
from .hub import load_model_config_from_hf from .hub import load_model_config_from_hf
from .layers import set_layer_config
from .registry import is_model, model_entrypoint
def parse_model_name(model_name): def parse_model_name(model_name):
model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name) parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub') assert parsed.scheme in ('', 'timm', 'hf-hub')
if parsed.scheme == 'hf-hub': if parsed.scheme == 'hf-hub':
@ -20,6 +24,7 @@ def parse_model_name(model_name):
def safe_model_name(model_name, remove_source=True): def safe_model_name(model_name, remove_source=True):
# return a filename / path safe model name
def make_safe(name): def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source: if remove_source:
@ -28,20 +33,29 @@ def safe_model_name(model_name, remove_source=True):
def create_model( def create_model(
model_name, model_name: str,
pretrained=False, pretrained: bool = False,
pretrained_cfg=None, pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
checkpoint_path='', pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
scriptable=None, checkpoint_path: str = '',
exportable=None, scriptable: Optional[bool] = None,
no_jit=None, exportable: Optional[bool] = None,
**kwargs): no_jit: Optional[bool] = None,
**kwargs,
):
"""Create a model """Create a model
Lookup model's entrypoint function and pass relevant args to create a new model.
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
and then the model class __init__(). kwargs values set to None are pruned before passing.
Args: Args:
model_name (str): name of model to instantiate model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true pretrained (bool): load pretrained ImageNet-1k weights if true
checkpoint_path (str): path of checkpoint to load after model is initialized pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
@ -49,7 +63,7 @@ def create_model(
Keyword Args: Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0) drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg') global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific **: other kwargs are consumed by builder or model __init__()
""" """
# Parameters that aren't supported by all models or are intended to only override model defaults if set # Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that # should default to None in command line args/cfg. Remove them if they are present and not set so that
@ -58,17 +72,27 @@ def create_model(
model_source, model_name = parse_model_name(model_name) model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub': if model_source == 'hf-hub':
# FIXME hf-hub source overrides any passed in pretrained_cfg, warn? assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`, # For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub. # load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name) pretrained_cfg, model_name = load_model_config_from_hf(model_name)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
if not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag
if not is_model(model_name): if not is_model(model_name):
raise RuntimeError('Unknown model (%s)' % model_name) raise RuntimeError('Unknown model (%s)' % model_name)
create_fn = model_entrypoint(model_name) create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) model = create_fn(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
**kwargs,
)
if checkpoint_path: if checkpoint_path:
load_checkpoint(model, checkpoint_path) load_checkpoint(model, checkpoint_path)

@ -3,6 +3,7 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import collections.abc import collections.abc
import dataclasses
import logging import logging
import math import math
import os import os
@ -17,6 +18,7 @@ import torch.nn as nn
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from ._pretrained import PretrainedCfg
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
@ -191,10 +193,14 @@ def load_custom_pretrained(
Args: Args:
model: The instantiated model to load weights into model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg pretrained_cfg (dict): Default pretrained model cfg
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists 'laod_pretrained' on the model will be called if it exists
""" """
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if not load_from: if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
@ -202,7 +208,11 @@ def load_custom_pretrained(
if load_from == 'hf-hub': # FIXME if load_from == 'hf-hub': # FIXME
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url': elif load_from == 'url':
pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS) pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
)
if load_fn is not None: if load_fn is not None:
load_fn(model, pretrained_loc) load_fn(model, pretrained_loc)
@ -250,13 +260,17 @@ def load_pretrained(
Args: Args:
model (nn.Module) : PyTorch model module model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for model num_classes (int): num_classes for target model
in_chans (int): in_chans for model in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint strict (bool): strict load of checkpoint
""" """
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file': if load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})') _logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
@ -264,7 +278,11 @@ def load_pretrained(
elif load_from == 'url': elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})') _logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url( state_dict = load_state_dict_from_url(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub': elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)): if isinstance(pretrained_loc, (list, tuple)):
@ -428,40 +446,20 @@ def adapt_model_from_file(parent_module, model_variant):
def pretrained_cfg_for_features(pretrained_cfg): def pretrained_cfg_for_features(pretrained_cfg):
pretrained_cfg = deepcopy(pretrained_cfg) pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone # remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove: for tr in to_remove:
pretrained_cfg.pop(tr, None) pretrained_cfg.pop(tr, None)
return pretrained_cfg return pretrained_cfg
def set_default_kwargs(kwargs, names, pretrained_cfg): def _filter_kwargs(kwargs, names):
for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
def filter_kwargs(kwargs, names):
if not kwargs or not names: if not kwargs or not names:
return return
for n in names: for n in names:
kwargs.pop(n, None) kwargs.pop(n, None)
def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model """ Update the default_cfg and kwargs before passing to model
Args: Args:
@ -474,31 +472,61 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
if pretrained_cfg.get('fixed_input_size', False): if pretrained_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',) default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg)
for n in default_kwarg_names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
# Filter keyword args for task specific model variants (some 'features only' models, etc.) # Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter) _filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None): def resolve_pretrained_cfg(
if pretrained_cfg and isinstance(pretrained_cfg, dict): variant: str,
# highest priority, pretrained_cfg available and passed as arg pretrained_cfg=None,
return deepcopy(pretrained_cfg) pretrained_cfg_overlay=None,
) -> PretrainedCfg:
model_with_tag = variant
pretrained_tag = None
if pretrained_cfg:
if isinstance(pretrained_cfg, dict):
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
elif isinstance(pretrained_cfg, str):
pretrained_tag = pretrained_cfg
pretrained_cfg = None
# fallback to looking up pretrained cfg in model registry by variant identifier # fallback to looking up pretrained cfg in model registry by variant identifier
pretrained_cfg = get_pretrained_cfg(variant) if not pretrained_cfg:
if pretrained_tag:
model_with_tag = '.'.join([variant, pretrained_tag])
pretrained_cfg = get_pretrained_cfg(model_with_tag)
if not pretrained_cfg: if not pretrained_cfg:
_logger.warning( _logger.warning(
f"No pretrained configuration specified for {variant} model. Using a default." f"No pretrained configuration specified for {model_with_tag} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.") f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = dict( pretrained_cfg = PretrainedCfg() # instance with defaults
url='',
num_classes=1000, pretrained_cfg_overlay = pretrained_cfg_overlay or {}
input_size=(3, 224, 224), if not pretrained_cfg.architecture:
pool_size=None, pretrained_cfg_overlay.setdefault('architecture', variant)
crop_pct=.9, pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
interpolation='bicubic',
first_conv='',
classifier='',
)
return pretrained_cfg return pretrained_cfg
@ -507,13 +535,14 @@ def build_model_with_cfg(
variant: str, variant: str,
pretrained: bool, pretrained: bool,
pretrained_cfg: Optional[Dict] = None, pretrained_cfg: Optional[Dict] = None,
pretrained_cfg_overlay: Optional[Dict] = None,
model_cfg: Optional[Any] = None, model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None, feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True, pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None, pretrained_filter_fn: Optional[Callable] = None,
pretrained_custom_load: bool = False,
kwargs_filter: Optional[Tuple[str]] = None, kwargs_filter: Optional[Tuple[str]] = None,
**kwargs): **kwargs,
):
""" Build model with specified default_cfg and optional model_cfg """ Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including: This helper fn aids in the construction of a model including:
@ -531,7 +560,6 @@ def build_model_with_cfg(
feature_cfg (Optional[Dict]: feature extraction adapter config feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__ **kwargs: model args passed through to model __init__
""" """
@ -540,9 +568,16 @@ def build_model_with_cfg(
feature_cfg = feature_cfg or {} feature_cfg = feature_cfg or {}
# resolve and update model pretrained config and model kwargs # resolve and update model pretrained config and model kwargs
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg) pretrained_cfg = resolve_pretrained_cfg(
update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter) variant,
pretrained_cfg.setdefault('architecture', variant) pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay
)
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = dataclasses.asdict(pretrained_cfg)
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
# Setup for feature extraction wrapper done at end of this fn # Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False): if kwargs.pop('features_only', False):
@ -551,8 +586,11 @@ def build_model_with_cfg(
if 'out_indices' in kwargs: if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices') feature_cfg['out_indices'] = kwargs.pop('out_indices')
# Build the model # Instantiate the model
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
model.pretrained_cfg = pretrained_cfg model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat model.default_cfg = model.pretrained_cfg # alias for backwards compat
@ -562,9 +600,11 @@ def build_model_with_cfg(
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained: if pretrained:
if pretrained_custom_load: if pretrained_cfg.get('custom_load', False):
# FIXME improve custom load trigger load_custom_pretrained(
load_custom_pretrained(model, pretrained_cfg=pretrained_cfg) model,
pretrained_cfg=pretrained_cfg,
)
else: else:
load_pretrained( load_pretrained(
model, model,
@ -572,7 +612,8 @@ def build_model_with_cfg(
num_classes=num_classes_pretrained, num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3), in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, filter_fn=pretrained_filter_fn,
strict=pretrained_strict) strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled # Wrap the model in a feature extraction module if enabled
if features: if features:

@ -27,24 +27,23 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
# original PyTorch weights, ported from Tensorflow but modified # original PyTorch weights, ported from Tensorflow but modified
'inception_v3': _cfg( 'inception_v3': _cfg(
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', # NOTE checkpoint has aux logit layer weights
has_aux=True), # checkpoint has aux logit layer weights url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'),
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
'tf_inception_v3': _cfg( 'tf_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
num_classes=1000, has_aux=False, label_offset=1), num_classes=1000, label_offset=1),
# my port of Tensorflow adversarially trained Inception V3 from # my port of Tensorflow adversarially trained Inception V3 from
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
'adv_inception_v3': _cfg( 'adv_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
num_classes=1000, has_aux=False, label_offset=1), num_classes=1000, label_offset=1),
# from gluon pretrained models, best performing in terms of accuracy/loss metrics # from gluon pretrained models, best performing in terms of accuracy/loss metrics
# https://gluon-cv.mxnet.io/model_zoo/classification.html # https://gluon-cv.mxnet.io/model_zoo/classification.html
'gluon_inception_v3': _cfg( 'gluon_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
has_aux=False,
) )
} }
@ -433,10 +432,10 @@ def _create_inception_v3(variant, pretrained=False, **kwargs):
if aux_logits: if aux_logits:
assert not kwargs.pop('features_only', False) assert not kwargs.pop('features_only', False)
model_cls = InceptionV3Aux model_cls = InceptionV3Aux
load_strict = pretrained_cfg['has_aux'] load_strict = variant == 'inception_v3'
else: else:
model_cls = InceptionV3 model_cls = InceptionV3
load_strict = not pretrained_cfg['has_aux'] load_strict = variant != 'inception_v3'
return build_model_with_cfg( return build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,

@ -2,20 +2,30 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import sys
import re
import fnmatch import fnmatch
from collections import defaultdict import re
import sys
from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from typing import Optional, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', __all__ = [
'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] 'list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns _model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_pretrained_cfgs = dict() # central repo for model default_cfgs _model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
return split_model_name_tag(model_name)[0]
def register_model(fn): def register_model(fn):
@ -35,19 +45,37 @@ def register_model(fn):
_model_entrypoints[model_name] = fn _model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name _model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name) _module_to_models[module_name].add(model_name)
has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name] cfg = mod.default_cfgs[model_name]
has_valid_pretrained = ( if not isinstance(cfg, DefaultCfg):
('url' in cfg and 'http' in cfg['url']) or # new style default cfg dataclass w/ multiple entries per model-arch
('file' in cfg and cfg['file']) or assert isinstance(cfg, dict)
('hf_hub_id' in cfg and cfg['hf_hub_id']) # old style cfg dict per model-arch
) cfg = PretrainedCfg(**cfg)
_model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
if has_valid_pretrained:
_model_has_pretrained.add(model_name) for tag_idx, tag in enumerate(cfg.tags):
is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag]
if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name)
if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid
_model_has_pretrained.add(model_name_tag)
_model_with_tags[model_name].append(model_name_tag)
else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg
return fn return fn
@ -55,24 +83,39 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): def list_models(
filter: str = '',
module: str = '',
pretrained=False,
exclude_filters: str = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None,
):
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with pretrained weights if True pretrained (bool) - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None)
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
""" """
if include_tags is None:
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained
if module: if module:
all_models = list(_module_to_models[module]) all_models = list(_module_to_models[module])
else: else:
all_models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
# FIXME wildcard filter tag as well as model arch name
if filter: if filter:
models = [] models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter] include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
@ -82,6 +125,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
models = set(models).union(include_models) models = set(models).union(include_models)
else: else:
models = all_models models = all_models
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters] exclude_filters = [exclude_filters]
@ -89,23 +133,35 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
exclude_models = fnmatch.filter(models, xf) # exclude these models exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models): if len(exclude_models):
models = set(models).difference(exclude_models) models = set(models).difference(exclude_models)
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
for m in models:
models_with_tags.extend(_model_with_tags[m])
models = models_with_tags
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
if name_matches_cfg: if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models) models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))
def is_model(model_name): def is_model(model_name):
""" Check if a model name exists """ Check if a model name exists
""" """
return model_name in _model_entrypoints arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints
def model_entrypoint(model_name): def model_entrypoint(model_name):
"""Fetch a model entrypoint for specified model name """Fetch a model entrypoint for specified model name
""" """
return _model_entrypoints[model_name] arch_name = get_arch_name(model_name)
return _model_entrypoints[arch_name]
def list_modules(): def list_modules():
@ -121,8 +177,9 @@ def is_model_in_modules(model_name, module_names):
model_name (str) - name of model to check model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in module_names (tuple, list, set) - names of modules to search in
""" """
arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set)) assert isinstance(module_names, (tuple, list, set))
return any(model_name in _module_to_models[n] for n in module_names) return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name): def is_model_pretrained(model_name):
@ -132,28 +189,12 @@ def is_model_pretrained(model_name):
def get_pretrained_cfg(model_name): def get_pretrained_cfg(model_name):
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name]) return deepcopy(_model_pretrained_cfgs[model_name])
return {} raise RuntimeError(f'No pretrained config exists for model {model_name}.')
def has_pretrained_cfg_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key.
"""
if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]:
return True
return False
def is_pretrained_cfg_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist.
"""
if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False):
return True
return False
def get_pretrained_cfg_value(model_name, cfg_key): def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist. """ Get a specific model default_cfg value by key. None if key doesn't exist.
""" """
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return _model_pretrained_cfgs[model_name].get(cfg_key, None) return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
return None raise RuntimeError(f'No pretrained config exist for model {model_name}.')

@ -76,6 +76,9 @@ model_cfgs = dict(
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25), regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25), regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25), regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
regnety_640=RegNetCfg(w0=352, wa=147.48, wm=2.4, group_size=328, depth=20, se_ratio=0.25),
regnety_1280=RegNetCfg(w0=456, wa=160.83, wm=2.52, group_size=264, depth=27, se_ratio=0.25),
regnety_2560=RegNetCfg(w0=640, wa=124.47, wm=2.04, group_size=848, depth=27, se_ratio=0.25),
# Experimental # Experimental
regnety_040s_gn=RegNetCfg( regnety_040s_gn=RegNetCfg(
@ -150,7 +153,12 @@ default_cfgs = dict(
regnety_160=_cfg( regnety_160=_cfg(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
crop_pct=1.0, test_input_size=(3, 288, 288)), crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), regnety_320=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
),
regnety_640=_cfg(url=''),
regnety_1280=_cfg(url=''),
regnety_2560=_cfg(url=''),
regnety_040s_gn=_cfg(url=''), regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg( regnetv_040=_cfg(
@ -508,6 +516,34 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict): def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
if 'classy_state_dict' in state_dict:
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']
out = {}
for k, v in state_dict['trunk'].items():
k = k.replace('_feature_blocks.conv1.stem.0', 'stem.conv')
k = k.replace('_feature_blocks.conv1.stem.1', 'stem.bn')
k = re.sub(
r'^_feature_blocks.res\d.block(\d)-(\d+)',
lambda x: f's{int(x.group(1))}.b{int(x.group(2)) + 1}', k)
k = re.sub(r's(\d)\.b(\d+)\.bn', r's\1.b\2.downsample.bn', k)
k = k.replace('proj', 'downsample.conv')
k = k.replace('f.a.0', 'conv1.conv')
k = k.replace('f.a.1', 'conv1.bn')
k = k.replace('f.b.0', 'conv2.conv')
k = k.replace('f.b.1', 'conv2.bn')
k = k.replace('f.c', 'conv3.conv')
k = k.replace('f.final_bn', 'conv3.bn')
k = k.replace('f.se.excitation.0', 'se.fc1')
k = k.replace('f.se.excitation.2', 'se.fc2')
out[k] = v
for k, v in state_dict['heads'].items():
if 'projection_head' in k or 'prototypes' in k:
continue
k = k.replace('0.clf.0', 'head.fc')
out[k] = v
return out
if 'model' in state_dict: if 'model' in state_dict:
# For DeiT trained regnety_160 pretraiend model # For DeiT trained regnety_160 pretraiend model
state_dict = state_dict['model'] state_dict = state_dict['model']
@ -666,6 +702,24 @@ def regnety_320(pretrained=False, **kwargs):
return _create_regnet('regnety_320', pretrained, **kwargs) return _create_regnet('regnety_320', pretrained, **kwargs)
@register_model
def regnety_640(pretrained=False, **kwargs):
"""RegNetY-64GF"""
return _create_regnet('regnety_640', pretrained, **kwargs)
@register_model
def regnety_1280(pretrained=False, **kwargs):
"""RegNetY-128GF"""
return _create_regnet('regnety_1280', pretrained, **kwargs)
@register_model
def regnety_2560(pretrained=False, **kwargs):
"""RegNetY-256GF"""
return _create_regnet('regnety_2560', pretrained, **kwargs)
@register_model @register_model
def regnety_040s_gn(pretrained=False, **kwargs): def regnety_040s_gn(pretrained=False, **kwargs):
"""RegNetY-4.0GF w/ GroupNorm """ """RegNetY-4.0GF w/ GroupNorm """

@ -57,52 +57,52 @@ default_cfgs = {
# pretrained on imagenet21k, finetuned on imagenet1k # pretrained on imagenet21k, finetuned on imagenet1k
'resnetv2_50x1_bitm': _cfg( 'resnetv2_50x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_50x3_bitm': _cfg( 'resnetv2_50x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x1_bitm': _cfg( 'resnetv2_101x1_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_101x3_bitm': _cfg( 'resnetv2_101x3_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x2_bitm': _cfg( 'resnetv2_152x2_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, custom_load=True),
'resnetv2_152x4_bitm': _cfg( 'resnetv2_152x4_bitm': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480? input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0, custom_load=True), # only one at 480x480?
# trained on imagenet-21k # trained on imagenet-21k
'resnetv2_50x1_bitm_in21k': _cfg( 'resnetv2_50x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
num_classes=21843), num_classes=21843, custom_load=True),
'resnetv2_50x3_bitm_in21k': _cfg( 'resnetv2_50x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
num_classes=21843), num_classes=21843, custom_load=True),
'resnetv2_101x1_bitm_in21k': _cfg( 'resnetv2_101x1_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
num_classes=21843), num_classes=21843, custom_load=True),
'resnetv2_101x3_bitm_in21k': _cfg( 'resnetv2_101x3_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
num_classes=21843), num_classes=21843, custom_load=True),
'resnetv2_152x2_bitm_in21k': _cfg( 'resnetv2_152x2_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
num_classes=21843), num_classes=21843, custom_load=True),
'resnetv2_152x4_bitm_in21k': _cfg( 'resnetv2_152x4_bitm_in21k': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843), num_classes=21843, custom_load=True),
'resnetv2_50x1_bit_distilled': _cfg( 'resnetv2_50x1_bit_distilled': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz', url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
interpolation='bicubic'), interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit_teacher': _cfg( 'resnetv2_152x2_bit_teacher': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz', url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
interpolation='bicubic'), interpolation='bicubic', custom_load=True),
'resnetv2_152x2_bit_teacher_384': _cfg( 'resnetv2_152x2_bit_teacher_384': _cfg(
url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz', url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic', custom_load=True),
'resnetv2_50': _cfg( 'resnetv2_50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth',
@ -507,8 +507,8 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
ResNetV2, variant, pretrained, ResNetV2, variant, pretrained,
feature_cfg=feature_cfg, feature_cfg=feature_cfg,
pretrained_custom_load='_bit' in variant, **kwargs,
**kwargs) )
def _create_resnetv2_bit(variant, pretrained=False, **kwargs): def _create_resnetv2_bit(variant, pretrained=False, **kwargs):

@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from ._pretrained import generate_defaults
from .registry import register_model from .registry import register_model
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -50,59 +51,50 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = { default_cfgs = generate_defaults({
# patch models (weights from official Google JAX impl) # patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg( 'vit_tiny_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), custom_load=True),
'vit_tiny_patch16_384': _cfg( 'vit_tiny_patch16_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224.in21ft1k': _cfg(
'vit_small_patch32_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
url='https://storage.googleapis.com/vit_models/augreg/' custom_load=True),
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch32_384.in21ft1k': _cfg(
'vit_small_patch32_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
url='https://storage.googleapis.com/vit_models/augreg/' custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 'vit_small_patch16_224.in21ft1k': _cfg(
input_size=(3, 384, 384), crop_pct=1.0), url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
'vit_small_patch16_224': _cfg( custom_load=True),
url='https://storage.googleapis.com/vit_models/augreg/' 'vit_small_patch16_384.in21ft1k': _cfg(
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
'vit_small_patch16_384': _cfg( custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
url='https://storage.googleapis.com/vit_models/augreg/' 'vit_base_patch32_224.in21ft1k': _cfg(
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
input_size=(3, 384, 384), crop_pct=1.0), custom_load=True),
'vit_base_patch32_224': _cfg( 'vit_base_patch32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_384': _cfg( 'vit_base_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', custom_load=True),
input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_384.in21ft1k': _cfg(
'vit_base_patch16_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
url='https://storage.googleapis.com/vit_models/augreg/' custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_base_patch8_224.in21ft1k': _cfg(
'vit_base_patch16_384': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
url='https://storage.googleapis.com/vit_models/augreg/' custom_load=True),
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 'vit_large_patch32_384.in21ft1k': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg( 'vit_large_patch16_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), custom_load=True),
'vit_large_patch16_384': _cfg( 'vit_large_patch16_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch14_224': _cfg(url=''), 'vit_large_patch14_224': _cfg(url=''),
'vit_huge_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''),
@ -111,92 +103,128 @@ default_cfgs = {
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg( 'vit_tiny_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_small_patch32_224_in21k': _cfg( 'vit_small_patch32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_small_patch16_224_in21k': _cfg( 'vit_small_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_base_patch32_224_in21k': _cfg( 'vit_base_patch32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_base_patch16_224_in21k': _cfg( 'vit_base_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_base_patch8_224_in21k': _cfg( 'vit_base_patch8_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_large_patch32_224_in21k': _cfg( 'vit_large_patch32_224.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843), num_classes=21843),
'vit_large_patch16_224_in21k': _cfg( 'vit_large_patch16_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843), custom_load=True, num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg( 'vit_huge_patch14_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k', hf_hub_id='timm/vit_huge_patch14_224_in21k',
num_classes=21843), custom_load=True, num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548) # SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_224_sam': _cfg( 'vit_base_patch32_224.sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True),
'vit_base_patch16_224_sam': _cfg( 'vit_base_patch16_224.sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True),
# DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
'vit_small_patch16_224_dino': _cfg( 'vit_small_patch16_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_small_patch8_224_dino': _cfg( 'vit_small_patch8_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch16_224_dino': _cfg( 'vit_base_patch16_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch8_224_dino': _cfg( 'vit_base_patch8_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# ViT ImageNet-21K-P pretraining by MILL # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg( 'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
'vit_base_patch16_224_miil': _cfg( 'vit_base_patch16_224_miil.in21ft1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
# custom timm variants
'vit_base_patch16_rpn_224': _cfg( 'vit_base_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
'vit_medium_patch16_gap_240.in12k': _cfg(
# experimental (may be removed) url='',
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_medium_patch16_gap_256.in12ft1k': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_medium_patch16_gap_384.in12ft1k': _cfg(url='', input_size=(3, 384, 384), crop_pct=0.95),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''), # CLIP pretrained image tower and related fine-tuned weights
'vit_base_patch32_224_clip.laion2b': _cfg(
'vit_base_patch32_224_clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_224_clip_laion2b': _cfg( 'vit_large_patch14_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768), mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=768),
'vit_huge_patch14_224_clip_laion2b': _cfg( 'vit_huge_patch14_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
'vit_giant_patch14_224_clip_laion2b': _cfg( 'vit_giant_patch14_224_clip.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=1024),
} 'vit_base_patch32_224_clip.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_large_patch14_224_clip.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'vit_huge_patch14_224_clip.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_224_clip.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_large_patch14_224_clip.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'vit_huge_patch14_224_clip.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_224_clip.laion2b_ft_in12k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_large_patch14_224_clip.laion2b_ft_in12k': _cfg(
hf_hub_id='',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=11821),
'vit_huge_patch14_224_clip.laion2b_ft_in12k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
# experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),
})
class Attention(nn.Module): class Attention(nn.Module):
@ -782,14 +810,11 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) return build_model_with_cfg(
model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs,
**kwargs) )
return model
@register_model @register_model
@ -831,7 +856,6 @@ def vit_small_patch32_384(pretrained=False, **kwargs):
@register_model @register_model
def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
@ -841,13 +865,21 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_small_patch16_384(pretrained=False, **kwargs): def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_small_patch8_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8)
"""
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch32_224(pretrained=False, **kwargs): def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
@ -974,175 +1006,90 @@ def vit_gigantic_patch14_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_miil(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch32_224_in21k(pretrained=False, **kwargs): def vit_base_patch32_224_clip(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-B/32
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs): def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch8_224_in21k(pretrained=False, **kwargs): def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs): def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
""" """
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs): def vit_large_patch14_224_clip(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/14)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_huge_patch14_224_clip(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224_dino(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch8_224_dino(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch8_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_dino(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) /w DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch8_224_dino(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/8) w/ DINO pretrained weights (no head) - https://arxiv.org/abs/2104.14294
"""
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224_dino', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch16_224_miil(pretrained=False, **kwargs): def vit_giant_patch14_224_clip(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model_kwargs = dict(
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip', pretrained=pretrained, **model_kwargs)
return model return model
@ -1211,46 +1158,3 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_patch32_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-B/32
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_giant_patch14_224_clip_laion2b(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower trained on LAION-2B image-text pairs.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224_clip_laion2b', pretrained=pretrained, **model_kwargs)
return model

@ -20,6 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._pretrained import generate_defaults
from .layers import StdConv2dSame, StdConv2d, to_2tuple from .layers import StdConv2dSame, StdConv2d, to_2tuple
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem from .resnetv2 import ResNetV2, create_resnetv2_stem
@ -38,52 +39,49 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = { default_cfgs = generate_defaults({
# hybrid in-1k models (weights from official JAX impl where they exist) # hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg( 'vit_tiny_r_s16_p8_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', custom_load=True,
first_conv='patch_embed.backbone.conv'), first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg( 'vit_tiny_r_s16_p8_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_r26_s32_224.in21ft1k': _cfg(
'vit_small_r26_s32_224': _cfg( url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
url='https://storage.googleapis.com/vit_models/augreg/' custom_load=True,
'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
), ),
'vit_small_r26_s32_384': _cfg( 'vit_small_r26_s32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r26_s32_224': _cfg(), 'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(), 'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg( 'vit_base_r50_s16_384.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg( 'vit_large_r50_s32_224.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' custom_load=True,
), ),
'vit_large_r50_s32_384': _cfg( 'vit_large_r50_s32_384.in21ft1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
input_size=(3, 384, 384), crop_pct=1.0
), ),
# hybrid in-21k models (weights from official Google JAX impl where they exist) # hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224_in21k': _cfg( 'vit_tiny_r_s16_p8_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
'vit_small_r26_s32_224_in21k': _cfg( 'vit_small_r26_s32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9), num_classes=21843, crop_pct=0.9, custom_load=True),
'vit_base_r50_s16_224_in21k': _cfg( 'vit_base_r50_s16_224.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9), num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg( 'vit_large_r50_s32_224.in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9), num_classes=21843, crop_pct=0.9, custom_load=True),
# hybrid models (using timm resnet backbones) # hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg( 'vit_small_resnet26d_224': _cfg(
@ -94,7 +92,7 @@ default_cfgs = {
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg( 'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
} })
class HybridEmbed(nn.Module): class HybridEmbed(nn.Module):
@ -248,12 +246,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model @register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs): def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. """ R50+ViT-L/S32 hybrid.
@ -276,57 +268,6 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

Loading…
Cancel
Save