""" PyTorch Feature Hook Helper This class helps gather features from a network via hooks specified on the module name. Hacked together by Ross Wightman """ import torch from collections import defaultdict, OrderedDict from functools import partial, partialmethod from typing import List class FeatureHooks: def __init__(self, hooks, named_modules, out_as_dict=False, out_map=None, default_hook_type='forward'): # setup feature hooks modules = {k: v for k, v in named_modules} for i, h in enumerate(hooks): hook_name = h['module'] m = modules[hook_name] hook_id = out_map[i] if out_map else hook_name hook_fn = partial(self._collect_output_hook, hook_id) hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type if hook_type == 'forward_pre': m.register_forward_pre_hook(hook_fn) elif hook_type == 'forward': m.register_forward_hook(hook_fn) else: assert False, "Unsupported hook type" self._feature_outputs = defaultdict(OrderedDict) self.out_as_dict = out_as_dict def _collect_output_hook(self, hook_id, *args): x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre if isinstance(x, tuple): x = x[0] # unwrap input tuple self._feature_outputs[x.device][hook_id] = x def get_output(self, device) -> List[torch.tensor]: if self.out_as_dict: output = self._feature_outputs[device] else: output = list(self._feature_outputs[device].values()) self._feature_outputs[device] = OrderedDict() # clear after reading return output