You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.5 KiB
45 lines
1.5 KiB
""" PyTorch Feature Hook Helper
|
|
|
|
This class helps gather features from a network via hooks specified on the module name.
|
|
|
|
Hacked together by Ross Wightman
|
|
"""
|
|
import torch
|
|
|
|
from collections import defaultdict, OrderedDict
|
|
from functools import partial
|
|
from typing import List
|
|
|
|
|
|
class FeatureHooks:
|
|
|
|
def __init__(self, hooks, named_modules, output_as_dict=False):
|
|
# setup feature hooks
|
|
modules = {k: v for k, v in named_modules}
|
|
for h in hooks:
|
|
hook_name = h['module']
|
|
m = modules[hook_name]
|
|
hook_fn = partial(self._collect_output_hook, hook_name)
|
|
if h['hook_type'] == 'forward_pre':
|
|
m.register_forward_pre_hook(hook_fn)
|
|
elif h['hook_type'] == 'forward':
|
|
m.register_forward_hook(hook_fn)
|
|
else:
|
|
assert False, "Unsupported hook type"
|
|
self._feature_outputs = defaultdict(OrderedDict)
|
|
self.output_as_dict = output_as_dict
|
|
|
|
def _collect_output_hook(self, name, *args):
|
|
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
|
if isinstance(x, tuple):
|
|
x = x[0] # unwrap input tuple
|
|
self._feature_outputs[x.device][name] = x
|
|
|
|
def get_output(self, device) -> List[torch.tensor]:
|
|
if self.output_as_dict:
|
|
output = self._feature_outputs[device]
|
|
else:
|
|
output = list(self._feature_outputs[device].values())
|
|
self._feature_outputs[device] = OrderedDict() # clear after reading
|
|
return output
|