Add the MnasNet-B1 variant weights, add A1/B1 model names as in the stand-alone repo, remove a bit of unused code

pull/13/head
Ross Wightman 5 years ago
parent c1a84ecb22
commit f3134973b5

@ -69,10 +69,11 @@ I've leveraged the training scripts in this repository to train a few of the mod
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic |
| efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic |
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic |
| semnasnet_100 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic |
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic |
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear |
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear |
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear |
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.38M | bicubic |
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.42M | bilinear |
| seresnet18 | 71.742 (28.258) | 90.334 (9.666) | 11.8M | bicubic |

@ -28,8 +28,8 @@ from models.conv2d_same import sconv2d
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
_models = [
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'tf_efficientnet_b0',
@ -50,7 +50,9 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'mnasnet_050': _cfg(url=''),
'mnasnet_075': _cfg(url=''),
'mnasnet_100': _cfg(url=''),
'mnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
interpolation='bicubic'),
'tflite_mnasnet_100': _cfg(
url='https://www.dropbox.com/s/q55ir3tx8mpeyol/tflite_mnasnet_100-31639cdc.pth?dl=1',
interpolation='bicubic'),
@ -161,8 +163,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act,
ca = Cascade3x3, and possibly more)
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
@ -227,15 +228,6 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
block_args['pw_group'] = options['g']
if options['g'] > 1:
block_args['shuffle_type'] = 'mid'
elif block_type == 'ca':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
@ -345,8 +337,6 @@ class _BlockBuilder:
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = self.drop_connect_rate
block = DepthwiseSeparableConv(**ba)
elif bt == 'ca':
block = CascadeConv(**ba)
elif bt == 'cn':
block = ConvBnAct(**ba)
else:
@ -565,36 +555,6 @@ class DepthwiseSeparableConv(nn.Module):
return x
class CascadeConv(nn.Sequential):
# FIXME haven't used yet
def __init__(self, in_chs, out_chs, kernel_size=3, stride=2, act_fn=F.relu, noskip=False,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
super(CascadeConv, self).__init__()
assert stride in [1, 2]
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.act_fn = act_fn
padding = _padding_arg(1, padding_same)
self.conv1 = sconv2d(in_chs, in_chs, kernel_size, stride=stride, padding=padding, bias=folded_bn)
self.bn1 = None if folded_bn else nn.BatchNorm2d(in_chs, momentum=bn_momentum, eps=bn_eps)
self.conv2 = sconv2d(in_chs, out_chs, kernel_size, stride=1, padding=padding, bias=folded_bn)
self.bn2 = None if folded_bn else nn.BatchNorm2d(out_chs, momentum=bn_momentum, eps=bn_eps)
def forward(self, x):
residual = x
x = self.conv1(x)
if self.bn1 is not None:
x = self.bn1(x)
x = self.act_fn(x)
x = self.conv2(x)
if self.bn2 is not None:
x = self.bn2(x)
if self.has_residual:
x += residual
return x
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE"""
@ -699,7 +659,6 @@ class GenEfficientNet(nn.Module):
super(GenEfficientNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.drop_connect_rate = drop_connect_rate
self.act_fn = act_fn
self.num_features = num_features
@ -730,7 +689,7 @@ class GenEfficientNet(nn.Module):
nn.BatchNorm2d(self.num_features, momentum=bn_momentum, eps=bn_eps)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, self.num_classes)
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
for m in self.modules():
if weight_init == 'goog':
@ -1220,6 +1179,11 @@ def mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def mnasnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
return mnasnet_100(num_classes, in_chans, pretrained, **kwargs)
def tflite_mnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet B1, depth multiplier of 1.0. """
default_cfg = default_cfgs['tflite_mnasnet_100']
@ -1273,6 +1237,11 @@ def semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
return model
def mnasnet_a1(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet A1 (w/ SE), depth multiplier of 1.0. """
return semnasnet_100(num_classes, in_chans, pretrained, **kwargs)
def tflite_semnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
""" MNASNet A1, depth multiplier of 1.0. """
default_cfg = default_cfgs['tflite_semnasnet_100']

Loading…
Cancel
Save