Add default_cfg back to models wrapped in feature extraction module as per discussion in #294.

pull/297/head
Ross Wightman 4 years ago
parent 4ca52d73d8
commit 867a0e5a04

@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_conv2d, create_classifier from .layers import create_conv2d, create_classifier
from .registry import register_model from .registry import register_model
@ -462,9 +462,11 @@ def _create_effnet(model_kwargs, variant, pretrained=False):
else: else:
load_strict = True load_strict = True
model_cls = EfficientNet model_cls = EfficientNet
return build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs) pretrained_strict=load_strict, **model_kwargs)
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

@ -251,6 +251,15 @@ def adapt_model_from_file(parent_module, model_variant):
return adapt_model_from_string(parent_module, f.read().strip()) return adapt_model_from_string(parent_module, f.read().strip())
def default_cfg_for_features(default_cfg):
default_cfg = deepcopy(default_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
for tr in to_remove:
default_cfg.pop(tr, None)
return default_cfg
def build_model_with_cfg( def build_model_with_cfg(
model_cls: Callable, model_cls: Callable,
variant: str, variant: str,
@ -296,5 +305,6 @@ def build_model_with_cfg(
else: else:
assert False, f'Unknown feature class {feature_cls}' assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model return model

@ -17,7 +17,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .features import FeatureInfo from .features import FeatureInfo
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_classifier from .layers import create_classifier
from .registry import register_model from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
@ -779,9 +779,11 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
model_kwargs['num_classes'] = 0 model_kwargs['num_classes'] = 0
strict = False strict = False
return build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
@register_model @register_model

@ -17,7 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model from .registry import register_model
@ -211,9 +211,11 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
else: else:
load_strict = True load_strict = True
model_cls = MobileNetV3 model_cls = MobileNetV3
return build_model_with_cfg( model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs) pretrained_strict=load_strict, **model_kwargs)
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):

Loading…
Cancel
Save