Add missing feature_info() on MobileNetV3, make hook feature output order/type consistent with bottleneck (list, decreasing fmap size)

pull/155/head
Ross Wightman 4 years ago
parent 88129b2569
commit 7be299504f

@ -24,9 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
@ -471,7 +474,7 @@ class EfficientNetFeatures(nn.Module):
return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices]
def forward(self, x):
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)

@ -1,5 +1,8 @@
import torch
from collections import defaultdict, OrderedDict
from functools import partial
from typing import List
class FeatureHooks:
@ -25,7 +28,7 @@ class FeatureHooks:
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][name] = x
def get_output(self, device):
output = tuple(self._feature_outputs[device].values())[::-1]
def get_output(self, device) -> List[torch.tensor]:
output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading
return output

@ -7,9 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
@ -206,7 +209,16 @@ class MobileNetV3Features(nn.Module):
return self._feature_info[idx]['num_chs']
return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x):
def feature_info(self, idx=None):
""" Feature Channel Shortcut
Returns feature channel count for each output index if idx == None. If idx is an integer, will
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices]
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)

Loading…
Cancel
Save