Cleanup model_factory imports, consistent __all__ for models, fixed inception_v4 weight url

pull/6/head
Ross Wightman 5 years ago
parent e6c14427c0
commit 6bff9c75dc

@ -12,7 +12,8 @@ from models.adaptive_avgmax_pool import *
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
_models = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
__all__ = ['DenseNet'] + _models
def _cfg(url=''):

@ -19,7 +19,8 @@ from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d
from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
_models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
__all__ = ['DPN'] + _models
def _cfg(url=''):
@ -32,18 +33,12 @@ def _cfg(url=''):
default_cfgs = {
'dpn68':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
'dpn68b_extra':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
'dpn92_extra':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
'dpn98':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
'dpn131':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
'dpn107_extra':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
'dpn68': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
'dpn68b_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
'dpn92_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
'dpn98': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
'dpn131': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
'dpn107_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
}

@ -26,10 +26,12 @@ from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from models.conv2d_same import sconv2d
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140',
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small',
'mobilenetv1_100', 'mobilenetv2_100', 'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100',
'chamnetv1_100', 'chamnetv2_100', 'fbnetc_100', 'spnasnet_100']
_models = [
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100']
__all__ = ['GenMobileNet', 'genmobilenet_model_names'] + _models
def _cfg(url='', **kwargs):
@ -67,7 +69,7 @@ default_cfgs = {
'spnasnet_100': _cfg(url='https://www.dropbox.com/s/iieopt18rytkgaa/spnasnet_100-048bc3f4.pth?dl=1'),
}
_DEBUG = True
_DEBUG = False
# Default args for PyTorch BN impl
_BN_MOMENTUM_PT_DEFAULT = 0.1
@ -266,7 +268,7 @@ class _BlockBuilder:
def __init__(self, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
act_fn=None, se_gate_fn=torch.sigmoid, se_reduce_mid=False,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
folded_bn=False, padding_same=False):
folded_bn=False, padding_same=False, verbose=False):
self.depth_multiplier = depth_multiplier
self.depth_divisor = depth_divisor
self.min_depth = min_depth
@ -277,6 +279,7 @@ class _BlockBuilder:
self.bn_eps = bn_eps
self.folded_bn = folded_bn
self.padding_same = padding_same
self.verbose = verbose
self.in_chs = None
def _round_channels(self, chs):
@ -293,7 +296,7 @@ class _BlockBuilder:
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
if _DEBUG:
if self.verbose:
print('args:', ba)
# could replace this if with lambdas or functools binding if variety increases
if bt == 'ir':
@ -315,7 +318,7 @@ class _BlockBuilder:
blocks = []
# each stack (stage) contains a list of block arguments
for block_idx, ba in enumerate(stack_args):
if _DEBUG:
if self.verbose:
print('block', block_idx, end=', ')
if block_idx >= 1:
# only the first block in any stack/stage can have a stride > 1
@ -334,18 +337,18 @@ class _BlockBuilder:
List of block stacks (each stack wrapped in nn.Sequential)
"""
arch_args = _decode_arch_def(arch_def) # convert and expand string defs to arg dicts
if _DEBUG:
if self.verbose:
print('Building model trunk with %d stacks (stages)...' % len(arch_args))
self.in_chs = in_chs
blocks = []
# outer list of arch_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(arch_args):
if _DEBUG:
if self.verbose:
print('stack', stack_idx)
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
if _DEBUG:
if self.verbose:
print()
return blocks
@ -631,7 +634,7 @@ class GenMobileNet(nn.Module):
builder = _BlockBuilder(
depth_multiplier, depth_divisor, min_depth,
act_fn, se_gate_fn, se_reduce_mid,
bn_momentum, bn_eps, folded_bn, padding_same)
bn_momentum, bn_eps, folded_bn, padding_same, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
@ -1265,3 +1268,7 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def genmobilenet_model_names():
return set(_models)

@ -11,13 +11,14 @@ from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['GluonResNet', 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b',
'gluon_resnet152_v1b', 'gluon_resnet50_v1c', 'gluon_resnet101_v1c', 'gluon_resnet152_v1c', 'gluon_resnet50_v1d',
'gluon_resnet101_v1d', 'gluon_resnet152_v1d', 'gluon_resnet50_v1e', 'gluon_resnet101_v1e', 'gluon_resnet152_v1e',
'gluon_resnet50_v1s', 'gluon_resnet101_v1s', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d',
'gluon_resnext101_32x4d', 'gluon_resnext101_64x4d', 'gluon_resnext152_32x4d', 'gluon_seresnext50_32x4d',
'gluon_seresnext101_32x4d', 'gluon_seresnext101_64x4d', 'gluon_seresnext152_32x4d', 'gluon_senet154'
]
_models = [
'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b',
'gluon_resnet50_v1c', 'gluon_resnet101_v1c', 'gluon_resnet152_v1c', 'gluon_resnet50_v1d', 'gluon_resnet101_v1d',
'gluon_resnet152_v1d', 'gluon_resnet50_v1e', 'gluon_resnet101_v1e', 'gluon_resnet152_v1e', 'gluon_resnet50_v1s',
'gluon_resnet101_v1s', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d',
'gluon_resnext101_64x4d', 'gluon_resnext152_32x4d', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d',
'gluon_seresnext101_64x4d', 'gluon_seresnext152_32x4d', 'gluon_senet154']
__all__ = ['GluonResNet'] + _models
def _cfg(url='', **kwargs):

@ -9,6 +9,9 @@ from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import *
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
_models = ['inception_resnet_v2']
__all__ = ['InceptionResnetV2'] + _models
default_cfgs = {
'inception_resnet_v2': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',

@ -2,6 +2,9 @@ from torchvision.models import Inception3
from models.helpers import load_pretrained
from data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
__all__ = _models
default_cfgs = {
# original PyTorch weights, ported from Tensorflow but modified
'inception_v3': {

@ -9,13 +9,16 @@ from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import *
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
_models = ['inception_v4']
__all__ = ['InceptionV4'] + _models
default_cfgs = {
'inception_v4': {
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'features.0.conv', 'classifier': 'classif',
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
}
}
@ -268,7 +271,7 @@ class InceptionV4(nn.Module):
Inception_C(),
Inception_C(),
)
self.classif = nn.Linear(self.num_features, num_classes)
self.last_linear = nn.Linear(self.num_features, num_classes)
def get_classifier(self):
return self.classif
@ -289,7 +292,7 @@ class InceptionV4(nn.Module):
x = self.forward_features(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classif(x)
x = self.last_linear(x)
return x

@ -1,38 +1,18 @@
from models.inception_v4 import inception_v4
from models.inception_resnet_v2 import inception_resnet_v2
from models.densenet import densenet161, densenet121, densenet169, densenet201
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
from models.dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
from models.xception import xception
from models.pnasnet import pnasnet5large
from models.genmobilenet import \
mnasnet_050, mnasnet_075, mnasnet_100, mnasnet_140, tflite_mnasnet_100,\
semnasnet_050, semnasnet_075, semnasnet_100, semnasnet_140, tflite_semnasnet_100, mnasnet_small,\
mobilenetv1_100, mobilenetv2_100, mobilenetv3_050, mobilenetv3_075, mobilenetv3_100,\
fbnetc_100, chamnetv1_100, chamnetv2_100, spnasnet_100
from models.inception_v3 import inception_v3, gluon_inception_v3, tf_inception_v3, adv_inception_v3
from models.gluon_resnet import gluon_resnet18_v1b, gluon_resnet34_v1b, gluon_resnet50_v1b, gluon_resnet101_v1b, \
gluon_resnet152_v1b, gluon_resnet50_v1c, gluon_resnet101_v1c, gluon_resnet152_v1c, \
gluon_resnet50_v1d, gluon_resnet101_v1d, gluon_resnet152_v1d, \
gluon_resnet50_v1e, gluon_resnet101_v1e, gluon_resnet152_v1e, \
gluon_resnet50_v1s, gluon_resnet101_v1s, gluon_resnet152_v1s, \
gluon_resnext50_32x4d, gluon_resnext101_32x4d , gluon_resnext101_64x4d, gluon_resnext152_32x4d, \
gluon_seresnext50_32x4d, gluon_seresnext101_32x4d, gluon_seresnext101_64x4d, gluon_seresnext152_32x4d, \
gluon_senet154
from models.inception_v4 import *
from models.inception_resnet_v2 import *
from models.densenet import *
from models.resnet import *
from models.dpn import *
from models.senet import *
from models.xception import *
from models.pnasnet import *
from models.genmobilenet import *
from models.inception_v3 import *
from models.gluon_resnet import *
from models.helpers import load_checkpoint
def _is_genmobilenet(name):
genmobilenets = ['mnasnet', 'semnasnet', 'fbnet', 'chamnet', 'mobilenet']
if any([name.startswith(x) for x in genmobilenets]):
return True
return False
def create_model(
model_name='resnet50',
pretrained=None,
@ -44,8 +24,7 @@ def create_model(
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
# Not all models have support for batchnorm params passed as args, only genmobilenet variants
# FIXME better way to do this without pushing support into every other model fn?
supports_bn_params = _is_genmobilenet(model_name)
supports_bn_params = model_name in genmobilenet_model_names()
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None)

@ -15,6 +15,9 @@ import torch.nn.functional as F
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
_models = ['pnasnet5large']
__all__ = ['PNASNet5Large'] + _models
default_cfgs = {
'pnasnet5large': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',

@ -12,8 +12,9 @@ from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
__all__ = ['ResNet'] + _models
def _cfg(url='', **kwargs):

@ -19,8 +19,9 @@ from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
_models = ['senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d']
__all__ = ['SENet'] + _models
def _cfg(url='', **kwargs):

@ -30,8 +30,8 @@ import torch.nn.functional as F
from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d
__all__ = ['xception']
_models = ['xception']
__all__ = ['Xception'] + _models
default_cfgs = {
'xception': {

Loading…
Cancel
Save