From 08016e839d73a745be5cdac86ee825bd877defff Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jul 2020 17:59:21 -0700 Subject: [PATCH] Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights --- timm/models/efficientnet.py | 2 +- timm/models/features.py | 60 ++++++++++++++++++--------------- timm/models/mobilenetv3.py | 2 +- timm/models/xception_aligned.py | 10 +++--- 4 files changed, 40 insertions(+), 34 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 21be2a96..08d1df7e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -441,7 +441,7 @@ class EfficientNetFeatures(nn.Module): # Register feature extraction hooks with FeatureHooks helper self.feature_hooks = None if feature_location != 'bottleneck': - hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def forward(self, x) -> List[torch.Tensor]: diff --git a/timm/models/features.py b/timm/models/features.py index 757811af..7329851c 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -8,7 +8,7 @@ Hacked together by Ross Wightman from collections import OrderedDict, defaultdict from copy import deepcopy from functools import partial -from typing import Dict, List, Tuple, Any +from typing import Dict, List, Tuple import torch import torch.nn as nn @@ -30,42 +30,46 @@ class FeatureInfo: def from_other(self, out_indices: Tuple[int]): return FeatureInfo(deepcopy(self.info), out_indices) + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + def channels(self, idx=None): """ feature channels accessor - if idx == None, returns feature channel count at each output index - if idx is an integer, return feature channel count for that feature module index """ - if isinstance(idx, int): - return self.info[idx]['num_chs'] - return [self.info[i]['num_chs'] for i in self.out_indices] + return self.get('num_chs', idx) def reduction(self, idx=None): """ feature reduction (output stride) accessor - if idx == None, returns feature reduction factor at each output index - if idx is an integer, return feature channel count at that feature module index """ - if isinstance(idx, int): - return self.info[idx]['reduction'] - return [self.info[i]['reduction'] for i in self.out_indices] + return self.get('reduction', idx) def module_name(self, idx=None): """ feature module name accessor - if idx == None, returns feature module name at each output index - if idx is an integer, return feature module name at that feature module index - """ - if isinstance(idx, int): - return self.info[idx]['module'] - return [self.info[i]['module'] for i in self.out_indices] - - def get_by_key(self, idx=None, keys=None): - """ return info dicts for specified keys (or all if None) at specified idx (or out_indices if None) """ - if isinstance(idx, int): - return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} - if keys is None: - return [self.info[i] for i in self.out_indices] - else: - return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + return self.get('module', idx) def __getitem__(self, item): return self.info[item] @@ -253,11 +257,11 @@ class FeatureHookNet(nn.ModuleDict): if hasattr(model, 'reset_classifier'): # make sure classifier is removed? model.reset_classifier(0) layers['body'] = model - hooks.extend(self.feature_info) + hooks.extend(self.feature_info.get_dicts()) else: modules = _module_list(model, flatten_sequential=flatten_sequential) remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type - for f in self.feature_info} + for f in self.feature_info.get_dicts()} for new_name, old_name, module in modules: layers[new_name] = module for fn, fm in module.named_modules(prefix=old_name): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 7e4af274..9e98394c 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -186,7 +186,7 @@ class MobileNetV3Features(nn.Module): # Register feature extraction hooks with FeatureHooks helper self.feature_hooks = None if feature_location != 'bottleneck': - hooks = self.feature_info.get_by_key(keys=('module', 'hook_type')) + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) self.feature_hooks = FeatureHooks(hooks, self.named_modules()) def forward(self, x) -> List[torch.Tensor]: diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 8303af27..b6bd8944 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -31,9 +31,12 @@ def _cfg(url='', **kwargs): default_cfgs = dict( - xception41=_cfg(url=''), - xception65=_cfg(url=''), - xception71=_cfg(url=''), + xception41=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'), + xception65=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), + xception71=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), ) @@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs): return _xception('xception65', pretrained=pretrained, **model_args) - @register_model def xception71(pretrained=False, **kwargs): """ Modified Aligned Xception-71