parent
8512436436
commit
171c0b88b6
@ -1,2 +1,2 @@
|
|||||||
from .version import __version__
|
from .version import __version__
|
||||||
from .models import create_model
|
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
|
||||||
|
@ -0,0 +1,44 @@
|
|||||||
|
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||||
|
from .helpers import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
model_name,
|
||||||
|
pretrained=False,
|
||||||
|
num_classes=1000,
|
||||||
|
in_chans=3,
|
||||||
|
checkpoint_path='',
|
||||||
|
**kwargs):
|
||||||
|
"""Create a model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): name of model to instantiate
|
||||||
|
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||||
|
num_classes (int): number of classes for final fully connected layer (default: 1000)
|
||||||
|
in_chans (int): number of input channels / colors (default: 3)
|
||||||
|
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
drop_rate (float): dropout rate for training (default: 0.0)
|
||||||
|
global_pool (str): global pool type (default: 'avg')
|
||||||
|
**: other kwargs are model specific
|
||||||
|
"""
|
||||||
|
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||||
|
|
||||||
|
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||||
|
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||||
|
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||||
|
kwargs.pop('bn_tf', None)
|
||||||
|
kwargs.pop('bn_momentum', None)
|
||||||
|
kwargs.pop('bn_eps', None)
|
||||||
|
|
||||||
|
if is_model(model_name):
|
||||||
|
create_fn = model_entrypoint(model_name)
|
||||||
|
model = create_fn(**margs, **kwargs)
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||||
|
|
||||||
|
if checkpoint_path:
|
||||||
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
|
||||||
|
return model
|
@ -1,42 +0,0 @@
|
|||||||
from .inception_v4 import *
|
|
||||||
from .inception_resnet_v2 import *
|
|
||||||
from .densenet import *
|
|
||||||
from .resnet import *
|
|
||||||
from .dpn import *
|
|
||||||
from .senet import *
|
|
||||||
from .xception import *
|
|
||||||
from .pnasnet import *
|
|
||||||
from .gen_efficientnet import *
|
|
||||||
from .inception_v3 import *
|
|
||||||
from .gluon_resnet import *
|
|
||||||
|
|
||||||
from .helpers import load_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def create_model(
|
|
||||||
model_name,
|
|
||||||
pretrained=False,
|
|
||||||
num_classes=1000,
|
|
||||||
in_chans=3,
|
|
||||||
checkpoint_path='',
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
|
||||||
|
|
||||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
|
||||||
supports_bn_params = model_name in gen_efficientnet_model_names()
|
|
||||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
|
||||||
kwargs.pop('bn_tf', None)
|
|
||||||
kwargs.pop('bn_momentum', None)
|
|
||||||
kwargs.pop('bn_eps', None)
|
|
||||||
|
|
||||||
if model_name in globals():
|
|
||||||
create_fn = globals()[model_name]
|
|
||||||
model = create_fn(**margs, **kwargs)
|
|
||||||
else:
|
|
||||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
|
||||||
|
|
||||||
if checkpoint_path:
|
|
||||||
load_checkpoint(model, checkpoint_path)
|
|
||||||
|
|
||||||
return model
|
|
@ -0,0 +1,78 @@
|
|||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import fnmatch
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
|
||||||
|
|
||||||
|
_module_to_models = defaultdict(set)
|
||||||
|
_model_to_module = {}
|
||||||
|
_model_entrypoints = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_model(fn):
|
||||||
|
mod = sys.modules[fn.__module__]
|
||||||
|
module_name_split = fn.__module__.split('.')
|
||||||
|
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||||
|
if hasattr(mod, '__all__'):
|
||||||
|
mod.__all__.append(fn.__name__)
|
||||||
|
else:
|
||||||
|
mod.__all__ = [fn.__name__]
|
||||||
|
_model_entrypoints[fn.__name__] = fn
|
||||||
|
_model_to_module[fn.__name__] = module_name
|
||||||
|
_module_to_models[module_name].add(fn.__name__)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def _natural_key(string_):
|
||||||
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||||
|
|
||||||
|
|
||||||
|
def list_models(filter='', module=''):
|
||||||
|
""" Return list of available model names, sorted alphabetically
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter (str) - Wildcard filter string that works with fnmatch
|
||||||
|
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
|
||||||
|
|
||||||
|
Example:
|
||||||
|
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||||
|
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||||
|
"""
|
||||||
|
if module:
|
||||||
|
models = list(_module_to_models[module])
|
||||||
|
else:
|
||||||
|
models = _model_entrypoints.keys()
|
||||||
|
if filter:
|
||||||
|
models = fnmatch.filter(models, filter)
|
||||||
|
return list(sorted(models, key=_natural_key))
|
||||||
|
|
||||||
|
|
||||||
|
def is_model(model_name):
|
||||||
|
""" Check if a model name exists
|
||||||
|
"""
|
||||||
|
return model_name in _model_entrypoints
|
||||||
|
|
||||||
|
|
||||||
|
def model_entrypoint(model_name):
|
||||||
|
"""Fetch a model entrypoint for specified model name
|
||||||
|
"""
|
||||||
|
return _model_entrypoints[model_name]
|
||||||
|
|
||||||
|
|
||||||
|
def list_modules():
|
||||||
|
""" Return list of module names that contain models / model entrypoints
|
||||||
|
"""
|
||||||
|
modules = _module_to_models.keys()
|
||||||
|
return list(sorted(modules))
|
||||||
|
|
||||||
|
|
||||||
|
def is_model_in_modules(model_name, module_names):
|
||||||
|
"""Check if a model exists within a subset of modules
|
||||||
|
Args:
|
||||||
|
model_name (str) - name of model to check
|
||||||
|
module_names (tuple, list, set) - names of modules to search in
|
||||||
|
"""
|
||||||
|
assert isinstance(module_names, (tuple, list, set))
|
||||||
|
return any(model_name in _module_to_models[n] for n in module_names)
|
||||||
|
|
Loading…
Reference in new issue