|
|
|
@ -11,10 +11,11 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
from collections import OrderedDict, defaultdict
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
from typing import Dict, List, Sequence, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
|
|
|
@ -88,12 +89,20 @@ class FeatureHooks:
|
|
|
|
|
""" Feature Hook Helper
|
|
|
|
|
|
|
|
|
|
This module helps with the setup and extraction of hooks for extracting features from
|
|
|
|
|
internal nodes in a model by node name. This works quite well in eager Python but needs
|
|
|
|
|
redesign for torchscript.
|
|
|
|
|
internal nodes in a model by node name.
|
|
|
|
|
|
|
|
|
|
FIXME This works well in eager Python but needs redesign for torchscript.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
hooks: Sequence[str],
|
|
|
|
|
named_modules: dict,
|
|
|
|
|
out_map: Sequence[Union[int, str]] = None,
|
|
|
|
|
default_hook_type: str = 'forward',
|
|
|
|
|
):
|
|
|
|
|
# setup feature hooks
|
|
|
|
|
self._feature_outputs = defaultdict(OrderedDict)
|
|
|
|
|
modules = {k: v for k, v in named_modules}
|
|
|
|
|
for i, h in enumerate(hooks):
|
|
|
|
|
hook_name = h['module']
|
|
|
|
@ -107,7 +116,6 @@ class FeatureHooks:
|
|
|
|
|
m.register_forward_hook(hook_fn)
|
|
|
|
|
else:
|
|
|
|
|
assert False, "Unsupported hook type"
|
|
|
|
|
self._feature_outputs = defaultdict(OrderedDict)
|
|
|
|
|
|
|
|
|
|
def _collect_output_hook(self, hook_id, *args):
|
|
|
|
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
|
|
|
@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
|
|
|
|
|
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
|
|
|
|
All Sequential containers that are directly assigned to the original model will have their
|
|
|
|
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
model (nn.Module): model from which we will extract the features
|
|
|
|
|
out_indices (tuple[int]): model output indices to extract features for
|
|
|
|
|
out_map (sequence): list or tuple specifying desired return id for each out index,
|
|
|
|
|
otherwise str(index) is used
|
|
|
|
|
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
|
|
|
|
vs select element [0]
|
|
|
|
|
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
|
|
|
|
"""
|
|
|
|
|
def __init__(
|
|
|
|
|
self, model,
|
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
|
|
|
|
self,
|
|
|
|
|
model: nn.Module,
|
|
|
|
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
|
|
|
|
out_map: Sequence[Union[int, str]] = None,
|
|
|
|
|
feature_concat: bool = False,
|
|
|
|
|
flatten_sequential: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
model: Model from which to extract features.
|
|
|
|
|
out_indices: Output indices of the model features to extract.
|
|
|
|
|
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
|
|
|
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
|
|
|
|
first element e.g. `x[0]`
|
|
|
|
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
|
|
|
|
"""
|
|
|
|
|
super(FeatureDictNet, self).__init__()
|
|
|
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
|
|
|
self.concat = feature_concat
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.return_layers = {}
|
|
|
|
|
|
|
|
|
|
return_layers = _get_return_layers(self.feature_info, out_map)
|
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
|
remaining = set(return_layers.keys())
|
|
|
|
@ -200,10 +215,21 @@ class FeatureDictNet(nn.ModuleDict):
|
|
|
|
|
f'Return layers ({remaining}) are not present in model'
|
|
|
|
|
self.update(layers)
|
|
|
|
|
|
|
|
|
|
def set_grad_checkpointing(self, enable: bool = True):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
|
|
|
|
out = OrderedDict()
|
|
|
|
|
for name, module in self.items():
|
|
|
|
|
for i, (name, module) in enumerate(self.items()):
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
# Skipping checkpoint of first module because need a gradient at input
|
|
|
|
|
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
|
|
|
|
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
|
|
|
|
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
|
|
|
|
x = module(x) if first_or_last_module else checkpoint(module, x)
|
|
|
|
|
else:
|
|
|
|
|
x = module(x)
|
|
|
|
|
|
|
|
|
|
if name in self.return_layers:
|
|
|
|
|
out_id = self.return_layers[name]
|
|
|
|
|
if isinstance(x, (tuple, list)):
|
|
|
|
@ -221,15 +247,29 @@ class FeatureDictNet(nn.ModuleDict):
|
|
|
|
|
class FeatureListNet(FeatureDictNet):
|
|
|
|
|
""" Feature extractor with list return
|
|
|
|
|
|
|
|
|
|
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
|
|
|
|
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
|
|
|
|
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
|
|
|
|
|
"""
|
|
|
|
|
def __init__(
|
|
|
|
|
self, model,
|
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
|
|
|
|
self,
|
|
|
|
|
model: nn.Module,
|
|
|
|
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
|
|
|
|
feature_concat: bool = False,
|
|
|
|
|
flatten_sequential: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
model: Model from which to extract features.
|
|
|
|
|
out_indices: Output indices of the model features to extract.
|
|
|
|
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
|
|
|
|
first element e.g. `x[0]`
|
|
|
|
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
|
|
|
|
"""
|
|
|
|
|
super(FeatureListNet, self).__init__(
|
|
|
|
|
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
|
|
|
|
flatten_sequential=flatten_sequential)
|
|
|
|
|
model,
|
|
|
|
|
out_indices=out_indices,
|
|
|
|
|
feature_concat=feature_concat,
|
|
|
|
|
flatten_sequential=flatten_sequential,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> (List[torch.Tensor]):
|
|
|
|
|
return list(self._collect(x).values())
|
|
|
|
@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
|
|
|
|
|
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
|
|
|
|
"""
|
|
|
|
|
def __init__(
|
|
|
|
|
self, model,
|
|
|
|
|
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
|
|
|
|
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
|
|
|
|
self,
|
|
|
|
|
model: nn.Module,
|
|
|
|
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
|
|
|
|
out_map: Sequence[Union[int, str]] = None,
|
|
|
|
|
out_as_dict: bool = False,
|
|
|
|
|
no_rewrite: bool = False,
|
|
|
|
|
flatten_sequential: bool = False,
|
|
|
|
|
default_hook_type: str = 'forward',
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model: Model from which to extract features.
|
|
|
|
|
out_indices: Output indices of the model features to extract.
|
|
|
|
|
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
|
|
|
|
out_as_dict: Output features as a dict.
|
|
|
|
|
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
|
|
|
|
flatten_sequential arg must also be False if this is set True.
|
|
|
|
|
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
|
|
|
|
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
|
|
|
|
"""
|
|
|
|
|
super(FeatureHookNet, self).__init__()
|
|
|
|
|
assert not torch.jit.is_scripting()
|
|
|
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
|
|
|
self.out_as_dict = out_as_dict
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
layers = OrderedDict()
|
|
|
|
|
hooks = []
|
|
|
|
|
if no_rewrite:
|
|
|
|
@ -266,8 +326,10 @@ class FeatureHookNet(nn.ModuleDict):
|
|
|
|
|
hooks.extend(self.feature_info.get_dicts())
|
|
|
|
|
else:
|
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
|
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
|
|
|
|
for f in self.feature_info.get_dicts()}
|
|
|
|
|
remaining = {
|
|
|
|
|
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
|
|
|
|
for f in self.feature_info.get_dicts()
|
|
|
|
|
}
|
|
|
|
|
for new_name, old_name, module in modules:
|
|
|
|
|
layers[new_name] = module
|
|
|
|
|
for fn, fm in module.named_modules(prefix=old_name):
|
|
|
|
@ -280,8 +342,18 @@ class FeatureHookNet(nn.ModuleDict):
|
|
|
|
|
self.update(layers)
|
|
|
|
|
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
|
|
|
|
|
|
|
|
|
def set_grad_checkpointing(self, enable: bool = True):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
for name, module in self.items():
|
|
|
|
|
for i, (name, module) in enumerate(self.items()):
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
# Skipping checkpoint of first module because need a gradient at input
|
|
|
|
|
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
|
|
|
|
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
|
|
|
|
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
|
|
|
|
x = module(x) if first_or_last_module else checkpoint(module, x)
|
|
|
|
|
else:
|
|
|
|
|
x = module(x)
|
|
|
|
|
out = self.hooks.get_output(x.device)
|
|
|
|
|
return out if self.out_as_dict else list(out.values())
|
|
|
|
|