Cleanup, refactoring of Feature extraction code, add tests, fix tests, non hook feature extraction working with torchscript

pull/175/head
Ross Wightman 4 years ago
parent 6eec3fb4a4
commit 4e61c6a12d

@ -106,3 +106,26 @@ def test_model_forward_torchscript(model_name, batch_size):
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
EXCLUDE_FEAT_FILTERS = [
'hrnet*', '*pruned*', # hopefully fix at some point
'legacy*', # not going to bother
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""
model = create_model(model_name, pretrained=False, features_only=True)
model.eval()
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)
for e, o in zip(expected_channels, outputs):
assert e == o.shape[1]
assert o.shape[0] == batch_size
assert not torch.isnan(o).any()

@ -75,8 +75,14 @@ class FeatureInfo:
class FeatureHooks:
""" Feature Hook Helper
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs
redesign for torcscript.
"""
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
@ -92,7 +98,6 @@ class FeatureHooks:
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
@ -100,11 +105,8 @@ class FeatureHooks:
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]: # FIXME deal with diff return types for torchscript?
if self.out_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
def get_output(self, device) -> Dict[str, torch.tensor]:
output = self._feature_outputs[device]
self._feature_outputs[device] = OrderedDict() # clear after reading
return output
@ -123,83 +125,72 @@ def _module_list(module, flatten_sequential=False):
return ml
class LayerGetterHooks(nn.ModuleDict):
""" LayerGetterHooks
TODO
"""
def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None,
default_hook_type='forward'):
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info}
layers = OrderedDict()
hooks = []
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
super(LayerGetterHooks, self).__init__(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
def _get_feature_info(net, out_indices):
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
def forward(self, x) -> Dict[Any, torch.Tensor]:
for name, module in self.items():
x = module(x)
return self.hooks.get_output(x.device)
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class LayerGetterDict(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model as a dictionary
Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
class FeatureDictNet(nn.ModuleDict):
""" Feature extractor with OrderedDict return
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Wrap a model and extract features as specified by the out indices, the network is
partially re-built from contained modules.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
long as `features` is a sequential container assigned to the model).
There is a strong assumption that the modules have been registered into the model in the same
order as they are used. There should be no reuse of the same nn.Module more than once, including
trivial modules like `self.relu = nn.ReLU`.
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
model (nn.Module): model from which we will extract the features
out_indices (tuple[int]): model output indices to extract features for
out_map (sequence): list or tuple specifying desired return id for each out index,
otherwise str(index) is used
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.concat = feature_concat
self.return_layers = {}
self.concat = concat
return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
layers = OrderedDict()
for new_name, old_name, module in modules:
layers[new_name] = module
if old_name in remaining:
self.return_layers[new_name] = return_layers[old_name]
# return id has to be consistently str type for torchscript
self.return_layers[new_name] = str(return_layers[old_name])
remaining.remove(old_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
super(LayerGetterDict, self).__init__(layers)
self.update(layers)
def forward(self, x) -> Dict[Any, torch.Tensor]:
def _collect(self, x) -> (Dict[str, torch.Tensor]):
out = OrderedDict()
for name, module in self.items():
x = module(x)
@ -213,131 +204,74 @@ class LayerGetterDict(nn.ModuleDict):
out[out_id] = x
return out
def forward(self, x) -> Dict[str, torch.Tensor]:
return self._collect(x)
class LayerGetterList(nn.Sequential):
"""
Module wrapper that returns intermediate layers from a model as a list
Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
class FeatureListNet(FeatureDictNet):
""" Feature extractor with list return
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
super(LayerGetterList, self).__init__()
self.return_layers = {}
self.concat = concat
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
for new_name, orig_name, module in modules:
self.add_module(new_name, module)
if orig_name in remaining:
self.return_layers[new_name] = return_layers[orig_name]
remaining.remove(orig_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
def forward(self, x) -> List[torch.Tensor]:
out = []
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out.append(torch.cat(x, 1) if self.concat else x[0])
else:
out.append(x)
return out
def _resolve_feature_info(net, out_indices, feature_info=None):
if feature_info is None:
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class FeatureNet(nn.Module):
""" FeatureNet
Wrap a model and extract features as specified by the out indices, the network
is partially re-built from contained modules using the LayerGetters.
Please read the docstrings of the LayerGetter classes, they will not work on all models.
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
"""
def __init__(
self, net,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False,
feature_info=None, feature_concat=False, flatten_sequential=False):
super(FeatureNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
if use_hooks:
self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map)
else:
return_layers = _get_return_layers(self.feature_info, out_map)
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureListNet, self).__init__(
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
flatten_sequential=flatten_sequential)
def forward(self, x):
output = self.body(x)
return output
def forward(self, x) -> (List[torch.Tensor]):
return list(self._collect(x).values())
class FeatureHookNet(nn.Module):
class FeatureHookNet(nn.ModuleDict):
""" FeatureHookNet
Wrap a model and extract features specified by the out indices.
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
Features are extracted via hooks without modifying the underlying network in any way. If only
part of the model is used it is up to the caller to remove unneeded layers as this wrapper
does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter.
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
network in any way.
If `no_rewrite` is False, the model will be re-written as in the
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
FIXME this does not currently work with Torchscript, see FeatureHooks class
"""
def __init__(
self, net,
out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None,
feature_info=None, feature_concat=False):
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
super(FeatureHookNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
self.body = net
self.hooks = FeatureHooks(
self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict
layers = OrderedDict()
hooks = []
if no_rewrite:
assert not flatten_sequential
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0)
layers['body'] = model
hooks.extend(self.feature_info)
else:
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info}
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def forward(self, x):
self.body(x)
return self.hooks.get_output(x.device)
for name, module in self.items():
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())

@ -252,7 +252,7 @@ class Xception65(nn.Module):
def _create_gluon_xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
Xception65, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(use_hooks=True), **kwargs)
feature_cfg=dict(feature_cls='hook'), **kwargs)
@register_model

@ -8,7 +8,7 @@ import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .features import FeatureNet, FeatureHookNet
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame
@ -234,15 +234,15 @@ def build_model_with_cfg(
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features:
feature_cls = feature_cfg.pop('feature_cls', FeatureNet)
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if feature_cls == 'hook' or feature_cls == 'featurehooknet':
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
if feature_cls == FeatureHookNet and hasattr(model, 'reset_classifier'):
model.reset_classifier(0)
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
return model

@ -211,7 +211,8 @@ class MobileNetV3Features(nn.Module):
return features
else:
self.blocks(x)
return self.feature_hooks.get_output(x.device)
out = self.feature_hooks.get_output(x.device)
return list(out.values())
def _create_mnv3(model_kwargs, variant, pretrained=False):
@ -220,6 +221,7 @@ def _create_mnv3(model_kwargs, variant, pretrained=False):
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
model_cls = MobileNetV3Features
else:
load_strict = True

@ -554,7 +554,7 @@ class NASNetALarge(nn.Module):
def _create_nasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs)

@ -337,7 +337,7 @@ class PNASNet5Large(nn.Module):
def _create_pnasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), # not possible to re-write this model, must use FeatureHookNet
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs)

@ -74,6 +74,29 @@ class SequentialList(nn.Sequential):
return x
class SelectSeq(nn.Module):
def __init__(self, mode='index', index=0):
super(SelectSeq, self).__init__()
self.mode = mode
self.index = index
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (List[torch.Tensor]) -> (torch.Tensor)
pass
@torch.jit._overload_method # noqa: F811
def forward(self, x):
# type: (Tuple[torch.Tensor]) -> (torch.Tensor)
pass
def forward(self, x) -> torch.Tensor:
if self.mode == 'index':
return x[self.index]
else:
return torch.cat(x, dim=1)
def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
if padding is None:
padding = ((stride - 1) + dilation * (k - 1)) // 2
@ -137,8 +160,10 @@ class SelecSLS(nn.Module):
self.stem = conv_bn(in_chans, 32, stride=2)
self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']])
self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way
self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
self.num_features = cfg['num_features']
self.feature_info = cfg['feature_info']
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -165,7 +190,7 @@ class SelecSLS(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.features(x)
x = self.head(x[0])
x = self.head(self.from_seq(x))
return x
def forward(self, x):
@ -297,6 +322,7 @@ def _create_selecsls(variant, pretrained, model_kwargs):
])
else:
raise ValueError('Invalid net configuration ' + variant + ' !!!')
cfg['feature_info'] = feature_info
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
return build_model_with_cfg(

@ -160,6 +160,9 @@ class Bottleneck(nn.Module):
conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
aa_layer(channels=planes, filt_size=3, stride=2))
reduce_layer_planes = max(planes * self.expansion // 8, 64)
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
@ -167,9 +170,6 @@ class Bottleneck(nn.Module):
self.downsample = downsample
self.stride = stride
reduce_layer_planes = max(planes * self.expansion // 8, 64)
self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None
def forward(self, x):
if self.downsample is not None:
residual = self.downsample(x)
@ -225,8 +225,8 @@ class TResNet(nn.Module):
dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
dict(num_chs=self.planes, reduction=4, module='body.layer1'),
dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'),
dict(num_chs=self.planes * 4, reduction=16, module='body.layer3'),
dict(num_chs=self.planes * 8, reduction=32, module='body.layer4'),
dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
]
# head

@ -228,7 +228,7 @@ class Xception(nn.Module):
def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
Xception, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(use_hooks=True), **kwargs)
feature_cfg=dict(feature_cls='hook'), **kwargs)
@register_model

@ -174,7 +174,7 @@ class XceptionAligned(nn.Module):
def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True, use_hooks=True), **kwargs)
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs)
@register_model

@ -100,6 +100,8 @@ def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if args.legacy_jit:
set_jit_legacy()
# create model
model = create_model(
@ -119,8 +121,6 @@ def validate(args):
model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.torchscript:
if args.legacy_jit:
set_jit_legacy()
torch.jit.optimized_execution(True)
model = torch.jit.script(model)

Loading…
Cancel
Save