From 7be299504fd0d49619abf027ee48d4e33af0a51c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 3 Jun 2020 00:00:37 -0700 Subject: [PATCH] Add missing feature_info() on MobileNetV3, make hook feature output order/type consistent with bottleneck (list, decreasing fmap size) --- timm/models/efficientnet.py | 5 ++++- timm/models/feature_hooks.py | 7 +++++-- timm/models/mobilenetv3.py | 14 +++++++++++++- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index fbd7f420..47cd0b9d 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -24,9 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi Hacked together by Ross Wightman """ +import torch import torch.nn as nn import torch.nn.functional as F +from typing import List + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights @@ -471,7 +474,7 @@ class EfficientNetFeatures(nn.Module): return self._feature_info[idx] return [self._feature_info[i] for i in self.out_indices] - def forward(self, x): + def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) diff --git a/timm/models/feature_hooks.py b/timm/models/feature_hooks.py index 8ffcda86..7b7b3da1 100644 --- a/timm/models/feature_hooks.py +++ b/timm/models/feature_hooks.py @@ -1,5 +1,8 @@ +import torch + from collections import defaultdict, OrderedDict from functools import partial +from typing import List class FeatureHooks: @@ -25,7 +28,7 @@ class FeatureHooks: x = x[0] # unwrap input tuple self._feature_outputs[x.device][name] = x - def get_output(self, device): - output = tuple(self._feature_outputs[device].values())[::-1] + def get_output(self, device) -> List[torch.tensor]: + output = list(self._feature_outputs[device].values()) self._feature_outputs[device] = OrderedDict() # clear after reading return output diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e1a700b0..9c4a9af5 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,9 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ +import torch import torch.nn as nn import torch.nn.functional as F +from typing import List + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights @@ -206,7 +209,16 @@ class MobileNetV3Features(nn.Module): return self._feature_info[idx]['num_chs'] return [self._feature_info[i]['num_chs'] for i in self.out_indices] - def forward(self, x): + def feature_info(self, idx=None): + """ Feature Channel Shortcut + Returns feature channel count for each output index if idx == None. If idx is an integer, will + return feature channel count for that feature block index (independent of out_indices setting). + """ + if isinstance(idx, int): + return self._feature_info[idx] + return [self._feature_info[i] for i in self.out_indices] + + def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x)