You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/feature_hooks.py

45 lines
1.5 KiB

""" PyTorch Feature Hook Helper
This class helps gather features from a network via hooks specified on the module name.
Hacked together by Ross Wightman
"""
import torch
from collections import defaultdict, OrderedDict
from functools import partial
from typing import List
class FeatureHooks:
def __init__(self, hooks, named_modules, output_as_dict=False):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for h in hooks:
hook_name = h['module']
m = modules[hook_name]
hook_fn = partial(self._collect_output_hook, hook_name)
if h['hook_type'] == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif h['hook_type'] == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
self.output_as_dict = output_as_dict
def _collect_output_hook(self, name, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][name] = x
def get_output(self, device) -> List[torch.tensor]:
if self.output_as_dict:
output = self._feature_outputs[device]
else:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output