Fix stem width for really small mobilenetv3 arch defs

pull/1091/head
Ross Wightman 3 years ago committed by Ross Wightman
parent edd3d73695
commit fa81164378

@ -110,8 +110,9 @@ class MobileNetV3(nn.Module):
* LCNet - https://arxiv.org/abs/2109.15099
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
def __init__(
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
@ -122,6 +123,7 @@ class MobileNetV3(nn.Module):
self.drop_rate = drop_rate
# Stem
if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
@ -188,8 +190,8 @@ class MobileNetV3Features(nn.Module):
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -197,6 +199,7 @@ class MobileNetV3Features(nn.Module):
self.drop_rate = drop_rate
# Stem
if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
@ -381,6 +384,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
fix_stem=channel_multiplier < 0.75,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,

Loading…
Cancel
Save