|
|
@ -75,8 +75,14 @@ class FeatureInfo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureHooks:
|
|
|
|
class FeatureHooks:
|
|
|
|
|
|
|
|
""" Feature Hook Helper
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
|
|
|
|
This module helps with the setup and extraction of hooks for extracting features from
|
|
|
|
|
|
|
|
internal nodes in a model by node name. This works quite well in eager Python but needs
|
|
|
|
|
|
|
|
redesign for torcscript.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
|
|
|
# setup feature hooks
|
|
|
|
# setup feature hooks
|
|
|
|
modules = {k: v for k, v in named_modules}
|
|
|
|
modules = {k: v for k, v in named_modules}
|
|
|
|
for i, h in enumerate(hooks):
|
|
|
|
for i, h in enumerate(hooks):
|
|
|
@ -92,7 +98,6 @@ class FeatureHooks:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
assert False, "Unsupported hook type"
|
|
|
|
assert False, "Unsupported hook type"
|
|
|
|
self._feature_outputs = defaultdict(OrderedDict)
|
|
|
|
self._feature_outputs = defaultdict(OrderedDict)
|
|
|
|
self.out_as_dict = out_as_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _collect_output_hook(self, hook_id, *args):
|
|
|
|
def _collect_output_hook(self, hook_id, *args):
|
|
|
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
|
|
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
|
|
@ -100,11 +105,8 @@ class FeatureHooks:
|
|
|
|
x = x[0] # unwrap input tuple
|
|
|
|
x = x[0] # unwrap input tuple
|
|
|
|
self._feature_outputs[x.device][hook_id] = x
|
|
|
|
self._feature_outputs[x.device][hook_id] = x
|
|
|
|
|
|
|
|
|
|
|
|
def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript?
|
|
|
|
def get_output(self, device) -> Dict[str, torch.tensor]:
|
|
|
|
if self.out_as_dict:
|
|
|
|
|
|
|
|
output = self._feature_outputs[device]
|
|
|
|
output = self._feature_outputs[device]
|
|
|
|
else:
|
|
|
|
|
|
|
|
output = list(self._feature_outputs[device].values())
|
|
|
|
|
|
|
|
self._feature_outputs[device] = OrderedDict() # clear after reading
|
|
|
|
self._feature_outputs[device] = OrderedDict() # clear after reading
|
|
|
|
return output
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
@ -123,83 +125,72 @@ def _module_list(module, flatten_sequential=False):
|
|
|
|
return ml
|
|
|
|
return ml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerGetterHooks(nn.ModuleDict):
|
|
|
|
def _get_feature_info(net, out_indices):
|
|
|
|
""" LayerGetterHooks
|
|
|
|
feature_info = getattr(net, 'feature_info')
|
|
|
|
TODO
|
|
|
|
if isinstance(feature_info, FeatureInfo):
|
|
|
|
"""
|
|
|
|
return feature_info.from_other(out_indices)
|
|
|
|
|
|
|
|
elif isinstance(feature_info, (list, tuple)):
|
|
|
|
def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None,
|
|
|
|
return FeatureInfo(net.feature_info, out_indices)
|
|
|
|
default_hook_type='forward'):
|
|
|
|
else:
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
assert False, "Provided feature_info is not valid"
|
|
|
|
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info}
|
|
|
|
|
|
|
|
layers = OrderedDict()
|
|
|
|
|
|
|
|
hooks = []
|
|
|
|
|
|
|
|
for new_name, old_name, module in modules:
|
|
|
|
|
|
|
|
layers[new_name] = module
|
|
|
|
|
|
|
|
for fn, fm in module.named_modules(prefix=old_name):
|
|
|
|
|
|
|
|
if fn in remaining:
|
|
|
|
|
|
|
|
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
|
|
|
|
|
|
|
del remaining[fn]
|
|
|
|
|
|
|
|
if not remaining:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
|
|
|
|
|
|
|
super(LayerGetterHooks, self).__init__(layers)
|
|
|
|
|
|
|
|
self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> Dict[Any, torch.Tensor]:
|
|
|
|
|
|
|
|
for name, module in self.items():
|
|
|
|
|
|
|
|
x = module(x)
|
|
|
|
|
|
|
|
return self.hooks.get_output(x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_return_layers(feature_info, out_map):
|
|
|
|
|
|
|
|
module_names = feature_info.module_name()
|
|
|
|
|
|
|
|
return_layers = {}
|
|
|
|
|
|
|
|
for i, name in enumerate(module_names):
|
|
|
|
|
|
|
|
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
|
|
|
|
|
|
|
return return_layers
|
|
|
|
|
|
|
|
|
|
|
|
class LayerGetterDict(nn.ModuleDict):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Module wrapper that returns intermediate layers from a model as a dictionary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Originally based on concepts from IntermediateLayerGetter at
|
|
|
|
class FeatureDictNet(nn.ModuleDict):
|
|
|
|
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
|
|
|
""" Feature extractor with OrderedDict return
|
|
|
|
|
|
|
|
|
|
|
|
It has a strong assumption that the modules have been registered into the model in the same
|
|
|
|
Wrap a model and extract features as specified by the out indices, the network is
|
|
|
|
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
|
|
|
partially re-built from contained modules.
|
|
|
|
in the forward if you want this to work.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Additionally, it is only able to query submodules that are directly assigned to the model
|
|
|
|
There is a strong assumption that the modules have been registered into the model in the same
|
|
|
|
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
|
|
|
|
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
|
|
|
long as `features` is a sequential container assigned to the model).
|
|
|
|
trivial modules like `self.relu = nn.ReLU`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
|
|
|
|
|
|
|
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
|
|
|
All Sequential containers that are directly assigned to the original model will have their
|
|
|
|
All Sequential containers that are directly assigned to the original model will have their
|
|
|
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
|
|
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
|
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
Arguments:
|
|
|
|
model (nn.Module): model on which we will extract the features
|
|
|
|
model (nn.Module): model from which we will extract the features
|
|
|
|
return_layers (Dict[name, new_name]): a dict containing the names
|
|
|
|
out_indices (tuple[int]): model output indices to extract features for
|
|
|
|
of the modules for which the activations will be returned as
|
|
|
|
out_map (sequence): list or tuple specifying desired return id for each out index,
|
|
|
|
the key of the dict, and the value of the dict is the name
|
|
|
|
otherwise str(index) is used
|
|
|
|
of the returned activation (which the user can specify).
|
|
|
|
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
|
|
|
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
|
|
|
|
|
|
|
vs select element [0]
|
|
|
|
vs select element [0]
|
|
|
|
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
|
|
|
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
|
|
|
self, model,
|
|
|
|
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
|
|
|
|
|
|
|
super(FeatureDictNet, self).__init__()
|
|
|
|
|
|
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
|
|
|
|
|
|
self.concat = feature_concat
|
|
|
|
self.return_layers = {}
|
|
|
|
self.return_layers = {}
|
|
|
|
self.concat = concat
|
|
|
|
return_layers = _get_return_layers(self.feature_info, out_map)
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
remaining = set(return_layers.keys())
|
|
|
|
remaining = set(return_layers.keys())
|
|
|
|
layers = OrderedDict()
|
|
|
|
layers = OrderedDict()
|
|
|
|
for new_name, old_name, module in modules:
|
|
|
|
for new_name, old_name, module in modules:
|
|
|
|
layers[new_name] = module
|
|
|
|
layers[new_name] = module
|
|
|
|
if old_name in remaining:
|
|
|
|
if old_name in remaining:
|
|
|
|
self.return_layers[new_name] = return_layers[old_name]
|
|
|
|
# return id has to be consistently str type for torchscript
|
|
|
|
|
|
|
|
self.return_layers[new_name] = str(return_layers[old_name])
|
|
|
|
remaining.remove(old_name)
|
|
|
|
remaining.remove(old_name)
|
|
|
|
if not remaining:
|
|
|
|
if not remaining:
|
|
|
|
break
|
|
|
|
break
|
|
|
|
assert not remaining and len(self.return_layers) == len(return_layers), \
|
|
|
|
assert not remaining and len(self.return_layers) == len(return_layers), \
|
|
|
|
f'Return layers ({remaining}) are not present in model'
|
|
|
|
f'Return layers ({remaining}) are not present in model'
|
|
|
|
super(LayerGetterDict, self).__init__(layers)
|
|
|
|
self.update(layers)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> Dict[Any, torch.Tensor]:
|
|
|
|
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
|
|
|
out = OrderedDict()
|
|
|
|
out = OrderedDict()
|
|
|
|
for name, module in self.items():
|
|
|
|
for name, module in self.items():
|
|
|
|
x = module(x)
|
|
|
|
x = module(x)
|
|
|
@ -213,131 +204,74 @@ class LayerGetterDict(nn.ModuleDict):
|
|
|
|
out[out_id] = x
|
|
|
|
out[out_id] = x
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> Dict[str, torch.Tensor]:
|
|
|
|
|
|
|
|
return self._collect(x)
|
|
|
|
|
|
|
|
|
|
|
|
class LayerGetterList(nn.Sequential):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Module wrapper that returns intermediate layers from a model as a list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Originally based on concepts from IntermediateLayerGetter at
|
|
|
|
|
|
|
|
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
It has a strong assumption that the modules have been registered into the model in the same
|
|
|
|
|
|
|
|
order as they are used. This means that one should **not** reuse the same nn.Module twice
|
|
|
|
|
|
|
|
in the forward if you want this to work.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Additionally, it is only able to query submodules that are directly assigned to the model
|
|
|
|
|
|
|
|
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
|
|
|
|
|
|
|
|
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
All Sequential containers that are directly assigned to the original model will have their
|
|
|
|
|
|
|
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
|
|
|
model (nn.Module): model on which we will extract the features
|
|
|
|
|
|
|
|
return_layers (Dict[name, new_name]): a dict containing the names
|
|
|
|
|
|
|
|
of the modules for which the activations will be returned as
|
|
|
|
|
|
|
|
the key of the dict, and the value of the dict is the name
|
|
|
|
|
|
|
|
of the returned activation (which the user can specify).
|
|
|
|
|
|
|
|
concat (bool): whether to concatenate intermediate features that are lists or tuples
|
|
|
|
|
|
|
|
vs select element [0]
|
|
|
|
|
|
|
|
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
|
|
|
|
|
|
|
|
super(LayerGetterList, self).__init__()
|
|
|
|
|
|
|
|
self.return_layers = {}
|
|
|
|
|
|
|
|
self.concat = concat
|
|
|
|
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
|
|
|
|
remaining = set(return_layers.keys())
|
|
|
|
|
|
|
|
for new_name, orig_name, module in modules:
|
|
|
|
|
|
|
|
self.add_module(new_name, module)
|
|
|
|
|
|
|
|
if orig_name in remaining:
|
|
|
|
|
|
|
|
self.return_layers[new_name] = return_layers[orig_name]
|
|
|
|
|
|
|
|
remaining.remove(orig_name)
|
|
|
|
|
|
|
|
if not remaining:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
assert not remaining and len(self.return_layers) == len(return_layers), \
|
|
|
|
|
|
|
|
f'Return layers ({remaining}) are not present in model'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[torch.Tensor]:
|
|
|
|
|
|
|
|
out = []
|
|
|
|
|
|
|
|
for name, module in self.named_children():
|
|
|
|
|
|
|
|
x = module(x)
|
|
|
|
|
|
|
|
if name in self.return_layers:
|
|
|
|
|
|
|
|
if isinstance(x, (tuple, list)):
|
|
|
|
|
|
|
|
# If model tap is a tuple or list, concat or select first element
|
|
|
|
|
|
|
|
# FIXME this may need to be more generic / flexible for some nets
|
|
|
|
|
|
|
|
out.append(torch.cat(x, 1) if self.concat else x[0])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
out.append(x)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_feature_info(net, out_indices, feature_info=None):
|
|
|
|
|
|
|
|
if feature_info is None:
|
|
|
|
|
|
|
|
feature_info = getattr(net, 'feature_info')
|
|
|
|
|
|
|
|
if isinstance(feature_info, FeatureInfo):
|
|
|
|
|
|
|
|
return feature_info.from_other(out_indices)
|
|
|
|
|
|
|
|
elif isinstance(feature_info, (list, tuple)):
|
|
|
|
|
|
|
|
return FeatureInfo(net.feature_info, out_indices)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert False, "Provided feature_info is not valid"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_return_layers(feature_info, out_map):
|
|
|
|
|
|
|
|
module_names = feature_info.module_name()
|
|
|
|
|
|
|
|
return_layers = {}
|
|
|
|
|
|
|
|
for i, name in enumerate(module_names):
|
|
|
|
|
|
|
|
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
|
|
|
|
|
|
|
return return_layers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureNet(nn.Module):
|
|
|
|
|
|
|
|
""" FeatureNet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Wrap a model and extract features as specified by the out indices, the network
|
|
|
|
class FeatureListNet(FeatureDictNet):
|
|
|
|
is partially re-built from contained modules using the LayerGetters.
|
|
|
|
""" Feature extractor with list return
|
|
|
|
|
|
|
|
|
|
|
|
Please read the docstrings of the LayerGetter classes, they will not work on all models.
|
|
|
|
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
|
|
|
|
|
|
|
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, net,
|
|
|
|
self, model,
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False,
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
|
|
|
feature_info=None, feature_concat=False, flatten_sequential=False):
|
|
|
|
super(FeatureListNet, self).__init__(
|
|
|
|
super(FeatureNet, self).__init__()
|
|
|
|
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
|
|
|
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
|
|
|
|
flatten_sequential=flatten_sequential)
|
|
|
|
if use_hooks:
|
|
|
|
|
|
|
|
self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return_layers = _get_return_layers(self.feature_info, out_map)
|
|
|
|
|
|
|
|
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
|
|
|
|
|
|
|
|
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x) -> (List[torch.Tensor]):
|
|
|
|
output = self.body(x)
|
|
|
|
return list(self._collect(x).values())
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FeatureHookNet(nn.Module):
|
|
|
|
class FeatureHookNet(nn.ModuleDict):
|
|
|
|
""" FeatureHookNet
|
|
|
|
""" FeatureHookNet
|
|
|
|
|
|
|
|
|
|
|
|
Wrap a model and extract features specified by the out indices.
|
|
|
|
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
|
|
|
|
|
|
|
network in any way.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
If `no_rewrite` is False, the model will be re-written as in the
|
|
|
|
|
|
|
|
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
|
|
|
|
|
|
|
|
|
|
|
Features are extracted via hooks without modifying the underlying network in any way. If only
|
|
|
|
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
|
|
|
part of the model is used it is up to the caller to remove unneeded layers as this wrapper
|
|
|
|
|
|
|
|
does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter.
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, net,
|
|
|
|
self, model,
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None,
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
|
|
|
feature_info=None, feature_concat=False):
|
|
|
|
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
|
|
|
super(FeatureHookNet, self).__init__()
|
|
|
|
super(FeatureHookNet, self).__init__()
|
|
|
|
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
|
|
|
|
assert not torch.jit.is_scripting()
|
|
|
|
self.body = net
|
|
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
|
|
self.hooks = FeatureHooks(
|
|
|
|
self.out_as_dict = out_as_dict
|
|
|
|
self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
|
|
|
|
layers = OrderedDict()
|
|
|
|
|
|
|
|
hooks = []
|
|
|
|
|
|
|
|
if no_rewrite:
|
|
|
|
|
|
|
|
assert not flatten_sequential
|
|
|
|
|
|
|
|
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
|
|
|
|
|
|
|
model.reset_classifier(0)
|
|
|
|
|
|
|
|
layers['body'] = model
|
|
|
|
|
|
|
|
hooks.extend(self.feature_info)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
|
|
|
|
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
|
|
|
|
|
|
|
for f in self.feature_info}
|
|
|
|
|
|
|
|
for new_name, old_name, module in modules:
|
|
|
|
|
|
|
|
layers[new_name] = module
|
|
|
|
|
|
|
|
for fn, fm in module.named_modules(prefix=old_name):
|
|
|
|
|
|
|
|
if fn in remaining:
|
|
|
|
|
|
|
|
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
|
|
|
|
|
|
|
del remaining[fn]
|
|
|
|
|
|
|
|
if not remaining:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
|
|
|
|
|
|
|
self.update(layers)
|
|
|
|
|
|
|
|
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
self.body(x)
|
|
|
|
for name, module in self.items():
|
|
|
|
return self.hooks.get_output(x.device)
|
|
|
|
x = module(x)
|
|
|
|
|
|
|
|
out = self.hooks.get_output(x.device)
|
|
|
|
|
|
|
|
return out if self.out_as_dict else list(out.values())
|
|
|
|