Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights

pull/175/head
Ross Wightman 5 years ago
parent 7ba5a384d3
commit 08016e839d

@ -441,7 +441,7 @@ class EfficientNetFeatures(nn.Module):
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]:

@ -8,7 +8,7 @@ Hacked together by Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple, Any
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
@ -30,42 +30,46 @@ class FeatureInfo:
def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self.info), out_indices)
def get(self, key, idx=None):
""" Get value by key at specified index (indices)
if idx == None, returns value for key at each output index
if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
"""
if idx is None:
return [self.info[i][key] for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i][key] for i in idx]
else:
return self.info[idx][key]
def get_dicts(self, keys=None, idx=None):
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
"""
if idx is None:
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None):
""" feature channels accessor
if idx == None, returns feature channel count at each output index
if idx is an integer, return feature channel count for that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['num_chs']
return [self.info[i]['num_chs'] for i in self.out_indices]
return self.get('num_chs', idx)
def reduction(self, idx=None):
""" feature reduction (output stride) accessor
if idx == None, returns feature reduction factor at each output index
if idx is an integer, return feature channel count at that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['reduction']
return [self.info[i]['reduction'] for i in self.out_indices]
return self.get('reduction', idx)
def module_name(self, idx=None):
""" feature module name accessor
if idx == None, returns feature module name at each output index
if idx is an integer, return feature module name at that feature module index
"""
if isinstance(idx, int):
return self.info[idx]['module']
return [self.info[i]['module'] for i in self.out_indices]
def get_by_key(self, idx=None, keys=None):
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
"""
if isinstance(idx, int):
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
return self.get('module', idx)
def __getitem__(self, item):
return self.info[item]
@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict):
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0)
layers['body'] = model
hooks.extend(self.feature_info)
hooks.extend(self.feature_info.get_dicts())
else:
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info}
for f in self.feature_info.get_dicts()}
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):

@ -186,7 +186,7 @@ class MobileNetV3Features(nn.Module):
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]:

@ -31,9 +31,12 @@ def _cfg(url='', **kwargs):
default_cfgs = dict(
xception41=_cfg(url=''),
xception65=_cfg(url=''),
xception71=_cfg(url=''),
xception41=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
xception65=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
)
@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs):
return _xception('xception65', pretrained=pretrained, **model_args)
@register_model
def xception71(pretrained=False, **kwargs):
""" Modified Aligned Xception-71

Loading…
Cancel
Save