Move FeatureHooks into features.py, switch EfficientNet, MobileNetV3 to use build model helper

pull/175/head
Ross Wightman 4 years ago
parent 9eba134d79
commit 6eec3fb4a4

@ -33,9 +33,8 @@ from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks
from .features import FeatureInfo
from .helpers import load_pretrained, adapt_model_from_file
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, create_conv2d
from .registry import register_model
@ -471,29 +470,19 @@ class EfficientNetFeatures(nn.Module):
return self.feature_hooks.get_output(x.device)
def _create_effnet(model_kwargs, default_cfg, pretrained=False):
def _create_effnet(model_kwargs, variant, pretrained=False):
if model_kwargs.pop('features_only', False):
load_strict = False
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_class = EfficientNetFeatures
model_cls = EfficientNetFeatures
else:
load_strict = True
model_class = EfficientNet
variant = model_kwargs.pop('variant', '')
model = model_class(**model_kwargs)
model.default_cfg = default_cfg
if '_pruned' in variant:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
default_cfg,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=load_strict)
return model
model_cls = EfficientNet
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
@ -528,7 +517,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -564,7 +553,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -593,7 +582,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs,variant, pretrained)
return model
@ -622,7 +611,7 @@ def _gen_mobilenet_v2(
act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -652,7 +641,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -687,7 +676,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -731,10 +720,9 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'swish'),
norm_kwargs=resolve_bn_args(kwargs),
variant=variant,
**kwargs,
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -763,7 +751,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs,
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -793,7 +781,7 @@ def _gen_efficientnet_condconv(
act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs,
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -834,7 +822,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -867,7 +855,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -900,7 +888,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained)
model = _create_effnet(model_kwargs, variant, pretrained)
return model
@ -1239,7 +1227,7 @@ def efficientnet_b1_pruned(pretrained=False, **kwargs):
kwargs['pad_type'] = 'same'
variant = 'efficientnet_b1_pruned'
model = _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs)
return model
@ -1249,7 +1237,8 @@ def efficientnet_b2_pruned(pretrained=False, **kwargs):
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True,
pretrained=pretrained, **kwargs)
return model
@ -1259,7 +1248,8 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_efficientnet(
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True,
pretrained=pretrained, **kwargs)
return model

@ -1,46 +0,0 @@
""" PyTorch Feature Hook Helper
This class helps gather features from a network via hooks specified on the module name.
Hacked together by Ross Wightman
"""
import torch
from collections import defaultdict, OrderedDict
from functools import partial, partialmethod
from typing import List
class FeatureHooks:
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]:
if self.out_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output

@ -5,15 +5,14 @@ and provide a common interface for describing them.
Hacked together by Ross Wightman
"""
from collections import OrderedDict
from typing import Dict, List, Tuple, Any
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple, Any
import torch
import torch.nn as nn
from .feature_hooks import FeatureHooks
class FeatureInfo:
@ -75,6 +74,41 @@ class FeatureInfo:
return len(self.info)
class FeatureHooks:
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript?
if self.out_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output
def _module_list(module, flatten_sequential=False):
# a yield/iter would be better for this but wouldn't be compatible with torchscript
ml = []

@ -16,9 +16,8 @@ from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks
from .features import FeatureInfo
from .helpers import load_pretrained
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model
@ -215,27 +214,19 @@ class MobileNetV3Features(nn.Module):
return self.feature_hooks.get_output(x.device)
def _create_mnv3(model_kwargs, default_cfg, pretrained=False):
def _create_mnv3(model_kwargs, variant, pretrained=False):
if model_kwargs.pop('features_only', False):
load_strict = False
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_class = MobileNetV3Features
model_cls = MobileNetV3Features
else:
load_strict = True
model_class = MobileNetV3
model = model_class(**model_kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(
model,
default_cfg,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=load_strict)
return model
model_cls = MobileNetV3
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
@ -272,7 +263,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
**kwargs,
)
model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained)
model = _create_mnv3(model_kwargs, variant, pretrained)
return model
@ -368,7 +359,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
**kwargs,
)
model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained)
model = _create_mnv3(model_kwargs, variant, pretrained)
return model

Loading…
Cancel
Save