From abc9ba254430ef971ea3dbd12f2b4f1969da55be Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 Jan 2022 21:54:13 -0800 Subject: [PATCH] Transitioning default_cfg -> pretrained_cfg. Improving handling of pretrained_cfg source (HF-Hub, files, timm config, etc). Checkpoint handling tweaks. --- clean_checkpoint.py | 2 +- tests/test_models.py | 10 +- timm/__init__.py | 4 +- timm/models/__init__.py | 4 +- timm/models/beit.py | 4 +- timm/models/byoanet.py | 1 - timm/models/byobnet.py | 1 - timm/models/cait.py | 3 +- timm/models/coat.py | 3 +- timm/models/convit.py | 5 +- timm/models/convmixer.py | 2 +- timm/models/crossvit.py | 1 - timm/models/cspnet.py | 1 - timm/models/densenet.py | 1 - timm/models/dla.py | 1 - timm/models/dpn.py | 1 - timm/models/efficientnet.py | 7 +- timm/models/factory.py | 49 +++-- timm/models/fx_features.py | 50 ++++- timm/models/ghostnet.py | 1 - timm/models/gluon_resnet.py | 5 +- timm/models/gluon_xception.py | 1 - timm/models/hardcorenas.py | 5 +- timm/models/helpers.py | 238 +++++++++++++++-------- timm/models/hrnet.py | 6 +- timm/models/hub.py | 12 +- timm/models/inception_resnet_v2.py | 5 +- timm/models/inception_v3.py | 11 +- timm/models/inception_v4.py | 1 - timm/models/levit.py | 3 +- timm/models/mlp_mixer.py | 3 +- timm/models/mobilenetv3.py | 5 +- timm/models/nasnet.py | 1 - timm/models/nest.py | 4 +- timm/models/nfnet.py | 3 +- timm/models/pit.py | 3 +- timm/models/pnasnet.py | 1 - timm/models/registry.py | 50 +++-- timm/models/regnet.py | 1 - timm/models/res2net.py | 5 +- timm/models/resnest.py | 5 +- timm/models/resnet.py | 5 +- timm/models/resnetv2.py | 1 - timm/models/rexnet.py | 1 - timm/models/selecsls.py | 1 - timm/models/senet.py | 5 +- timm/models/sknet.py | 5 +- timm/models/swin_transformer.py | 18 +- timm/models/tnt.py | 1 - timm/models/tresnet.py | 1 - timm/models/twins.py | 5 +- timm/models/vgg.py | 1 - timm/models/visformer.py | 7 +- timm/models/vision_transformer.py | 14 +- timm/models/vision_transformer_hybrid.py | 3 +- timm/models/vovnet.py | 1 - timm/models/xception.py | 1 - timm/models/xception_aligned.py | 1 - timm/models/xcit.py | 3 +- train.py | 4 +- validate.py | 4 +- 61 files changed, 321 insertions(+), 280 deletions(-) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 3eea15e6..8ec892b2 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -49,7 +49,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): # If all aux_bn keys are removed, the SplitBN layers will end up as normal and # load with the unmodified model using BatchNorm2d. continue - name = k[7:] if k.startswith('module') else k + name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(checkpoint)) diff --git a/tests/test_models.py b/tests/test_models.py index 01ad7489..77155cfa 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,8 +11,8 @@ except ImportError: has_fx_feature_extraction = False import timm -from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ - get_model_default_value +from timm import list_models, create_model, set_scriptable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ + get_pretrained_cfg_value from timm.models.fx_features import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): @@ -54,9 +54,9 @@ MAX_BWD_FX_SIZE = 224 def _get_input_size(model=None, model_name='', target=None): if model is None: assert model_name, "One of model or model_name must be provided" - input_size = get_model_default_value(model_name, 'input_size') - fixed_input_size = get_model_default_value(model_name, 'fixed_input_size') - min_input_size = get_model_default_value(model_name, 'min_input_size') + input_size = get_pretrained_cfg_value(model_name, 'input_size') + fixed_input_size = get_pretrained_cfg_value(model_name, 'fixed_input_size') + min_input_size = get_pretrained_cfg_value(model_name, 'min_input_size') else: default_cfg = model.default_cfg input_size = default_cfg['input_size'] diff --git a/timm/__init__.py b/timm/__init__.py index 04ec7e51..c5f797b1 100644 --- a/timm/__init__.py +++ b/timm/__init__.py @@ -1,4 +1,4 @@ from .version import __version__ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ - is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ - get_model_default_value, is_model_pretrained + is_scriptable, is_exportable, set_scriptable, set_exportable, has_pretrained_cfg_key, is_pretrained_cfg_key, \ + get_pretrained_cfg_value, is_model_pretrained diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0982b6e1..9682480c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -49,10 +49,10 @@ from .xception import * from .xception_aligned import * from .xcit import * -from .factory import create_model, split_model_name, safe_model_name +from .factory import create_model, parse_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ - has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained + is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/timm/models/beit.py b/timm/models/beit.py index f644b657..e82f6f63 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -339,14 +339,12 @@ class Beit(nn.Module): return x -def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs): - default_cfg = default_cfg or default_cfgs[variant] +def _create_beit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Beit models.') model = build_model_with_cfg( Beit, variant, pretrained, - default_cfg=default_cfg, # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes pretrained_filter_fn=checkpoint_filter_fn, **kwargs) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index f44040b0..3815fa30 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -327,7 +327,6 @@ model_cfgs = dict( def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): return build_model_with_cfg( ByobNet, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index e7faa63d..554b9a6e 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1553,7 +1553,6 @@ def _init_weights(module, name='', zero_init_last=False): def _create_byobnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( ByobNet, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], feature_cfg=dict(flatten_sequential=True), **kwargs) diff --git a/timm/models/cait.py b/timm/models/cait.py index 69b4ba06..28847bf2 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -14,7 +14,7 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .registry import register_model @@ -318,7 +318,6 @@ def _create_cait(variant, pretrained=False, **kwargs): model = build_model_with_cfg( Cait, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model diff --git a/timm/models/coat.py b/timm/models/coat.py index 18ff8ab9..6425d67e 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model from .layers import _assert @@ -610,7 +610,6 @@ def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): model = build_model_with_cfg( CoaT, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model diff --git a/timm/models/convit.py b/timm/models/convit.py index 6ef1da72..a4aafac0 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -318,10 +318,7 @@ def _create_convit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - return build_model_with_cfg( - ConViT, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(ConViT, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index a2400782..df551788 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -80,7 +80,7 @@ class ConvMixer(nn.Module): def _create_convmixer(variant, pretrained=False, **kwargs): - return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index ddc4f64c..37a17dba 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -413,7 +413,6 @@ def _create_crossvit(variant, pretrained=False, **kwargs): return build_model_with_cfg( CrossViT, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=pretrained_filter_fn, **kwargs) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index aa57bd88..897b9f3d 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -413,7 +413,6 @@ def _create_cspnet(variant, pretrained=False, **kwargs): cfg_variant = variant.split('_')[0] return build_model_with_cfg( CspNet, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 7be15f49..ee66666e 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -288,7 +288,6 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): kwargs['block_config'] = block_config return build_model_with_cfg( DenseNet, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs) diff --git a/timm/models/dla.py b/timm/models/dla.py index f6e4dd28..2d8597a5 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -341,7 +341,6 @@ class DLA(nn.Module): def _create_dla(variant, pretrained=False, **kwargs): return build_model_with_cfg( DLA, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 07e4a128..79358695 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -264,7 +264,6 @@ class DPN(nn.Module): def _create_dpn(variant, pretrained=False, **kwargs): return build_model_with_cfg( DPN, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 11e4e827..87f50e27 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -48,7 +48,7 @@ from .efficientnet_blocks import SqueezeExcite from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, default_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct from .registry import register_model @@ -599,12 +599,11 @@ def _create_effnet(variant, pretrained=False, **kwargs): model_cls = EfficientNetFeatures model = build_model_with_cfg( model_cls, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_strict=not features_only, kwargs_filter=kwargs_filter, **kwargs) if features_only: - model.default_cfg = default_cfg_for_features(model.default_cfg) + model.default_cfg = pretrained_cfg_for_features(model.default_cfg) return model @@ -1475,7 +1474,7 @@ def efficientnet_b0_g16_evos(pretrained=False, **kwargs): """ EfficientNet-B0 w/ group 16 conv + EvoNorm""" model = _gen_efficientnet( 'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16, - norm_layer=partial(EvoNorm2dS0, group_size=16), pretrained=pretrained, **kwargs) + pretrained=pretrained, **kwargs) #norm_layer=partial(EvoNorm2dS0, group_size=16), return model diff --git a/timm/models/factory.py b/timm/models/factory.py index 6d3fd982..f7a8fd9c 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,30 +1,36 @@ +from urllib.parse import urlsplit, urlunsplit +import os + from .registry import is_model, is_model_in_modules, model_entrypoint from .helpers import load_checkpoint from .layers import set_layer_config from .hub import load_model_config_from_hf -def split_model_name(model_name): - model_split = model_name.split(':', 1) - if len(model_split) == 1: - return '', model_split[0] +def parse_model_name(model_name): + model_name = model_name.replace('hf_hub', 'hf-hub') # NOTE for backwards compat, to deprecate hf_hub use + parsed = urlsplit(model_name) + assert parsed.scheme in ('', 'timm', 'hf-hub') + if parsed.scheme == 'hf-hub': + # FIXME may use fragment as revision, currently `@` in URI path + return parsed.scheme, parsed.path else: - source_name, model_name = model_split - assert source_name in ('timm', 'hf_hub') - return source_name, model_name + model_name = os.path.split(parsed.path)[-1] + return 'timm', model_name def safe_model_name(model_name, remove_source=True): def make_safe(name): return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') if remove_source: - model_name = split_model_name(model_name)[-1] + model_name = parse_model_name(model_name)[-1] return make_safe(model_name) def create_model( model_name, pretrained=False, + pretrained_cfg=None, checkpoint_path='', scriptable=None, exportable=None, @@ -45,33 +51,24 @@ def create_model( global_pool (str): global pool type (default: 'avg') **: other kwargs are model specific """ - source_name, model_name = split_model_name(model_name) - - # handle backwards compat with drop_connect -> drop_path change - drop_connect_rate = kwargs.pop('drop_connect_rate', None) - if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: - print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." - " Setting drop_path to %f." % drop_connect_rate) - kwargs['drop_path_rate'] = drop_connect_rate - # Parameters that aren't supported by all models or are intended to only override model defaults if set # should default to None in command line args/cfg. Remove them if they are present and not set so that # non-supporting models don't break and default args remain in effect. kwargs = {k: v for k, v in kwargs.items() if v is not None} - if source_name == 'hf_hub': - # For model names specified in the form `hf_hub:path/architecture_name#revision`, - # load model weights + default_cfg from Hugging Face hub. - hf_default_cfg, model_name = load_model_config_from_hf(model_name) - kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday + model_source, model_name = parse_model_name(model_name) + if model_source == 'hf-hub': + # FIXME hf-hub source overrides any passed in pretrained_cfg, warn? + # For model names specified in the form `hf-hub:path/architecture_name@revision`, + # load model weights + pretrained_cfg from Hugging Face hub. + pretrained_cfg, model_name = load_model_config_from_hf(model_name) - if is_model(model_name): - create_fn = model_entrypoint(model_name) - else: + if not is_model(model_name): raise RuntimeError('Unknown model (%s)' % model_name) + create_fn = model_entrypoint(model_name) with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): - model = create_fn(pretrained=pretrained, **kwargs) + model = create_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg, **kwargs) if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index f709d92e..c7ca0f8b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,13 +1,15 @@ """ PyTorch FX Based Feature Extraction Helpers Using https://pytorch.org/vision/stable/feature_extraction.html """ -from typing import Callable +from typing import Callable, List, Dict, Union + +import torch from torch import nn from .features import _get_feature_info try: - from torchvision.models.feature_extraction import create_feature_extractor + from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor has_fx_feature_extraction = True except ImportError: has_fx_feature_extraction = False @@ -61,18 +63,52 @@ def register_notrace_function(func: Callable): return func +def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + return _create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} + ) + + class FeatureGraphNet(nn.Module): + """ A FX Graph based feature extractor that works with the model feature_info metadata + """ def __init__(self, model, out_indices, out_map=None): super().__init__() assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) - return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] - for i, info in enumerate(self.feature_info) if i in out_indices} - self.graph_module = create_feature_extractor( - model, return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + return_nodes = { + info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor(model, return_nodes) def forward(self, x): return list(self.graph_module(x).values()) + + +class FeatureExtractNet(nn.Module): + """ A standalone feature extraction wrapper that maps dict -> list or single tensor + NOTE: + * one can use feature_extractor directly if dictionary output is desired + * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info + metadata for builtin feature extraction mode + * feature_extractor can be used directly if dictionary output is acceptable + + Args: + model: model to extract features from + return_nodes: node names to return features from (dict or list) + squeeze_out: if only one output, and output in list format, flatten to single tensor + """ + def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): + super().__init__() + self.squeeze_out = squeeze_out + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: + out = list(self.graph_module(x).values()) + if self.squeeze_out and len(out) == 1: + return out[0] + return out diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 3b6f90a4..684d6651 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -250,7 +250,6 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): ) return build_model_with_cfg( GhostNet, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(flatten_sequential=True), **model_kwargs) diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 027a10b5..a1e73554 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -58,10 +58,7 @@ default_cfgs = { def _create_resnet(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - ResNet, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index fbd668a5..6a2168c9 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -234,7 +234,6 @@ class Xception65(nn.Module): def _create_gluon_xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( Xception65, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(feature_cls='hook'), **kwargs) diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 9988a044..a4e42de0 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -5,7 +5,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .efficientnet_blocks import SqueezeExcite from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels -from .helpers import build_model_with_cfg, default_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import get_act_fn from .mobilenetv3 import MobileNetV3, MobileNetV3Features from .registry import register_model @@ -59,12 +59,11 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): model_cls = MobileNetV3Features model = build_model_with_cfg( model_cls, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_strict=not features_only, kwargs_filter=kwargs_filter, **model_kwargs) if features_only: - model.default_cfg = default_cfg_for_features(model.default_cfg) + model.default_cfg = pretrained_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 16ce64d0..169fe884 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -7,7 +7,7 @@ import os import math from collections import OrderedDict from copy import deepcopy -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Dict import torch import torch.nn as nn @@ -17,12 +17,28 @@ from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf from .layers import Conv2dSame, Linear +from .registry import get_pretrained_cfg _logger = logging.getLogger(__name__) -def load_state_dict(checkpoint_path, use_ema=False): +# Global variables for rarely used pretrained checkpoint download progress and hash check. +# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle. +_DOWNLOAD_PROGRESS = False +_CHECK_HASH = False + + +def clean_state_dict(state_dict): + # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training + cleaned_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k + cleaned_state_dict[name] = v + return cleaned_state_dict + + +def load_state_dict(checkpoint_path, use_ema=True): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = '' @@ -35,16 +51,7 @@ def load_state_dict(checkpoint_path, use_ema=False): state_dict_key = 'state_dict' elif 'model' in checkpoint: state_dict_key = 'model' - if state_dict_key: - state_dict = checkpoint[state_dict_key] - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - # strip `module.` prefix - name = k[7:] if k.startswith('module') else k - new_state_dict[name] = v - state_dict = new_state_dict - else: - state_dict = checkpoint + state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint) _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) return state_dict else: @@ -52,7 +59,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): # numpy checkpoint, try to load via model specific load_pretrained fn if hasattr(model, 'load_pretrained'): @@ -71,11 +78,8 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if log_info: _logger.info('Restoring model state from checkpoint...') - new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict'].items(): - name = k[7:] if k.startswith('module') else k - new_state_dict[name] = v - model.load_state_dict(new_state_dict) + state_dict = clean_state_dict(checkpoint['state_dict']) + model.load_state_dict(state_dict) if optimizer is not None and 'optimizer' in checkpoint: if log_info: @@ -104,7 +108,50 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, raise FileNotFoundError() -def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): +def _resolve_pretrained_source(pretrained_cfg): + cfg_source = pretrained_cfg.get('source', '') + pretrained_url = pretrained_cfg.get('url', None) + pretrained_file = pretrained_cfg.get('file', None) + hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from + load_from = '' + pretrained_loc = '' + if cfg_source == 'hf-hub' and has_hf_hub(necessary=True): + # hf-hub specified as source via model identifier + load_from = 'hf-hub' + assert hf_hub_id + else: + # default source == timm or unspecified + if pretrained_file: + load_from = 'file' + pretrained_loc = pretrained_file + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + elif hf_hub_id and has_hf_hub(necessary=False): + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + return load_from, pretrained_loc + + +def set_pretrained_download_progress(enable=True): + """ Set download progress for pretrained weights on/off (globally). """ + global _DOWNLOAD_PROGRESS + _DOWNLOAD_PROGRESS = enable + + +def set_pretrained_check_hash(enable=True): + """ Set hash checking for pretrained weights on/off (globally). """ + global _CHECK_HASH + _CHECK_HASH = enable + + +def load_custom_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + load_fn: Optional[Callable] = None, +): r"""Loads a custom (read non .pth) weight file Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls @@ -116,7 +163,7 @@ def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False Args: model: The instantiated model to load weights into - default_cfg (dict): Default pretrained model cfg + pretrained_cfg (dict): Default pretrained model cfg load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named 'laod_pretrained' on the model will be called if it exists progress (bool, optional): whether or not to display a progress bar to stderr. Default: False @@ -125,17 +172,20 @@ def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} - pretrained_url = default_cfg.get('url', None) - if not pretrained_url: + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if not load_from: _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) + if load_from == 'hf-hub': # FIXME + _logger.warning("Hugging Face hub not currently supported for custom load pretrained models.") + elif load_from == 'url': + pretrained_loc = download_cached_file(pretrained_loc, check_hash=_CHECK_HASH, progress=_DOWNLOAD_PROGRESS) if load_fn is not None: - load_fn(model, cached_file) + load_fn(model, pretrained_loc) elif hasattr(model, 'load_pretrained'): - model.load_pretrained(cached_file) + model.load_pretrained(pretrained_loc) else: _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") @@ -165,31 +215,41 @@ def adapt_input_conv(in_chans, conv_weight): return conv_weight -def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): +def load_pretrained( + model: nn.Module, + pretrained_cfg: Optional[Dict] = None, + num_classes: int = 1000, + in_chans: int = 3, + filter_fn: Optional[Callable] = None, + strict: bool = True, +): """ Load pretrained checkpoint Args: model (nn.Module) : PyTorch model module - default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset + pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset num_classes (int): num_classes for model in_chans (int): in_chans for model filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) strict (bool): strict load of checkpoint - progress (bool): enable progress bar for weight download """ - default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} - pretrained_url = default_cfg.get('url', None) - hf_hub_id = default_cfg.get('hf_hub', None) - if not pretrained_url and not hf_hub_id: - _logger.warning("No pretrained weights exist for this model. Using random initialization.") + pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) or {} + load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg) + if load_from == 'file': + _logger.info(f'Loading pretrained weights from file ({pretrained_loc})') + state_dict = load_state_dict(pretrained_loc) + elif load_from == 'url': + _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') + state_dict = load_state_dict_from_url( + pretrained_loc, map_location='cpu', progress=_DOWNLOAD_PROGRESS, check_hash=_CHECK_HASH) + elif load_from == 'hf-hub': + _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') + state_dict = load_state_dict_from_hf(pretrained_loc) + else: + _logger.warning("No pretrained weights exist or were found for this model. Using random initialization.") return - if pretrained_url: - _logger.info(f'Loading pretrained weights from url ({pretrained_url})') - state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') - elif hf_hub_id and has_hf_hub(necessary=True): - _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})') - state_dict = load_state_dict_from_hf(hf_hub_id) + if filter_fn is not None: # for backwards compat with filter fn that take one arg, try one first, the two try: @@ -197,7 +257,7 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte except TypeError: state_dict = filter_fn(state_dict, model) - input_convs = default_cfg.get('first_conv', None) + input_convs = pretrained_cfg.get('first_conv', None) if input_convs is not None and in_chans != 3: if isinstance(input_convs, str): input_convs = (input_convs,) @@ -213,12 +273,12 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte _logger.warning( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - classifiers = default_cfg.get('classifier', None) - label_offset = default_cfg.get('label_offset', 0) + classifiers = pretrained_cfg.get('classifier', None) + label_offset = pretrained_cfg.get('label_offset', 0) if classifiers is not None: if isinstance(classifiers, str): classifiers = (classifiers,) - if num_classes != default_cfg['num_classes']: + if num_classes != pretrained_cfg['num_classes']: for classifier_name in classifiers: # completely discard fully connected if model num_classes doesn't match pretrained weights del state_dict[classifier_name + '.weight'] @@ -333,43 +393,43 @@ def adapt_model_from_file(parent_module, model_variant): return adapt_model_from_string(parent_module, f.read().strip()) -def default_cfg_for_features(default_cfg): - default_cfg = deepcopy(default_cfg) +def pretrained_cfg_for_features(pretrained_cfg): + pretrained_cfg = deepcopy(pretrained_cfg) # remove default pretrained cfg fields that don't have much relevance for feature backbone to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? for tr in to_remove: - default_cfg.pop(tr, None) - return default_cfg + pretrained_cfg.pop(tr, None) + return pretrained_cfg -def overlay_external_default_cfg(default_cfg, kwargs): - """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. - """ - external_default_cfg = kwargs.pop('external_default_cfg', None) - if external_default_cfg: - default_cfg.pop('url', None) # url should come from external cfg - default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg - default_cfg.update(external_default_cfg) +# def overlay_external_pretrained_cfg(pretrained_cfg, kwargs): +# """ Overlay 'external_pretrained_cfg' in kwargs on top of pretrained_cfg arg. +# """ +# external_pretrained_cfg = kwargs.pop('external_pretrained_cfg', None) +# if external_pretrained_cfg: +# pretrained_cfg.pop('url', None) # url should come from external cfg +# pretrained_cfg.pop('hf_hub', None) # hf hub id should come from external cfg +# pretrained_cfg.update(external_pretrained_cfg) -def set_default_kwargs(kwargs, names, default_cfg): +def set_default_kwargs(kwargs, names, pretrained_cfg): for n in names: # for legacy reasons, model __init__args uses img_size + in_chans as separate args while - # default_cfg has one input_size=(C, H ,W) entry + # pretrained_cfg has one input_size=(C, H ,W) entry if n == 'img_size': - input_size = default_cfg.get('input_size', None) + input_size = pretrained_cfg.get('input_size', None) if input_size is not None: assert len(input_size) == 3 kwargs.setdefault(n, input_size[-2:]) elif n == 'in_chans': - input_size = default_cfg.get('input_size', None) + input_size = pretrained_cfg.get('input_size', None) if input_size is not None: assert len(input_size) == 3 kwargs.setdefault(n, input_size[0]) else: - default_val = default_cfg.get(n, None) + default_val = pretrained_cfg.get(n, None) if default_val is not None: - kwargs.setdefault(n, default_cfg[n]) + kwargs.setdefault(n, pretrained_cfg[n]) def filter_kwargs(kwargs, names): @@ -379,36 +439,46 @@ def filter_kwargs(kwargs, names): kwargs.pop(n, None) -def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): +def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): """ Update the default_cfg and kwargs before passing to model - FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs - could/should be replaced by an improved configuration mechanism - Args: - default_cfg: input default_cfg (updated in-place) + pretrained_cfg: input pretrained cfg (updated in-place) kwargs: keyword args passed to model build fn (updated in-place) kwargs_filter: keyword arg keys that must be removed before model __init__ """ - # Overlay default cfg values from `external_default_cfg` if it exists in kwargs - overlay_external_default_cfg(default_cfg, kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') - if default_cfg.get('fixed_input_size', False): + if pretrained_cfg.get('fixed_input_size', False): # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size default_kwarg_names += ('img_size',) - set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) + set_default_kwargs(kwargs, names=default_kwarg_names, pretrained_cfg=pretrained_cfg) # Filter keyword args for task specific model variants (some 'features only' models, etc.) filter_kwargs(kwargs, names=kwargs_filter) +def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): + if pretrained_cfg and isinstance(pretrained_cfg, dict): + # highest priority, pretrained_cfg available and passed explicitly + return deepcopy(pretrained_cfg) + if kwargs and 'pretrained_cfg' in kwargs: + # next highest, pretrained_cfg in a kwargs dict, pop and return + pretrained_cfg = kwargs.pop('pretrained_cfg', {}) + if pretrained_cfg: + return deepcopy(pretrained_cfg) + # lookup pretrained cfg in model registry by variant + pretrained_cfg = get_pretrained_cfg(variant) + assert pretrained_cfg + return pretrained_cfg + + def build_model_with_cfg( model_cls: Callable, variant: str, pretrained: bool, - default_cfg: dict, + pretrained_cfg: Optional[Dict] = None, model_cfg: Optional[Any] = None, - feature_cfg: Optional[dict] = None, + feature_cfg: Optional[Dict] = None, pretrained_strict: bool = True, pretrained_filter_fn: Optional[Callable] = None, pretrained_custom_load: bool = False, @@ -417,7 +487,7 @@ def build_model_with_cfg( """ Build model with specified default_cfg and optional model_cfg This helper fn aids in the construction of a model including: - * handling default_cfg and associated pretained weight loading + * handling default_cfg and associated pretrained weight loading * passing through optional model_cfg for models with config based arch spec * features_only model adaptation * pruning config / model adaptation @@ -426,7 +496,7 @@ def build_model_with_cfg( model_cls (nn.Module): model class variant (str): model variant name pretrained (bool): load pretrained weights - default_cfg (dict): model's default pretrained/task config + pretrained_cfg (dict): model's pretrained weight/task config model_cfg (Optional[Dict]): model's architecture config feature_cfg (Optional[Dict]: feature extraction adapter config pretrained_strict (bool): load pretrained weights strictly @@ -438,9 +508,11 @@ def build_model_with_cfg( pruned = kwargs.pop('pruned', False) features = False feature_cfg = feature_cfg or {} - default_cfg = deepcopy(default_cfg) if default_cfg else {} - update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) - default_cfg.setdefault('architecture', variant) + + # resolve and update model pretrained config and model kwargs + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=pretrained_cfg) + update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter) + pretrained_cfg.setdefault('architecture', variant) # Setup for feature extraction wrapper done at end of this fn if kwargs.pop('features_only', False): @@ -451,7 +523,8 @@ def build_model_with_cfg( # Build the model model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) - model.default_cfg = default_cfg + model.pretrained_cfg = pretrained_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat if pruned: model = adapt_model_from_file(model, variant) @@ -460,10 +533,12 @@ def build_model_with_cfg( num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: if pretrained_custom_load: - load_custom_pretrained(model) + # FIXME improve custom load trigger + load_custom_pretrained(model, pretrained_cfg=pretrained_cfg) else: load_pretrained( model, + pretrained_cfg=pretrained_cfg, num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, @@ -483,7 +558,8 @@ def build_model_with_cfg( else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) - model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg + model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg + model.default_cfg = model.pretrained_cfg # alias for backwards compat return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index c56964f6..32b4eb32 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .features import FeatureInfo -from .helpers import build_model_with_cfg, default_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import create_classifier from .registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE @@ -781,13 +781,13 @@ def _create_hrnet(variant, pretrained, **model_kwargs): features_only = True model = build_model_with_cfg( model_cls, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=cfg_cls[variant], pretrained_strict=not features_only, kwargs_filter=kwargs_filter, **model_kwargs) if features_only: - model.default_cfg = default_cfg_for_features(model.default_cfg) + model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) + model.default_cfg = model.pretrained_cfg # backwards compat return model diff --git a/timm/models/hub.py b/timm/models/hub.py index 65e7ba9a..dd7870cb 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -62,6 +62,7 @@ def has_hf_hub(necessary=False): def hf_split(hf_id): + # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme rev_split = hf_id.split('@') assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' hf_model_id = rev_split[0] @@ -84,10 +85,11 @@ def _download_from_hf(model_id: str, filename: str): def load_model_config_from_hf(model_id: str): assert has_hf_hub(True) cached_file = _download_from_hf(model_id, 'config.json') - default_cfg = load_cfg_from_json(cached_file) - default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation - model_name = default_cfg.get('architecture') - return default_cfg, model_name + pretrained_cfg = load_cfg_from_json(cached_file) + pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation + pretrained_cfg['source'] = 'hf-hub' + model_name = pretrained_cfg.get('architecture') + return pretrained_cfg, model_name def load_state_dict_from_hf(model_id: str): @@ -107,7 +109,7 @@ def save_for_hf(model, save_directory, model_config=None): torch.save(model.state_dict(), weights_path) config_path = save_directory / 'config.json' - hf_config = model.default_cfg + hf_config = model.pretrained_cfg hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes) hf_config['num_features'] = model_config.pop('num_features', model.num_features) hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])]) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 71672849..d4aced05 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -335,10 +335,7 @@ class InceptionResnetV2(nn.Module): def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - InceptionResnetV2, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(InceptionResnetV2, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index cbb1107b..eb6fb2cf 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, resolve_pretrained_cfg from .registry import register_model from .layers import trunc_normal_, create_classifier, Linear @@ -424,18 +424,19 @@ class InceptionV3Aux(InceptionV3): def _create_inception_v3(variant, pretrained=False, **kwargs): - default_cfg = default_cfgs[variant] + pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) aux_logits = kwargs.pop('aux_logits', False) if aux_logits: assert not kwargs.pop('features_only', False) model_cls = InceptionV3Aux - load_strict = default_cfg['has_aux'] + load_strict = pretrained_cfg['has_aux'] else: model_cls = InceptionV3 - load_strict = not default_cfg['has_aux'] + load_strict = not pretrained_cfg['has_aux'] + return build_model_with_cfg( model_cls, variant, pretrained, - default_cfg=default_cfg, + pretrained_cfg=pretrained_cfg, pretrained_strict=load_strict, **kwargs) diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index cc899e15..f95db28e 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -306,7 +306,6 @@ class InceptionV4(nn.Module): def _create_inception_v4(variant, pretrained=False, **kwargs): return build_model_with_cfg( InceptionV4, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(flatten_sequential=True), **kwargs) diff --git a/timm/models/levit.py b/timm/models/levit.py index 9987e4ba..fcb237dd 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -32,7 +32,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg from .layers import to_ntuple, get_act_layer from .vision_transformer import trunc_normal_ from .registry import register_model @@ -554,7 +554,6 @@ def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwar model_cfg = dict(**model_cfgs[variant], **kwargs) model = build_model_with_cfg( Levit, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **model_cfg) #if fuse: diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 727b655b..dc5d70a4 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -46,7 +46,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .registry import register_model @@ -360,7 +360,6 @@ def _create_mixer(variant, pretrained=False, **kwargs): model = build_model_with_cfg( MlpMixer, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index c97223ed..86f599c1 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -19,7 +19,7 @@ from .efficientnet_blocks import SqueezeExcite from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg, default_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer from .registry import register_model @@ -239,12 +239,11 @@ def _create_mnv3(variant, pretrained=False, **kwargs): model_cls = MobileNetV3Features model = build_model_with_cfg( model_cls, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_strict=not features_only, kwargs_filter=kwargs_filter, **kwargs) if features_only: - model.default_cfg = default_cfg_for_features(model.default_cfg) + model.default_cfg = pretrained_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 9c257d9d..571312d5 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -554,7 +554,6 @@ class NASNetALarge(nn.Module): def _create_nasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( NASNetALarge, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model **kwargs) diff --git a/timm/models/nest.py b/timm/models/nest.py index 22cf6099..6b9be873 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -395,11 +395,9 @@ def checkpoint_filter_fn(state_dict, model): return state_dict -def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): - default_cfg = default_cfg or default_cfgs[variant] +def _create_nest(variant, pretrained=False, **kwargs): model = build_model_with_cfg( Nest, variant, pretrained, - default_cfg=default_cfg, feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True), pretrained_filter_fn=checkpoint_filter_fn, **kwargs) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 973cbd66..dd15ff14 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -106,7 +106,7 @@ default_cfgs = dict( pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), eca_nfnet_l0=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth', - hf_hub='timm/eca_nfnet_l0', + hf_hub_id='timm/eca_nfnet_l0', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), eca_nfnet_l1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', @@ -592,7 +592,6 @@ def _create_normfreenet(variant, pretrained=False, **kwargs): feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( NormFreeNet, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=model_cfg, feature_cfg=feature_cfg, **kwargs) diff --git a/timm/models/pit.py b/timm/models/pit.py index 460824e2..843880e7 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -21,7 +21,7 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg from .layers import trunc_normal_, to_2tuple from .registry import register_model from .vision_transformer import Block @@ -262,7 +262,6 @@ def _create_pit(variant, pretrained=False, **kwargs): model = build_model_with_cfg( PoolingVisionTransformer, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 208bccf3..4aef89f4 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -335,7 +335,6 @@ class PNASNet5Large(nn.Module): def _create_pnasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( PNASNet5Large, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model **kwargs) diff --git a/timm/models/registry.py b/timm/models/registry.py index f92219b2..9f58060f 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -9,13 +9,13 @@ from collections import defaultdict from copy import deepcopy __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', - 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] + 'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained'] _module_to_models = defaultdict(set) # dict of sets to check membership of model in module _model_to_module = {} # mapping of model names to module names _model_entrypoints = {} # mapping of model names to entrypoint fns _model_has_pretrained = set() # set of model names that have pretrained weight url present -_model_default_cfgs = dict() # central repo for model default_cfgs +_model_pretrained_cfgs = dict() # central repo for model default_cfgs def register_model(fn): @@ -35,13 +35,18 @@ def register_model(fn): _model_entrypoints[model_name] = fn _model_to_module[model_name] = module_name _module_to_models[module_name].add(model_name) - has_pretrained = False # check if model has a pretrained url to allow filtering on this + has_valid_pretrained = False # check if model has a pretrained url to allow filtering on this if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos - has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] - _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) - if has_pretrained: + cfg = mod.default_cfgs[model_name] + has_valid_pretrained = ( + ('url' in cfg and 'http' in cfg['url']) or + ('file' in cfg and cfg['file']) or + ('hf_hub_id' in cfg and cfg['hf_hub_id']) + ) + _model_pretrained_cfgs[model_name] = mod.default_cfgs[model_name] + if has_valid_pretrained: _model_has_pretrained.add(model_name) return fn @@ -87,7 +92,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name if pretrained: models = _model_has_pretrained.intersection(models) if name_matches_cfg: - models = set(_model_default_cfgs).intersection(models) + models = set(_model_pretrained_cfgs).intersection(models) return list(sorted(models, key=_natural_key)) @@ -120,30 +125,35 @@ def is_model_in_modules(model_name, module_names): return any(model_name in _module_to_models[n] for n in module_names) -def has_model_default_key(model_name, cfg_key): +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained + + +def get_pretrained_cfg(model_name): + if model_name in _model_pretrained_cfgs: + return deepcopy(_model_pretrained_cfgs[model_name]) + return {} + + +def has_pretrained_cfg_key(model_name, cfg_key): """ Query model default_cfgs for existence of a specific key. """ - if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: + if model_name in _model_pretrained_cfgs and cfg_key in _model_pretrained_cfgs[model_name]: return True return False -def is_model_default_key(model_name, cfg_key): +def is_pretrained_cfg_key(model_name, cfg_key): """ Return truthy value for specified model default_cfg key, False if does not exist. """ - if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): + if model_name in _model_pretrained_cfgs and _model_pretrained_cfgs[model_name].get(cfg_key, False): return True return False -def get_model_default_value(model_name, cfg_key): +def get_pretrained_cfg_value(model_name, cfg_key): """ Get a specific model default_cfg value by key. None if it doesn't exist. """ - if model_name in _model_default_cfgs: - return _model_default_cfgs[model_name].get(cfg_key, None) - else: - return None - - -def is_model_pretrained(model_name): - return model_name in _model_has_pretrained + if model_name in _model_pretrained_cfgs: + return _model_pretrained_cfgs[model_name].get(cfg_key, None) + return None \ No newline at end of file diff --git a/timm/models/regnet.py b/timm/models/regnet.py index e6d0e646..5497b74b 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -472,7 +472,6 @@ def _filter_fn(state_dict): def _create_regnet(variant, pretrained, **kwargs): return build_model_with_cfg( RegNet, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], pretrained_filter_fn=_filter_fn, **kwargs) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 282baba3..109fee1f 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -133,10 +133,7 @@ class Bottle2neck(nn.Module): def _create_res2net(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - ResNet, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/resnest.py b/timm/models/resnest.py index f3119807..7bbe58e0 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -135,10 +135,7 @@ class ResNestBottleneck(nn.Module): def _create_resnest(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - ResNet, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index cb71c464..6305dbcb 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -680,10 +680,7 @@ class ResNet(nn.Module): def _create_resnet(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - ResNet, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 2c6fb9a0..dbf3e9cc 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -482,7 +482,6 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( ResNetV2, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, pretrained_custom_load='_bit' in variant, **kwargs) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 1cb8e2f5..df1e0afe 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -186,7 +186,6 @@ def _create_rexnet(variant, pretrained, **kwargs): feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( ReXNetV1, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 1f3379db..5bc1b78a 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -321,7 +321,6 @@ def _create_selecsls(variant, pretrained, **kwargs): # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? return build_model_with_cfg( SelecSLS, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=cfg, feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **kwargs) diff --git a/timm/models/senet.py b/timm/models/senet.py index 3d0ba7b3..d07f01ad 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -397,10 +397,7 @@ class SENet(nn.Module): def _create_senet(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - SENet, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(SENet, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 87520fbe..0ca38b87 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -133,10 +133,7 @@ class SelectiveKernelBottleneck(nn.Module): def _create_skresnet(variant, pretrained=False, **kwargs): - return build_model_with_cfg( - ResNet, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + return build_model_with_cfg(ResNet, variant, pretrained, **kwargs) @register_model diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 92057902..2b874737 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -22,7 +22,7 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import _assert from .registry import register_model @@ -542,23 +542,9 @@ class SwinTransformer(nn.Module): return x -def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - +def _create_swin_transformer(variant, pretrained=False, **kwargs): model = build_model_with_cfg( SwinTransformer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index d52f9ce6..8affc13e 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -248,7 +248,6 @@ def _create_tnt(variant, pretrained=False, **kwargs): model = build_model_with_cfg( TNT, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 372bfb7b..77c96aee 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -250,7 +250,6 @@ class TResNet(nn.Module): def _create_tresnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( TResNet, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs) diff --git a/timm/models/twins.py b/timm/models/twins.py index 67a939d4..6894e5c2 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -369,10 +369,7 @@ def _create_twins(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( - Twins, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + model = build_model_with_cfg(Twins, variant, pretrained, **kwargs) return model diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 11f6d0ea..0b1c16ba 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -183,7 +183,6 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) model = build_model_with_cfg( VGG, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=cfgs[cfg], feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), pretrained_filter_fn=_filter_fn, diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 37284c9d..66d50dc7 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -12,7 +12,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier from .registry import register_model @@ -318,10 +318,7 @@ class Visformer(nn.Module): def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( - Visformer, variant, pretrained, - default_cfg=default_cfgs[variant], - **kwargs) + model = build_model_with_cfg(Visformer, variant, pretrained, **kwargs) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 6e568abf..e7ce1866 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -33,7 +33,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -132,7 +132,7 @@ default_cfgs = { num_classes=21843), 'vit_huge_patch14_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', - hf_hub='timm/vit_huge_patch14_224_in21k', + hf_hub_id='timm/vit_huge_patch14_224_in21k', num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) @@ -525,13 +525,13 @@ def checkpoint_filter_fn(state_dict, model): return out_dict -def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - default_cfg = default_cfg or default_cfgs[variant] +def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') # NOTE this extra code to support handling of repr size for in21k pretrained models - default_num_classes = default_cfg['num_classes'] + pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + default_num_classes = pretrained_cfg['num_classes'] num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: @@ -542,10 +542,10 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw model = build_model_with_cfg( VisionTransformer, variant, pretrained, - default_cfg=default_cfg, + pretrained_cfg=pretrained_cfg, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, - pretrained_custom_load='npz' in default_cfg['url'], + pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs) return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index d5f0a537..df0fe381 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -143,8 +143,7 @@ class HybridEmbed(nn.Module): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set - return _create_vision_transformer( - variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) + return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs) def _resnetv2(layers=(3, 4, 9), **kwargs): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index c9d8c6ff..507e4bb5 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -339,7 +339,6 @@ class VovNet(nn.Module): def _create_vovnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( VovNet, variant, pretrained, - default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], feature_cfg=dict(flatten_sequential=True), **kwargs) diff --git a/timm/models/xception.py b/timm/models/xception.py index 86f558cb..f9428d07 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -222,7 +222,6 @@ class Xception(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( Xception, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(feature_cls='hook'), **kwargs) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 457dc11a..e1156674 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -227,7 +227,6 @@ class XceptionAligned(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( XceptionAligned, variant, pretrained, - default_cfg=default_cfgs[variant], feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index ac5e802c..a7750d12 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -469,9 +469,8 @@ def checkpoint_filter_fn(state_dict, model): def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): - default_cfg = default_cfg or default_cfgs[variant] model = build_model_with_cfg( - XCiT, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) + XCiT, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) return model diff --git a/train.py b/train.py index 10d839be..8ba8a2a6 100755 --- a/train.py +++ b/train.py @@ -234,8 +234,6 @@ parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) -parser.add_argument('--bn-tf', action='store_true', default=False, - help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, @@ -375,7 +373,6 @@ def main(): drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, - bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, @@ -443,6 +440,7 @@ def main(): if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') + # optionally resume from a checkpoint resume_epoch = None if args.resume: diff --git a/validate.py b/validate.py index a99e5b5c..8098e1d9 100755 --- a/validate.py +++ b/validate.py @@ -216,7 +216,9 @@ def validate(args): input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) - model(input) + with amp_autocast(): + model(input) + end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: