Add missing feature_info() on MobileNetV3, make hook feature output order/type consistent with bottleneck (list, decreasing fmap size)

pull/155/head
Ross Wightman 5 years ago
parent 88129b2569
commit 7be299504f

@ -24,9 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
@ -471,7 +474,7 @@ class EfficientNetFeatures(nn.Module):
return self._feature_info[idx] return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices] return [self._feature_info[i] for i in self.out_indices]
def forward(self, x): def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) x = self.act1(x)

@ -1,5 +1,8 @@
import torch
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from functools import partial from functools import partial
from typing import List
class FeatureHooks: class FeatureHooks:
@ -25,7 +28,7 @@ class FeatureHooks:
x = x[0] # unwrap input tuple x = x[0] # unwrap input tuple
self._feature_outputs[x.device][name] = x self._feature_outputs[x.device][name] = x
def get_output(self, device): def get_output(self, device) -> List[torch.tensor]:
output = tuple(self._feature_outputs[device].values())[::-1] output = list(self._feature_outputs[device].values())
self._feature_outputs[device] = OrderedDict() # clear after reading self._feature_outputs[device] = OrderedDict() # clear after reading
return output return output

@ -7,9 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
@ -206,7 +209,16 @@ class MobileNetV3Features(nn.Module):
return self._feature_info[idx]['num_chs'] return self._feature_info[idx]['num_chs']
return [self._feature_info[i]['num_chs'] for i in self.out_indices] return [self._feature_info[i]['num_chs'] for i in self.out_indices]
def forward(self, x): def feature_info(self, idx=None):
""" Feature Channel Shortcut
Returns feature channel count for each output index if idx == None. If idx is an integer, will
return feature channel count for that feature block index (independent of out_indices setting).
"""
if isinstance(idx, int):
return self._feature_info[idx]
return [self._feature_info[i] for i in self.out_indices]
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) x = self.act1(x)

Loading…
Cancel
Save