|
|
|
@ -8,7 +8,7 @@ Hacked together by Ross Wightman
|
|
|
|
|
from collections import OrderedDict, defaultdict
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Dict, List, Tuple, Any
|
|
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -30,42 +30,46 @@ class FeatureInfo:
|
|
|
|
|
def from_other(self, out_indices: Tuple[int]):
|
|
|
|
|
return FeatureInfo(deepcopy(self.info), out_indices)
|
|
|
|
|
|
|
|
|
|
def get(self, key, idx=None):
|
|
|
|
|
""" Get value by key at specified index (indices)
|
|
|
|
|
if idx == None, returns value for key at each output index
|
|
|
|
|
if idx is an integer, return value for that feature module index (ignoring output indices)
|
|
|
|
|
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
|
|
|
|
"""
|
|
|
|
|
if idx is None:
|
|
|
|
|
return [self.info[i][key] for i in self.out_indices]
|
|
|
|
|
if isinstance(idx, (tuple, list)):
|
|
|
|
|
return [self.info[i][key] for i in idx]
|
|
|
|
|
else:
|
|
|
|
|
return self.info[idx][key]
|
|
|
|
|
|
|
|
|
|
def get_dicts(self, keys=None, idx=None):
|
|
|
|
|
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
|
|
|
|
"""
|
|
|
|
|
if idx is None:
|
|
|
|
|
if keys is None:
|
|
|
|
|
return [self.info[i] for i in self.out_indices]
|
|
|
|
|
else:
|
|
|
|
|
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
|
|
|
|
if isinstance(idx, (tuple, list)):
|
|
|
|
|
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
|
|
|
|
else:
|
|
|
|
|
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
|
|
|
|
|
|
|
|
|
def channels(self, idx=None):
|
|
|
|
|
""" feature channels accessor
|
|
|
|
|
if idx == None, returns feature channel count at each output index
|
|
|
|
|
if idx is an integer, return feature channel count for that feature module index
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(idx, int):
|
|
|
|
|
return self.info[idx]['num_chs']
|
|
|
|
|
return [self.info[i]['num_chs'] for i in self.out_indices]
|
|
|
|
|
return self.get('num_chs', idx)
|
|
|
|
|
|
|
|
|
|
def reduction(self, idx=None):
|
|
|
|
|
""" feature reduction (output stride) accessor
|
|
|
|
|
if idx == None, returns feature reduction factor at each output index
|
|
|
|
|
if idx is an integer, return feature channel count at that feature module index
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(idx, int):
|
|
|
|
|
return self.info[idx]['reduction']
|
|
|
|
|
return [self.info[i]['reduction'] for i in self.out_indices]
|
|
|
|
|
return self.get('reduction', idx)
|
|
|
|
|
|
|
|
|
|
def module_name(self, idx=None):
|
|
|
|
|
""" feature module name accessor
|
|
|
|
|
if idx == None, returns feature module name at each output index
|
|
|
|
|
if idx is an integer, return feature module name at that feature module index
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(idx, int):
|
|
|
|
|
return self.info[idx]['module']
|
|
|
|
|
return [self.info[i]['module'] for i in self.out_indices]
|
|
|
|
|
|
|
|
|
|
def get_by_key(self, idx=None, keys=None):
|
|
|
|
|
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(idx, int):
|
|
|
|
|
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
|
|
|
|
if keys is None:
|
|
|
|
|
return [self.info[i] for i in self.out_indices]
|
|
|
|
|
else:
|
|
|
|
|
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
|
|
|
|
return self.get('module', idx)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, item):
|
|
|
|
|
return self.info[item]
|
|
|
|
@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict):
|
|
|
|
|
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
|
|
|
|
model.reset_classifier(0)
|
|
|
|
|
layers['body'] = model
|
|
|
|
|
hooks.extend(self.feature_info)
|
|
|
|
|
hooks.extend(self.feature_info.get_dicts())
|
|
|
|
|
else:
|
|
|
|
|
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
|
|
|
|
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
|
|
|
|
for f in self.feature_info}
|
|
|
|
|
for f in self.feature_info.get_dicts()}
|
|
|
|
|
for new_name, old_name, module in modules:
|
|
|
|
|
layers[new_name] = module
|
|
|
|
|
for fn, fm in module.named_modules(prefix=old_name):
|
|
|
|
|