You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/feature_hooks.py

47 lines
1.7 KiB

""" PyTorch Feature Hook Helper
This class helps gather features from a network via hooks specified on the module name.
Hacked together by Ross Wightman
"""
import torch
from collections import defaultdict, OrderedDict
from functools import partial, partialmethod
from typing import List
class FeatureHooks:
def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.out_as_dict = out_as_dict
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> List[torch.tensor]:
if self.out_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output