parent
8512436436
commit
171c0b88b6
@ -1,2 +1,2 @@
|
||||
from .version import __version__
|
||||
from .models import create_model
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint
|
||||
|
@ -0,0 +1,44 @@
|
||||
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||
from .helpers import load_checkpoint
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name,
|
||||
pretrained=False,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
"""Create a model
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
num_classes (int): number of classes for final fully connected layer (default: 1000)
|
||||
in_chans (int): number of input channels / colors (default: 3)
|
||||
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are model specific
|
||||
"""
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
model = create_fn(**margs, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
@ -1,42 +0,0 @@
|
||||
from .inception_v4 import *
|
||||
from .inception_resnet_v2 import *
|
||||
from .densenet import *
|
||||
from .resnet import *
|
||||
from .dpn import *
|
||||
from .senet import *
|
||||
from .xception import *
|
||||
from .pnasnet import *
|
||||
from .gen_efficientnet import *
|
||||
from .inception_v3 import *
|
||||
from .gluon_resnet import *
|
||||
|
||||
from .helpers import load_checkpoint
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name,
|
||||
pretrained=False,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||
supports_bn_params = model_name in gen_efficientnet_model_names()
|
||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
|
||||
if model_name in globals():
|
||||
create_fn = globals()[model_name]
|
||||
model = create_fn(**margs, **kwargs)
|
||||
else:
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
@ -0,0 +1,78 @@
|
||||
import sys
|
||||
import re
|
||||
import fnmatch
|
||||
from collections import defaultdict
|
||||
|
||||
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
|
||||
|
||||
_module_to_models = defaultdict(set)
|
||||
_model_to_module = {}
|
||||
_model_entrypoints = {}
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(fn.__name__)
|
||||
else:
|
||||
mod.__all__ = [fn.__name__]
|
||||
_model_entrypoints[fn.__name__] = fn
|
||||
_model_to_module[fn.__name__] = module_name
|
||||
_module_to_models[module_name].add(fn.__name__)
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(filter='', module=''):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if module:
|
||||
models = list(_module_to_models[module])
|
||||
else:
|
||||
models = _model_entrypoints.keys()
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter)
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
return model_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name):
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
return _model_entrypoints[model_name]
|
||||
|
||||
|
||||
def list_modules():
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return list(sorted(modules))
|
||||
|
||||
|
||||
def is_model_in_modules(model_name, module_names):
|
||||
"""Check if a model exists within a subset of modules
|
||||
Args:
|
||||
model_name (str) - name of model to check
|
||||
module_names (tuple, list, set) - names of modules to search in
|
||||
"""
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(model_name in _module_to_models[n] for n in module_names)
|
||||
|
Loading…
Reference in new issue