Transitioning default_cfg -> pretrained_cfg. Improving handling of pretrained_cfg source (HF-Hub, files, timm config, etc). Checkpoint handling tweaks.

pull/1014/head
Ross Wightman 3 years ago
parent de5fa791c6
commit abc9ba2544

@ -49,7 +49,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and # If all aux_bn keys are removed, the SplitBN layers will end up as normal and
# load with the unmodified model using BatchNorm2d. # load with the unmodified model using BatchNorm2d.
continue continue
name = k[7:] if k.startswith('module') else k name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint)) print("=> Loaded state_dict from '{}'".format(checkpoint))

@ -11,8 +11,8 @@ except ImportError:
has_fx_feature_extraction = False has_fx_feature_extraction = False
import timm import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ from timm import list_models, create_model, set_scriptable, has_pretrained_cfg_key, is_pretrained_cfg_key, \
get_model_default_value get_pretrained_cfg_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'): if hasattr(torch._C, '_jit_set_profiling_executor'):
@ -54,9 +54,9 @@ MAX_BWD_FX_SIZE = 224
def _get_input_size(model=None, model_name='', target=None): def _get_input_size(model=None, model_name='', target=None):
if model is None: if model is None:
assert model_name, "One of model or model_name must be provided" assert model_name, "One of model or model_name must be provided"
input_size = get_model_default_value(model_name, 'input_size') input_size = get_pretrained_cfg_value(model_name, 'input_size')
fixed_input_size = get_model_default_value(model_name, 'fixed_input_size') fixed_input_size = get_pretrained_cfg_value(model_name, 'fixed_input_size')
min_input_size = get_model_default_value(model_name, 'min_input_size') min_input_size = get_pretrained_cfg_value(model_name, 'min_input_size')
else: else:
default_cfg = model.default_cfg default_cfg = model.default_cfg
input_size = default_cfg['input_size'] input_size = default_cfg['input_size']

@ -1,4 +1,4 @@
from .version import __version__ from .version import __version__
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \
get_model_default_value, is_model_pretrained get_pretrained_cfg_value, is_model_pretrained

@ -49,10 +49,10 @@ from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .xcit import * from .xcit import *
from .factory import create_model, split_model_name, safe_model_name from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -339,14 +339,12 @@ class Beit(nn.Module):
return x return x
def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs): def _create_beit(variant, pretrained=False, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Beit models.') raise RuntimeError('features_only not implemented for Beit models.')
model = build_model_with_cfg( model = build_model_with_cfg(
Beit, variant, pretrained, Beit, variant, pretrained,
default_cfg=default_cfg,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)

@ -327,7 +327,6 @@ model_cfgs = dict(
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
ByobNet, variant, pretrained, ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs)

@ -1553,7 +1553,6 @@ def _init_weights(module, name='', zero_init_last=False):
def _create_byobnet(variant, pretrained=False, **kwargs): def _create_byobnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
ByobNet, variant, pretrained, ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs)

@ -14,7 +14,7 @@ import torch.nn as nn
from functools import partial from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model from .registry import register_model
@ -318,7 +318,6 @@ def _create_cait(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
Cait, variant, pretrained, Cait, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -16,7 +16,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model from .registry import register_model
from .layers import _assert from .layers import _assert
@ -610,7 +610,6 @@ def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
CoaT, variant, pretrained, CoaT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -318,10 +318,7 @@ def _create_convit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg( return build_model_with_cfg(ConViT, variant, pretrained, **kwargs)
ConViT, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -80,7 +80,7 @@ class ConvMixer(nn.Module):
def _create_convmixer(variant, pretrained=False, **kwargs): def _create_convmixer(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
@register_model @register_model

@ -413,7 +413,6 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
CrossViT, variant, pretrained, CrossViT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=pretrained_filter_fn, pretrained_filter_fn=pretrained_filter_fn,
**kwargs) **kwargs)

@ -413,7 +413,6 @@ def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0] cfg_variant = variant.split('_')[0]
return build_model_with_cfg( return build_model_with_cfg(
CspNet, variant, pretrained, CspNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant],
**kwargs) **kwargs)

@ -288,7 +288,6 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
kwargs['block_config'] = block_config kwargs['block_config'] = block_config
return build_model_with_cfg( return build_model_with_cfg(
DenseNet, variant, pretrained, DenseNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
**kwargs) **kwargs)

@ -341,7 +341,6 @@ class DLA(nn.Module):
def _create_dla(variant, pretrained=False, **kwargs): def _create_dla(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
DLA, variant, pretrained, DLA, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=False, pretrained_strict=False,
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
**kwargs) **kwargs)

@ -264,7 +264,6 @@ class DPN(nn.Module):
def _create_dpn(variant, pretrained=False, **kwargs): def _create_dpn(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
DPN, variant, pretrained, DPN, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_concat=True, flatten_sequential=True), feature_cfg=dict(feature_concat=True, flatten_sequential=True),
**kwargs) **kwargs)

@ -48,7 +48,7 @@ from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
from .registry import register_model from .registry import register_model
@ -599,12 +599,11 @@ def _create_effnet(variant, pretrained=False, **kwargs):
model_cls = EfficientNetFeatures model_cls = EfficientNetFeatures
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only, pretrained_strict=not features_only,
kwargs_filter=kwargs_filter, kwargs_filter=kwargs_filter,
**kwargs) **kwargs)
if features_only: if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg) model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
return model return model
@ -1475,7 +1474,7 @@ def efficientnet_b0_g16_evos(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group 16 conv + EvoNorm""" """ EfficientNet-B0 w/ group 16 conv + EvoNorm"""
model = _gen_efficientnet( model = _gen_efficientnet(
'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16, 'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16,
norm_layer=partial(EvoNorm2dS0, group_size=16), pretrained=pretrained, **kwargs) pretrained=pretrained, **kwargs) #norm_layer=partial(EvoNorm2dS0, group_size=16),
return model return model

@ -1,30 +1,36 @@
from urllib.parse import urlsplit, urlunsplit
import os
from .registry import is_model, is_model_in_modules, model_entrypoint from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint from .helpers import load_checkpoint
from .layers import set_layer_config from .layers import set_layer_config
from .hub import load_model_config_from_hf from .hub import load_model_config_from_hf
def split_model_name(model_name): def parse_model_name(model_name):
model_split = model_name.split(':', 1) model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use
if len(model_split) == 1: parsed = urlsplit(model_name)
return '', model_split[0] assert parsed.scheme in ('', 'timm', 'hf-hub')
if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path
else: else:
source_name, model_name = model_split model_name = os.path.split(parsed.path)[-1]
assert source_name in ('timm', 'hf_hub') return 'timm', model_name
return source_name, model_name
def safe_model_name(model_name, remove_source=True): def safe_model_name(model_name, remove_source=True):
def make_safe(name): def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source: if remove_source:
model_name = split_model_name(model_name)[-1] model_name = parse_model_name(model_name)[-1]
return make_safe(model_name) return make_safe(model_name)
def create_model( def create_model(
model_name, model_name,
pretrained=False, pretrained=False,
pretrained_cfg=None,
checkpoint_path='', checkpoint_path='',
scriptable=None, scriptable=None,
exportable=None, exportable=None,
@ -45,33 +51,24 @@ def create_model(
global_pool (str): global pool type (default: 'avg') global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific **: other kwargs are model specific
""" """
source_name, model_name = split_model_name(model_name)
# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate
# Parameters that aren't supported by all models or are intended to only override model defaults if set # Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that # should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect. # non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
if source_name == 'hf_hub': model_source, model_name = parse_model_name(model_name)
# For model names specified in the form `hf_hub:path/architecture_name#revision`, if model_source == 'hf-hub':
# load model weights + default_cfg from Hugging Face hub. # FIXME hf-hub source overrides any passed in pretrained_cfg, warn?
hf_default_cfg, model_name = load_model_config_from_hf(model_name) # For model names specified in the form `hf-hub:path/architecture_name@revision`,
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday # load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
if is_model(model_name): if not is_model(model_name):
create_fn = model_entrypoint(model_name)
else:
raise RuntimeError('Unknown model (%s)' % model_name) raise RuntimeError('Unknown model (%s)' % model_name)
create_fn = model_entrypoint(model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
model = create_fn(pretrained=pretrained, **kwargs) model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs)
if checkpoint_path: if checkpoint_path:
load_checkpoint(model, checkpoint_path) load_checkpoint(model, checkpoint_path)

@ -1,13 +1,15 @@
""" PyTorch FX Based Feature Extraction Helpers """ PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html Using https://pytorch.org/vision/stable/feature_extraction.html
""" """
from typing import Callable from typing import Callable, List, Dict, Union
import torch
from torch import nn from torch import nn
from .features import _get_feature_info from .features import _get_feature_info
try: try:
from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
has_fx_feature_extraction = True has_fx_feature_extraction = True
except ImportError: except ImportError:
has_fx_feature_extraction = False has_fx_feature_extraction = False
@ -61,18 +63,52 @@ def register_notrace_function(func: Callable):
return func return func
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
)
class FeatureGraphNet(nn.Module): class FeatureGraphNet(nn.Module):
""" A FX Graph based feature extractor that works with the model feature_info metadata
"""
def __init__(self, model, out_indices, out_map=None): def __init__(self, model, out_indices, out_map=None):
super().__init__() super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices) self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None: if out_map is not None:
assert len(out_map) == len(out_indices) assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] return_nodes = {
for i, info in enumerate(self.feature_info) if i in out_indices} info['module']: out_map[i] if out_map is not None else info['module']
self.graph_module = create_feature_extractor( for i, info in enumerate(self.feature_info) if i in out_indices}
model, return_nodes, self.graph_module = create_feature_extractor(model, return_nodes)
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
def forward(self, x): def forward(self, x):
return list(self.graph_module(x).values()) return list(self.graph_module(x).values())
class FeatureExtractNet(nn.Module):
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
NOTE:
* one can use feature_extractor directly if dictionary output is desired
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
metadata for builtin feature extraction mode
* feature_extractor can be used directly if dictionary output is acceptable
Args:
model: model to extract features from
return_nodes: node names to return features from (dict or list)
squeeze_out: if only one output, and output in list format, flatten to single tensor
"""
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
super().__init__()
self.squeeze_out = squeeze_out
self.graph_module = create_feature_extractor(model, return_nodes)
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
out = list(self.graph_module(x).values())
if self.squeeze_out and len(out) == 1:
return out[0]
return out

@ -250,7 +250,6 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
) )
return build_model_with_cfg( return build_model_with_cfg(
GhostNet, variant, pretrained, GhostNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**model_kwargs) **model_kwargs)

@ -58,10 +58,7 @@ default_cfgs = {
def _create_resnet(variant, pretrained=False, **kwargs): def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -234,7 +234,6 @@ class Xception65(nn.Module):
def _create_gluon_xception(variant, pretrained=False, **kwargs): def _create_gluon_xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
Xception65, variant, pretrained, Xception65, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), feature_cfg=dict(feature_cls='hook'),
**kwargs) **kwargs)

@ -5,7 +5,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .efficientnet_blocks import SqueezeExcite from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import get_act_fn from .layers import get_act_fn
from .mobilenetv3 import MobileNetV3, MobileNetV3Features from .mobilenetv3 import MobileNetV3, MobileNetV3Features
from .registry import register_model from .registry import register_model
@ -59,12 +59,11 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
model_cls = MobileNetV3Features model_cls = MobileNetV3Features
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only, pretrained_strict=not features_only,
kwargs_filter=kwargs_filter, kwargs_filter=kwargs_filter,
**model_kwargs) **model_kwargs)
if features_only: if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg) model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
return model return model

@ -7,7 +7,7 @@ import os
import math import math
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -17,12 +17,28 @@ from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
from .registry import get_pretrained_cfg
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def load_state_dict(checkpoint_path, use_ema=False): # Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
def clean_state_dict(state_dict):
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
return cleaned_state_dict
def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = '' state_dict_key = ''
@ -35,16 +51,7 @@ def load_state_dict(checkpoint_path, use_ema=False):
state_dict_key = 'state_dict' state_dict_key = 'state_dict'
elif 'model' in checkpoint: elif 'model' in checkpoint:
state_dict_key = 'model' state_dict_key = 'model'
if state_dict_key: state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
state_dict = checkpoint[state_dict_key]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict return state_dict
else: else:
@ -52,7 +59,7 @@ def load_state_dict(checkpoint_path, use_ema=False):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn # numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'): if hasattr(model, 'load_pretrained'):
@ -71,11 +78,8 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
if log_info: if log_info:
_logger.info('Restoring model state from checkpoint...') _logger.info('Restoring model state from checkpoint...')
new_state_dict = OrderedDict() state_dict = clean_state_dict(checkpoint['state_dict'])
for k, v in checkpoint['state_dict'].items(): model.load_state_dict(state_dict)
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if optimizer is not None and 'optimizer' in checkpoint: if optimizer is not None and 'optimizer' in checkpoint:
if log_info: if log_info:
@ -104,7 +108,50 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
raise FileNotFoundError() raise FileNotFoundError()
def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
# hf-hub specified as source via model identifier
load_from = 'hf-hub'
assert hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=False):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
return load_from, pretrained_loc
def set_pretrained_download_progress(enable=True):
""" Set download progress for pretrained weights on/off (globally). """
global _DOWNLOAD_PROGRESS
_DOWNLOAD_PROGRESS = enable
def set_pretrained_check_hash(enable=True):
""" Set hash checking for pretrained weights on/off (globally). """
global _CHECK_HASH
_CHECK_HASH = enable
def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
):
r"""Loads a custom (read non .pth) weight file r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
@ -116,7 +163,7 @@ def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False
Args: Args:
model: The instantiated model to load weights into model: The instantiated model to load weights into
default_cfg (dict): Default pretrained model cfg pretrained_cfg (dict): Default pretrained model cfg
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists 'laod_pretrained' on the model will be called if it exists
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
@ -125,17 +172,20 @@ def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False
digits of the SHA256 hash of the contents of the file. The hash is used to digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False ensure unique names and to verify the contents of the file. Default: False
""" """
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
pretrained_url = default_cfg.get('url', None) load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if not pretrained_url: if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) if load_from == 'hf-hub': # FIXME
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS)
if load_fn is not None: if load_fn is not None:
load_fn(model, cached_file) load_fn(model, pretrained_loc)
elif hasattr(model, 'load_pretrained'): elif hasattr(model, 'load_pretrained'):
model.load_pretrained(cached_file) model.load_pretrained(pretrained_loc)
else: else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.") _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
@ -165,31 +215,41 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight return conv_weight
def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): def load_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
):
""" Load pretrained checkpoint """ Load pretrained checkpoint
Args: Args:
model (nn.Module) : PyTorch model module model (nn.Module) : PyTorch model module
default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for model num_classes (int): num_classes for model
in_chans (int): in_chans for model in_chans (int): in_chans for model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint strict (bool): strict load of checkpoint
progress (bool): enable progress bar for weight download
""" """
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {}
pretrained_url = default_cfg.get('url', None) load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
hf_hub_id = default_cfg.get('hf_hub', None) if load_from == 'file':
if not pretrained_url and not hf_hub_id: _logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
_logger.warning("No pretrained weights exist for this model. Using random initialization.") state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return return
if pretrained_url:
_logger.info(f'Loading pretrained weights from url ({pretrained_url})')
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
elif hf_hub_id and has_hf_hub(necessary=True):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
if filter_fn is not None: if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two # for backwards compat with filter fn that take one arg, try one first, the two
try: try:
@ -197,7 +257,7 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
except TypeError: except TypeError:
state_dict = filter_fn(state_dict, model) state_dict = filter_fn(state_dict, model)
input_convs = default_cfg.get('first_conv', None) input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3: if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str): if isinstance(input_convs, str):
input_convs = (input_convs,) input_convs = (input_convs,)
@ -213,12 +273,12 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
_logger.warning( _logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifiers = default_cfg.get('classifier', None) classifiers = pretrained_cfg.get('classifier', None)
label_offset = default_cfg.get('label_offset', 0) label_offset = pretrained_cfg.get('label_offset', 0)
if classifiers is not None: if classifiers is not None:
if isinstance(classifiers, str): if isinstance(classifiers, str):
classifiers = (classifiers,) classifiers = (classifiers,)
if num_classes != default_cfg['num_classes']: if num_classes != pretrained_cfg['num_classes']:
for classifier_name in classifiers: for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights # completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.weight']
@ -333,43 +393,43 @@ def adapt_model_from_file(parent_module, model_variant):
return adapt_model_from_string(parent_module, f.read().strip()) return adapt_model_from_string(parent_module, f.read().strip())
def default_cfg_for_features(default_cfg): def pretrained_cfg_for_features(pretrained_cfg):
default_cfg = deepcopy(default_cfg) pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone # remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove: for tr in to_remove:
default_cfg.pop(tr, None) pretrained_cfg.pop(tr, None)
return default_cfg return pretrained_cfg
def overlay_external_default_cfg(default_cfg, kwargs): # def overlay_external_pretrained_cfg(pretrained_cfg, kwargs):
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. # """ Overlay 'external_pretrained_cfg' in kwargs on top of pretrained_cfg arg.
""" # """
external_default_cfg = kwargs.pop('external_default_cfg', None) # external_pretrained_cfg = kwargs.pop('external_pretrained_cfg', None)
if external_default_cfg: # if external_pretrained_cfg:
default_cfg.pop('url', None) # url should come from external cfg # pretrained_cfg.pop('url', None) # url should come from external cfg
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg # pretrained_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
default_cfg.update(external_default_cfg) # pretrained_cfg.update(external_pretrained_cfg)
def set_default_kwargs(kwargs, names, default_cfg): def set_default_kwargs(kwargs, names, pretrained_cfg):
for n in names: for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# default_cfg has one input_size=(C, H ,W) entry # pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size': if n == 'img_size':
input_size = default_cfg.get('input_size', None) input_size = pretrained_cfg.get('input_size', None)
if input_size is not None: if input_size is not None:
assert len(input_size) == 3 assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:]) kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans': elif n == 'in_chans':
input_size = default_cfg.get('input_size', None) input_size = pretrained_cfg.get('input_size', None)
if input_size is not None: if input_size is not None:
assert len(input_size) == 3 assert len(input_size) == 3
kwargs.setdefault(n, input_size[0]) kwargs.setdefault(n, input_size[0])
else: else:
default_val = default_cfg.get(n, None) default_val = pretrained_cfg.get(n, None)
if default_val is not None: if default_val is not None:
kwargs.setdefault(n, default_cfg[n]) kwargs.setdefault(n, pretrained_cfg[n])
def filter_kwargs(kwargs, names): def filter_kwargs(kwargs, names):
@ -379,36 +439,46 @@ def filter_kwargs(kwargs, names):
kwargs.pop(n, None) kwargs.pop(n, None)
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model """ Update the default_cfg and kwargs before passing to model
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
could/should be replaced by an improved configuration mechanism
Args: Args:
default_cfg: input default_cfg (updated in-place) pretrained_cfg: input pretrained cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place) kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__ kwargs_filter: keyword arg keys that must be removed before model __init__
""" """
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if default_cfg.get('fixed_input_size', False): if pretrained_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',) default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.) # Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter) filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
if pretrained_cfg and isinstance(pretrained_cfg, dict):
# highest priority, pretrained_cfg available and passed explicitly
return deepcopy(pretrained_cfg)
if kwargs and 'pretrained_cfg' in kwargs:
# next highest, pretrained_cfg in a kwargs dict, pop and return
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
if pretrained_cfg:
return deepcopy(pretrained_cfg)
# lookup pretrained cfg in model registry by variant
pretrained_cfg = get_pretrained_cfg(variant)
assert pretrained_cfg
return pretrained_cfg
def build_model_with_cfg( def build_model_with_cfg(
model_cls: Callable, model_cls: Callable,
variant: str, variant: str,
pretrained: bool, pretrained: bool,
default_cfg: dict, pretrained_cfg: Optional[Dict] = None,
model_cfg: Optional[Any] = None, model_cfg: Optional[Any] = None,
feature_cfg: Optional[dict] = None, feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True, pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None, pretrained_filter_fn: Optional[Callable] = None,
pretrained_custom_load: bool = False, pretrained_custom_load: bool = False,
@ -417,7 +487,7 @@ def build_model_with_cfg(
""" Build model with specified default_cfg and optional model_cfg """ Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including: This helper fn aids in the construction of a model including:
* handling default_cfg and associated pretained weight loading * handling default_cfg and associated pretrained weight loading
* passing through optional model_cfg for models with config based arch spec * passing through optional model_cfg for models with config based arch spec
* features_only model adaptation * features_only model adaptation
* pruning config / model adaptation * pruning config / model adaptation
@ -426,7 +496,7 @@ def build_model_with_cfg(
model_cls (nn.Module): model class model_cls (nn.Module): model class
variant (str): model variant name variant (str): model variant name
pretrained (bool): load pretrained weights pretrained (bool): load pretrained weights
default_cfg (dict): model's default pretrained/task config pretrained_cfg (dict): model's pretrained weight/task config
model_cfg (Optional[Dict]): model's architecture config model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly pretrained_strict (bool): load pretrained weights strictly
@ -438,9 +508,11 @@ def build_model_with_cfg(
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
features = False features = False
feature_cfg = feature_cfg or {} feature_cfg = feature_cfg or {}
default_cfg = deepcopy(default_cfg) if default_cfg else {}
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) # resolve and update model pretrained config and model kwargs
default_cfg.setdefault('architecture', variant) pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg)
update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter)
pretrained_cfg.setdefault('architecture', variant)
# Setup for feature extraction wrapper done at end of this fn # Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False): if kwargs.pop('features_only', False):
@ -451,7 +523,8 @@ def build_model_with_cfg(
# Build the model # Build the model
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = default_cfg model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
if pruned: if pruned:
model = adapt_model_from_file(model, variant) model = adapt_model_from_file(model, variant)
@ -460,10 +533,12 @@ def build_model_with_cfg(
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained: if pretrained:
if pretrained_custom_load: if pretrained_custom_load:
load_custom_pretrained(model) # FIXME improve custom load trigger
load_custom_pretrained(model, pretrained_cfg=pretrained_cfg)
else: else:
load_pretrained( load_pretrained(
model, model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained, num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3), in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, filter_fn=pretrained_filter_fn,
@ -483,7 +558,8 @@ def build_model_with_cfg(
else: else:
assert False, f'Unknown feature class {feature_cls}' assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
return model return model

@ -17,7 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureInfo from .features import FeatureInfo
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import create_classifier from .layers import create_classifier
from .registry import register_model from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@ -781,13 +781,13 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
features_only = True features_only = True
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], model_cfg=cfg_cls[variant],
pretrained_strict=not features_only, pretrained_strict=not features_only,
kwargs_filter=kwargs_filter, kwargs_filter=kwargs_filter,
**model_kwargs) **model_kwargs)
if features_only: if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg) model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model return model

@ -62,6 +62,7 @@ def has_hf_hub(necessary=False):
def hf_split(hf_id): def hf_split(hf_id):
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
rev_split = hf_id.split('@') rev_split = hf_id.split('@')
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
hf_model_id = rev_split[0] hf_model_id = rev_split[0]
@ -84,10 +85,11 @@ def _download_from_hf(model_id: str, filename: str):
def load_model_config_from_hf(model_id: str): def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json') cached_file = _download_from_hf(model_id, 'config.json')
default_cfg = load_cfg_from_json(cached_file) pretrained_cfg = load_cfg_from_json(cached_file)
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
model_name = default_cfg.get('architecture') pretrained_cfg['source'] = 'hf-hub'
return default_cfg, model_name model_name = pretrained_cfg.get('architecture')
return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str): def load_state_dict_from_hf(model_id: str):
@ -107,7 +109,7 @@ def save_for_hf(model, save_directory, model_config=None):
torch.save(model.state_dict(), weights_path) torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json' config_path = save_directory / 'config.json'
hf_config = model.default_cfg hf_config = model.pretrained_cfg
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
hf_config['num_features'] = model_config.pop('num_features', model.num_features) hf_config['num_features'] = model_config.pop('num_features', model.num_features)
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])

@ -335,10 +335,7 @@ class InceptionResnetV2(nn.Module):
def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs)
InceptionResnetV2, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, resolve_pretrained_cfg
from .registry import register_model from .registry import register_model
from .layers import trunc_normal_, create_classifier, Linear from .layers import trunc_normal_, create_classifier, Linear
@ -424,18 +424,19 @@ class InceptionV3Aux(InceptionV3):
def _create_inception_v3(variant, pretrained=False, **kwargs): def _create_inception_v3(variant, pretrained=False, **kwargs):
default_cfg = default_cfgs[variant] pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
aux_logits = kwargs.pop('aux_logits', False) aux_logits = kwargs.pop('aux_logits', False)
if aux_logits: if aux_logits:
assert not kwargs.pop('features_only', False) assert not kwargs.pop('features_only', False)
model_cls = InceptionV3Aux model_cls = InceptionV3Aux
load_strict = default_cfg['has_aux'] load_strict = pretrained_cfg['has_aux']
else: else:
model_cls = InceptionV3 model_cls = InceptionV3
load_strict = not default_cfg['has_aux'] load_strict = not pretrained_cfg['has_aux']
return build_model_with_cfg( return build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,
default_cfg=default_cfg, pretrained_cfg=pretrained_cfg,
pretrained_strict=load_strict, pretrained_strict=load_strict,
**kwargs) **kwargs)

@ -306,7 +306,6 @@ class InceptionV4(nn.Module):
def _create_inception_v4(variant, pretrained=False, **kwargs): def _create_inception_v4(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
InceptionV4, variant, pretrained, InceptionV4, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs)

@ -32,7 +32,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
from .layers import to_ntuple, get_act_layer from .layers import to_ntuple, get_act_layer
from .vision_transformer import trunc_normal_ from .vision_transformer import trunc_normal_
from .registry import register_model from .registry import register_model
@ -554,7 +554,6 @@ def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwar
model_cfg = dict(**model_cfgs[variant], **kwargs) model_cfg = dict(**model_cfgs[variant], **kwargs)
model = build_model_with_cfg( model = build_model_with_cfg(
Levit, variant, pretrained, Levit, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**model_cfg) **model_cfg)
#if fuse: #if fuse:

@ -46,7 +46,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model
@ -360,7 +360,6 @@ def _create_mixer(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
MlpMixer, variant, pretrained, MlpMixer, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -19,7 +19,7 @@ from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer
from .registry import register_model from .registry import register_model
@ -239,12 +239,11 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
model_cls = MobileNetV3Features model_cls = MobileNetV3Features
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only, pretrained_strict=not features_only,
kwargs_filter=kwargs_filter, kwargs_filter=kwargs_filter,
**kwargs) **kwargs)
if features_only: if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg) model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
return model return model

@ -554,7 +554,6 @@ class NASNetALarge(nn.Module):
def _create_nasnet(variant, pretrained=False, **kwargs): def _create_nasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
NASNetALarge, variant, pretrained, NASNetALarge, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs) **kwargs)

@ -395,11 +395,9 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict return state_dict
def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): def _create_nest(variant, pretrained=False, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
model = build_model_with_cfg( model = build_model_with_cfg(
Nest, variant, pretrained, Nest, variant, pretrained,
default_cfg=default_cfg,
feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True), feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)

@ -106,7 +106,7 @@ default_cfgs = dict(
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
eca_nfnet_l0=_dcfg( eca_nfnet_l0=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth',
hf_hub='timm/eca_nfnet_l0', hf_hub_id='timm/eca_nfnet_l0',
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
eca_nfnet_l1=_dcfg( eca_nfnet_l1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
@ -592,7 +592,6 @@ def _create_normfreenet(variant, pretrained=False, **kwargs):
feature_cfg = dict(flatten_sequential=True) feature_cfg = dict(flatten_sequential=True)
return build_model_with_cfg( return build_model_with_cfg(
NormFreeNet, variant, pretrained, NormFreeNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfg, model_cfg=model_cfg,
feature_cfg=feature_cfg, feature_cfg=feature_cfg,
**kwargs) **kwargs)

@ -21,7 +21,7 @@ import torch
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
from .layers import trunc_normal_, to_2tuple from .layers import trunc_normal_, to_2tuple
from .registry import register_model from .registry import register_model
from .vision_transformer import Block from .vision_transformer import Block
@ -262,7 +262,6 @@ def _create_pit(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
PoolingVisionTransformer, variant, pretrained, PoolingVisionTransformer, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -335,7 +335,6 @@ class PNASNet5Large(nn.Module):
def _create_pnasnet(variant, pretrained=False, **kwargs): def _create_pnasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
PNASNet5Large, variant, pretrained, PNASNet5Large, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs) **kwargs)

@ -9,13 +9,13 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to entrypoint fns _model_entrypoints = {} # mapping of model names to entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model default_cfgs _model_pretrained_cfgs = dict() # central repo for model default_cfgs
def register_model(fn): def register_model(fn):
@ -35,13 +35,18 @@ def register_model(fn):
_model_entrypoints[model_name] = fn _model_entrypoints[model_name] = fn
_model_to_module[model_name] = module_name _model_to_module[model_name] = module_name
_module_to_models[module_name].add(model_name) _module_to_models[module_name].add(model_name)
has_pretrained = False # check if model has a pretrained url to allow filtering on this has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] cfg = mod.default_cfgs[model_name]
_model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) has_valid_pretrained = (
if has_pretrained: ('url' in cfg and 'http' in cfg['url']) or
('file' in cfg and cfg['file']) or
('hf_hub_id' in cfg and cfg['hf_hub_id'])
)
_model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name]
if has_valid_pretrained:
_model_has_pretrained.add(model_name) _model_has_pretrained.add(model_name)
return fn return fn
@ -87,7 +92,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
if name_matches_cfg: if name_matches_cfg:
models = set(_model_default_cfgs).intersection(models) models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return list(sorted(models, key=_natural_key))
@ -120,30 +125,35 @@ def is_model_in_modules(model_name, module_names):
return any(model_name in _module_to_models[n] for n in module_names) return any(model_name in _module_to_models[n] for n in module_names)
def has_model_default_key(model_name, cfg_key): def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name):
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
return {}
def has_pretrained_cfg_key(model_name, cfg_key):
""" Query model default_cfgs for existence of a specific key. """ Query model default_cfgs for existence of a specific key.
""" """
if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]:
return True return True
return False return False
def is_model_default_key(model_name, cfg_key): def is_pretrained_cfg_key(model_name, cfg_key):
""" Return truthy value for specified model default_cfg key, False if does not exist. """ Return truthy value for specified model default_cfg key, False if does not exist.
""" """
if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False):
return True return True
return False return False
def get_model_default_value(model_name, cfg_key): def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if it doesn't exist. """ Get a specific model default_cfg value by key. None if it doesn't exist.
""" """
if model_name in _model_default_cfgs: if model_name in _model_pretrained_cfgs:
return _model_default_cfgs[model_name].get(cfg_key, None) return _model_pretrained_cfgs[model_name].get(cfg_key, None)
else: return None
return None
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained

@ -472,7 +472,6 @@ def _filter_fn(state_dict):
def _create_regnet(variant, pretrained, **kwargs): def _create_regnet(variant, pretrained, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
RegNet, variant, pretrained, RegNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
pretrained_filter_fn=_filter_fn, pretrained_filter_fn=_filter_fn,
**kwargs) **kwargs)

@ -133,10 +133,7 @@ class Bottle2neck(nn.Module):
def _create_res2net(variant, pretrained=False, **kwargs): def _create_res2net(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -135,10 +135,7 @@ class ResNestBottleneck(nn.Module):
def _create_resnest(variant, pretrained=False, **kwargs): def _create_resnest(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -680,10 +680,7 @@ class ResNet(nn.Module):
def _create_resnet(variant, pretrained=False, **kwargs): def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -482,7 +482,6 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
feature_cfg = dict(flatten_sequential=True) feature_cfg = dict(flatten_sequential=True)
return build_model_with_cfg( return build_model_with_cfg(
ResNetV2, variant, pretrained, ResNetV2, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=feature_cfg, feature_cfg=feature_cfg,
pretrained_custom_load='_bit' in variant, pretrained_custom_load='_bit' in variant,
**kwargs) **kwargs)

@ -186,7 +186,6 @@ def _create_rexnet(variant, pretrained, **kwargs):
feature_cfg = dict(flatten_sequential=True) feature_cfg = dict(flatten_sequential=True)
return build_model_with_cfg( return build_model_with_cfg(
ReXNetV1, variant, pretrained, ReXNetV1, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=feature_cfg, feature_cfg=feature_cfg,
**kwargs) **kwargs)

@ -321,7 +321,6 @@ def _create_selecsls(variant, pretrained, **kwargs):
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
return build_model_with_cfg( return build_model_with_cfg(
SelecSLS, variant, pretrained, SelecSLS, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=cfg, model_cfg=cfg,
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True),
**kwargs) **kwargs)

@ -397,10 +397,7 @@ class SENet(nn.Module):
def _create_senet(variant, pretrained=False, **kwargs): def _create_senet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(SENet, variant, pretrained, **kwargs)
SENet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -133,10 +133,7 @@ class SelectiveKernelBottleneck(nn.Module):
def _create_skresnet(variant, pretrained=False, **kwargs): def _create_skresnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model @register_model

@ -22,7 +22,7 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers import _assert from .layers import _assert
from .registry import register_model from .registry import register_model
@ -542,23 +542,9 @@ class SwinTransformer(nn.Module):
return x return x
def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_swin_transformer(variant, pretrained=False, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(
SwinTransformer, variant, pretrained, SwinTransformer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)

@ -248,7 +248,6 @@ def _create_tnt(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
TNT, variant, pretrained, TNT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

@ -250,7 +250,6 @@ class TResNet(nn.Module):
def _create_tresnet(variant, pretrained=False, **kwargs): def _create_tresnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
TResNet, variant, pretrained, TResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
**kwargs) **kwargs)

@ -369,10 +369,7 @@ def _create_twins(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(Twins, variant, pretrained, **kwargs)
Twins, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model return model

@ -183,7 +183,6 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
model = build_model_with_cfg( model = build_model_with_cfg(
VGG, variant, pretrained, VGG, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=cfgs[cfg], model_cfg=cfgs[cfg],
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
pretrained_filter_fn=_filter_fn, pretrained_filter_fn=_filter_fn,

@ -12,7 +12,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
from .registry import register_model from .registry import register_model
@ -318,10 +318,7 @@ class Visformer(nn.Module):
def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg( model = build_model_with_cfg(Visformer, variant, pretrained, **kwargs)
Visformer, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model return model

@ -33,7 +33,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model from .registry import register_model
@ -132,7 +132,7 @@ default_cfgs = {
num_classes=21843), num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg( 'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k', hf_hub_id='timm/vit_huge_patch14_224_in21k',
num_classes=21843), num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548) # SAM trained models (https://arxiv.org/abs/2106.01548)
@ -525,13 +525,13 @@ def checkpoint_filter_fn(state_dict, model):
return out_dict return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): def _create_vision_transformer(variant, pretrained=False, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
# NOTE this extra code to support handling of repr size for in21k pretrained models # NOTE this extra code to support handling of repr size for in21k pretrained models
default_num_classes = default_cfg['num_classes'] pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
default_num_classes = pretrained_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes) num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None) repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes: if repr_size is not None and num_classes != default_num_classes:
@ -542,10 +542,10 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
default_cfg=default_cfg, pretrained_cfg=pretrained_cfg,
representation_size=repr_size, representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in default_cfg['url'], pretrained_custom_load='npz' in pretrained_cfg['url'],
**kwargs) **kwargs)
return model return model

@ -143,8 +143,7 @@ class HybridEmbed(nn.Module):
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
embed_layer = partial(HybridEmbed, backbone=backbone) embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer( return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs): def _resnetv2(layers=(3, 4, 9), **kwargs):

@ -339,7 +339,6 @@ class VovNet(nn.Module):
def _create_vovnet(variant, pretrained=False, **kwargs): def _create_vovnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
VovNet, variant, pretrained, VovNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant], model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
**kwargs) **kwargs)

@ -222,7 +222,6 @@ class Xception(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
Xception, variant, pretrained, Xception, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), feature_cfg=dict(feature_cls='hook'),
**kwargs) **kwargs)

@ -227,7 +227,6 @@ class XceptionAligned(nn.Module):
def _xception(variant, pretrained=False, **kwargs): def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
XceptionAligned, variant, pretrained, XceptionAligned, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
**kwargs) **kwargs)

@ -469,9 +469,8 @@ def checkpoint_filter_fn(state_dict, model):
def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
model = build_model_with_cfg( model = build_model_with_cfg(
XCiT, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) XCiT, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
return model return model

@ -234,8 +234,6 @@ parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)') help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently) # Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None, parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)') help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None, parser.add_argument('--bn-eps', type=float, default=None,
@ -375,7 +373,6 @@ def main():
drop_path_rate=args.drop_path, drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block, drop_block_rate=args.drop_block,
global_pool=args.gp, global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum, bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript, scriptable=args.torchscript,
@ -443,6 +440,7 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.') _logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint # optionally resume from a checkpoint
resume_epoch = None resume_epoch = None
if args.resume: if args.resume:

@ -216,7 +216,9 @@ def validate(args):
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
if args.channels_last: if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last) input = input.contiguous(memory_format=torch.channels_last)
model(input) with amp_autocast():
model(input)
end = time.time() end = time.time()
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher: if args.no_prefetcher:

Loading…
Cancel
Save