From 867a0e5a049516b9597e05751799a2502fae0ec8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 3 Dec 2020 10:24:35 -0800 Subject: [PATCH] Add default_cfg back to models wrapped in feature extraction module as per discussion in #294. --- timm/models/efficientnet.py | 6 ++++-- timm/models/helpers.py | 10 ++++++++++ timm/models/hrnet.py | 6 ++++-- timm/models/mobilenetv3.py | 6 ++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index a61a6f47..7eeda3ca 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_conv2d, create_classifier from .registry import register_model @@ -462,9 +462,11 @@ def _create_effnet(model_kwargs, variant, pretrained=False): else: load_strict = True model_cls = EfficientNet - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_strict=load_strict, **model_kwargs) + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 0bc6d2f7..77b98dc6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant): return adapt_model_from_string(parent_module, f.read().strip()) +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + def build_model_with_cfg( model_cls: Callable, variant: str, @@ -296,5 +305,6 @@ def build_model_with_cfg( else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 2e8757b5..d246812e 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .features import FeatureInfo -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_classifier from .registry import register_model from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE @@ -779,9 +779,11 @@ def _create_hrnet(variant, pretrained, **model_kwargs): model_kwargs['num_classes'] = 0 strict = False - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model @register_model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index ea930308..afded75f 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -17,7 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model @@ -211,9 +211,11 @@ def _create_mnv3(model_kwargs, variant, pretrained=False): else: load_strict = True model_cls = MobileNetV3 - return build_model_with_cfg( + model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_strict=load_strict, **model_kwargs) + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):