Move FeatureHooks into features.py, switch EfficientNet, MobileNetV3 to use build model helper

pull/175/head
Ross Wightman 4 years ago
parent 9eba134d79
commit 6eec3fb4a4

@ -33,9 +33,8 @@ from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks from .features import FeatureInfo, FeatureHooks
from .features import FeatureInfo from .helpers import build_model_with_cfg
from .helpers import load_pretrained, adapt_model_from_file
from .layers import SelectAdaptivePool2d, create_conv2d from .layers import SelectAdaptivePool2d, create_conv2d
from .registry import register_model from .registry import register_model
@ -471,29 +470,19 @@ class EfficientNetFeatures(nn.Module):
return self.feature_hooks.get_output(x.device) return self.feature_hooks.get_output(x.device)
def _create_effnet(model_kwargs, default_cfg, pretrained=False): def _create_effnet(model_kwargs, variant, pretrained=False):
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
load_strict = False load_strict = False
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0) model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None) model_kwargs.pop('head_conv', None)
model_class = EfficientNetFeatures model_cls = EfficientNetFeatures
else: else:
load_strict = True load_strict = True
model_class = EfficientNet model_cls = EfficientNet
variant = model_kwargs.pop('variant', '') return build_model_with_cfg(
model = model_class(**model_kwargs) model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model.default_cfg = default_cfg pretrained_strict=load_strict, **model_kwargs)
if '_pruned' in variant:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
default_cfg,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=load_strict)
return model
def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
@ -528,7 +517,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -564,7 +553,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -593,7 +582,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs,variant, pretrained)
return model return model
@ -622,7 +611,7 @@ def _gen_mobilenet_v2(
act_layer=resolve_act_layer(kwargs, 'relu6'), act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -652,7 +641,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -687,7 +676,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -731,10 +720,9 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
act_layer=resolve_act_layer(kwargs, 'swish'), act_layer=resolve_act_layer(kwargs, 'swish'),
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
variant=variant,
**kwargs, **kwargs,
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -763,7 +751,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
act_layer=resolve_act_layer(kwargs, 'relu'), act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs, **kwargs,
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -793,7 +781,7 @@ def _gen_efficientnet_condconv(
act_layer=resolve_act_layer(kwargs, 'swish'), act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs, **kwargs,
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -834,7 +822,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs, **kwargs,
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -867,7 +855,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -900,7 +888,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs **kwargs
) )
model = _create_effnet(model_kwargs, default_cfgs[variant], pretrained) model = _create_effnet(model_kwargs, variant, pretrained)
return model return model
@ -1239,7 +1227,7 @@ def efficientnet_b1_pruned(pretrained=False, **kwargs):
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
variant = 'efficientnet_b1_pruned' variant = 'efficientnet_b1_pruned'
model = _gen_efficientnet( model = _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs)
return model return model
@ -1249,7 +1237,8 @@ def efficientnet_b2_pruned(pretrained=False, **kwargs):
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
model = _gen_efficientnet( model = _gen_efficientnet(
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True,
pretrained=pretrained, **kwargs)
return model return model
@ -1259,7 +1248,8 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
model = _gen_efficientnet( model = _gen_efficientnet(
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True,
pretrained=pretrained, **kwargs)
return model return model

@ -1,46 +0,0 @@
""" PyTorch Feature Hook Helper
This class helps gather features from a network via hooks specified on the module name.
Hacked together by Ross Wightman
"""
import torch
from collections import defaultdict, OrderedDict
from functools import partial, partialmethod
from typing import List
class FeatureHooks:
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]:
if self.out_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output

@ -5,15 +5,14 @@ and provide a common interface for describing them.
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
from collections import OrderedDict from collections import OrderedDict, defaultdict
from typing import Dict, List, Tuple, Any
from copy import deepcopy from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple, Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from .feature_hooks import FeatureHooks
class FeatureInfo: class FeatureInfo:
@ -75,6 +74,41 @@ class FeatureInfo:
return len(self.info) return len(self.info)
class FeatureHooks:
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript?
if self.out_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output
def _module_list(module, flatten_sequential=False): def _module_list(module, flatten_sequential=False):
# a yield/iter would be better for this but wouldn't be compatible with torchscript # a yield/iter would be better for this but wouldn't be compatible with torchscript
ml = [] ml = []

@ -16,9 +16,8 @@ from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .feature_hooks import FeatureHooks from .features import FeatureInfo, FeatureHooks
from .features import FeatureInfo from .helpers import build_model_with_cfg
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model from .registry import register_model
@ -215,27 +214,19 @@ class MobileNetV3Features(nn.Module):
return self.feature_hooks.get_output(x.device) return self.feature_hooks.get_output(x.device)
def _create_mnv3(model_kwargs, default_cfg, pretrained=False): def _create_mnv3(model_kwargs, variant, pretrained=False):
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
load_strict = False load_strict = False
model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0) model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None) model_kwargs.pop('head_conv', None)
model_class = MobileNetV3Features model_cls = MobileNetV3Features
else: else:
load_strict = True load_strict = True
model_class = MobileNetV3 model_cls = MobileNetV3
return build_model_with_cfg(
model = model_class(**model_kwargs) model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model.default_cfg = default_cfg pretrained_strict=load_strict, **model_kwargs)
if pretrained:
load_pretrained(
model,
default_cfg,
num_classes=model_kwargs.get('num_classes', 0),
in_chans=model_kwargs.get('in_chans', 3),
strict=load_strict)
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
@ -272,7 +263,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
**kwargs, **kwargs,
) )
model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained) model = _create_mnv3(model_kwargs, variant, pretrained)
return model return model
@ -368,7 +359,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
**kwargs, **kwargs,
) )
model = _create_mnv3(model_kwargs, default_cfgs[variant], pretrained) model = _create_mnv3(model_kwargs, variant, pretrained)
return model return model

Loading…
Cancel
Save