diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 4c1e4d3f..7de4c8c4 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -33,9 +33,8 @@ from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights -from .feature_hooks import FeatureHooks -from .features import FeatureInfo -from .helpers import load_pretrained, adapt_model_from_file +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, create_conv2d from .registry import register_model @@ -471,29 +470,19 @@ class EfficientNetFeatures(nn.Module): return self.feature_hooks.get_output(x.device) -def _create_effnet(model_kwargs, default_cfg, pretrained=False): +def _create_effnet(model_kwargs, variant, pretrained=False): if model_kwargs.pop('features_only', False): load_strict = False model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) - model_class = EfficientNetFeatures + model_cls = EfficientNetFeatures else: load_strict = True - model_class = EfficientNet - variant = model_kwargs.pop('variant', '') - model = model_class(**model_kwargs) - model.default_cfg = default_cfg - if '_pruned' in variant: - model = adapt_model_from_file(model, variant) - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=model_kwargs.get('num_classes', 0), - in_chans=model_kwargs.get('in_chans', 3), - strict=load_strict) - return model + model_cls = EfficientNet + return build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=load_strict, **model_kwargs) def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): @@ -528,7 +517,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -564,7 +553,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -593,7 +582,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs,variant, pretrained) return model @@ -622,7 +611,7 @@ def _gen_mobilenet_v2( act_layer=resolve_act_layer(kwargs, 'relu6'), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -652,7 +641,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -687,7 +676,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -731,10 +720,9 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre channel_multiplier=channel_multiplier, act_layer=resolve_act_layer(kwargs, 'swish'), norm_kwargs=resolve_bn_args(kwargs), - variant=variant, **kwargs, ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -763,7 +751,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 act_layer=resolve_act_layer(kwargs, 'relu'), **kwargs, ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -793,7 +781,7 @@ def _gen_efficientnet_condconv( act_layer=resolve_act_layer(kwargs, 'swish'), **kwargs, ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -834,7 +822,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0 norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -867,7 +855,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -900,7 +888,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) + model = _create_effnet(model_kwargs, variant, pretrained) return model @@ -1239,7 +1227,7 @@ def efficientnet_b1_pruned(pretrained=False, **kwargs): kwargs['pad_type'] = 'same' variant = 'efficientnet_b1_pruned' model = _gen_efficientnet( - variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs) return model @@ -1249,7 +1237,8 @@ def efficientnet_b2_pruned(pretrained=False, **kwargs): kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True, + pretrained=pretrained, **kwargs) return model @@ -1259,7 +1248,8 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs): kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True, + pretrained=pretrained, **kwargs) return model diff --git a/timm/models/feature_hooks.py b/timm/models/feature_hooks.py deleted file mode 100644 index b489b6f5..00000000 --- a/timm/models/feature_hooks.py +++ /dev/null @@ -1,46 +0,0 @@ -""" PyTorch Feature Hook Helper - -This class helps gather features from a network via hooks specified on the module name. - -Hacked together by Ross Wightman -""" -import torch - -from collections import defaultdict, OrderedDict -from functools import partial, partialmethod -from typing import List - - -class FeatureHooks: - - def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'): - # setup feature hooks - modules = {k: v for k, v in named_modules} - for i, h in enumerate(hooks): - hook_name = h['module'] - m = modules[hook_name] - hook_id = out_map[i] if out_map else hook_name - hook_fn = partial(self._collect_output_hook, hook_id) - hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type - if hook_type == 'forward_pre': - m.register_forward_pre_hook(hook_fn) - elif hook_type == 'forward': - m.register_forward_hook(hook_fn) - else: - assert False, "Unsupported hook type" - self._feature_outputs = defaultdict(OrderedDict) - self.out_as_dict = out_as_dict - - def _collect_output_hook(self, hook_id, *args): - x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre - if isinstance(x, tuple): - x = x[0] # unwrap input tuple - self._feature_outputs[x.device][hook_id] = x - - def get_output(self, device) -> List[torch.tensor]: - if self.out_as_dict: - output = self._feature_outputs[device] - else: - output = list(self._feature_outputs[device].values()) - self._feature_outputs[device] = OrderedDict() # clear after reading - return output diff --git a/timm/models/features.py b/timm/models/features.py index 2c210734..46842f5d 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -5,15 +5,14 @@ and provide a common interface for describing them. Hacked together by Ross Wightman """ -from collections import OrderedDict -from typing import Dict, List, Tuple, Any +from collections import OrderedDict, defaultdict from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple, Any import torch import torch.nn as nn -from .feature_hooks import FeatureHooks - class FeatureInfo: @@ -75,6 +74,41 @@ class FeatureInfo: return len(self.info) +class FeatureHooks: + + def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + self.out_as_dict = out_as_dict + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript? + if self.out_as_dict: + output = self._feature_outputs[device] + else: + output = list(self._feature_outputs[device].values()) + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + def _module_list(module, flatten_sequential=False): # a yield/iter would be better for this but wouldn't be compatible with torchscript ml = [] diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 48288223..b99f4f7a 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -16,9 +16,8 @@ from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights -from .feature_hooks import FeatureHooks -from .features import FeatureInfo -from .helpers import load_pretrained +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model @@ -215,27 +214,19 @@ class MobileNetV3Features(nn.Module): return self.feature_hooks.get_output(x.device) -def _create_mnv3(model_kwargs, default_cfg, pretrained=False): +def _create_mnv3(model_kwargs, variant, pretrained=False): if model_kwargs.pop('features_only', False): load_strict = False model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) - model_class = MobileNetV3Features + model_cls = MobileNetV3Features else: load_strict = True - model_class = MobileNetV3 - - model = model_class(**model_kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained( - model, - default_cfg, - num_classes=model_kwargs.get('num_classes', 0), - in_chans=model_kwargs.get('in_chans', 3), - strict=load_strict) - return model + model_cls = MobileNetV3 + return build_model_with_cfg( + model_cls, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_strict=load_strict, **model_kwargs) def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): @@ -272,7 +263,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), **kwargs, ) - model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained) + model = _create_mnv3(model_kwargs, variant, pretrained) return model @@ -368,7 +359,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), **kwargs, ) - model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained) + model = _create_mnv3(model_kwargs, variant, pretrained) return model