You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/features.py

310 lines
13 KiB

""" PyTorch Feature Extraction Helpers
A collection of classes, functions, modules to help extract features from models
and provide a common interface for describing them.
Hacked together by Ross Wightman
"""
from collections import OrderedDict
from typing import Dict, List, Tuple, Any
from copy import deepcopy
import torch
import torch.nn as nn
from .feature_hooks import FeatureHooks
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
prev_reduction = 1
for fi in feature_info:
# sanity check the mandatory fields, there may be additional fields depending on the model
assert 'num_chs' in fi and fi['num_chs'] > 0
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction']
assert 'module' in fi
self.out_indices = out_indices
self.info = feature_info
def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self.info), out_indices)
def channels(self, idx=None):
""" feature channels accessor
if idx == None, returns feature channel count at each output index
if idx is an integer, return feature channel count for that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['num_chs']
return [self.info[i]['num_chs'] for i in self.out_indices]
def reduction(self, idx=None):
""" feature reduction (output stride) accessor
if idx == None, returns feature reduction factor at each output index
if idx is an integer, return feature channel count at that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['reduction']
return [self.info[i]['reduction'] for i in self.out_indices]
def module_name(self, idx=None):
""" feature module name accessor
if idx == None, returns feature module name at each output index
if idx is an integer, return feature module name at that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['module']
return [self.info[i]['module'] for i in self.out_indices]
def get_by_key(self, idx=None, keys=None):
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
"""
if isinstance(idx, int):
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
def __getitem__(self, item):
return self.info[item]
def __len__(self):
return len(self.info)
def _module_list(module, flatten_sequential=False):
# a yield/iter would be better for this but wouldn't be compatible with torchscript
ml = []
for name, module in module.named_children():
if flatten_sequential and isinstance(module, nn.Sequential):
# first level of Sequential containers is flattened into containing model
for child_name, child_module in module.named_children():
combined = [name, child_name]
ml.append(('_'.join(combined), '.'.join(combined), child_module))
else:
ml.append((name, name, module))
return ml
class LayerGetterHooks(nn.ModuleDict):
""" LayerGetterHooks
TODO
"""
def __init__(self, model, feature_info, flatten_sequential=False, out_as_dict=False, out_map=None,
default_hook_type='forward'):
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type for f in feature_info}
layers = OrderedDict()
hooks = []
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
super(LayerGetterHooks, self).__init__(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
def forward(self, x) -> Dict[Any, torch.Tensor]:
for name, module in self.items():
x = module(x)
return self.hooks.get_output(x.device)
class LayerGetterDict(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model as a dictionary
Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`, so
long as `features` is a sequential container assigned to the model).
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
self.return_layers = {}
self.concat = concat
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
layers = OrderedDict()
for new_name, old_name, module in modules:
layers[new_name] = module
if old_name in remaining:
self.return_layers[new_name] = return_layers[old_name]
remaining.remove(old_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
super(LayerGetterDict, self).__init__(layers)
def forward(self, x) -> Dict[Any, torch.Tensor]:
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_id = self.return_layers[name]
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
else:
out[out_id] = x
return out
class LayerGetterList(nn.Sequential):
"""
Module wrapper that returns intermediate layers from a model as a list
Originally based on concepts from IntermediateLayerGetter at
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
It has a strong assumption that the modules have been registered into the model in the same
order as they are used. This means that one should **not** reuse the same nn.Module twice
in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly assigned to the model
class (`model.feature1`) or at most one Sequential container deep (`model.features.1`) so
long as `features` is a sequential container assigned to the model and flatten_sequent=True.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(self, model, return_layers, concat=False, flatten_sequential=False):
super(LayerGetterList, self).__init__()
self.return_layers = {}
self.concat = concat
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
for new_name, orig_name, module in modules:
self.add_module(new_name, module)
if orig_name in remaining:
self.return_layers[new_name] = return_layers[orig_name]
remaining.remove(orig_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
def forward(self, x) -> List[torch.Tensor]:
out = []
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out.append(torch.cat(x, 1) if self.concat else x[0])
else:
out.append(x)
return out
def _resolve_feature_info(net, out_indices, feature_info=None):
if feature_info is None:
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class FeatureNet(nn.Module):
""" FeatureNet
Wrap a model and extract features as specified by the out indices, the network
is partially re-built from contained modules using the LayerGetters.
Please read the docstrings of the LayerGetter classes, they will not work on all models.
"""
def __init__(
self, net,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, use_hooks=False,
feature_info=None, feature_concat=False, flatten_sequential=False):
super(FeatureNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
if use_hooks:
self.body = LayerGetterHooks(net, self.feature_info, out_as_dict=out_as_dict, out_map=out_map)
else:
return_layers = _get_return_layers(self.feature_info, out_map)
lg_args = dict(return_layers=return_layers, concat=feature_concat, flatten_sequential=flatten_sequential)
self.body = LayerGetterDict(net, **lg_args) if out_as_dict else LayerGetterList(net, **lg_args)
def forward(self, x):
output = self.body(x)
return output
class FeatureHookNet(nn.Module):
""" FeatureHookNet
Wrap a model and extract features specified by the out indices.
Features are extracted via hooks without modifying the underlying network in any way. If only
part of the model is used it is up to the caller to remove unneeded layers as this wrapper
does not rewrite and remove unused top-level modules like FeatureNet with LayerGetter.
"""
def __init__(
self, net,
out_indices=(0, 1, 2, 3, 4), out_as_dict=False, out_map=None,
feature_info=None, feature_concat=False):
super(FeatureHookNet, self).__init__()
self.feature_info = _resolve_feature_info(net, out_indices, feature_info)
self.body = net
self.hooks = FeatureHooks(
self.feature_info, self.body.named_modules(), out_as_dict=out_as_dict, out_map=out_map)
def forward(self, x):
self.body(x)
return self.hooks.get_output(x.device)