Fix stem width for really small mobilenetv3 arch defs

pull/1091/head
Ross Wightman 3 years ago committed by Ross Wightman
parent edd3d73695
commit fa81164378

@ -110,9 +110,10 @@ class MobileNetV3(nn.Module):
* LCNet - https://arxiv.org/abs/2109.15099 * LCNet - https://arxiv.org/abs/2109.15099
""" """
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, def __init__(
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True, self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -122,7 +123,8 @@ class MobileNetV3(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
# Stem # Stem
stem_size = round_chs_fn(stem_size) if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
@ -188,8 +190,8 @@ class MobileNetV3Features(nn.Module):
""" """
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True, stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -197,7 +199,8 @@ class MobileNetV3Features(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
# Stem # Stem
stem_size = round_chs_fn(stem_size) if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
@ -381,6 +384,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
block_args=decode_arch_def(arch_def), block_args=decode_arch_def(arch_def),
num_features=num_features, num_features=num_features,
stem_size=16, stem_size=16,
fix_stem=channel_multiplier < 0.75,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer, act_layer=act_layer,

Loading…
Cancel
Save