@ -1,5 +1,6 @@
from . registry import is_model , is_model_in_modules , model_entrypoint
from . registry import is_model , is_model_in_modules , model_entrypoint
from . helpers import load_checkpoint
from . helpers import load_checkpoint
from . layers import set_layer_config
def create_model (
def create_model (
@ -8,6 +9,9 @@ def create_model(
num_classes = 1000 ,
num_classes = 1000 ,
in_chans = 3 ,
in_chans = 3 ,
checkpoint_path = ' ' ,
checkpoint_path = ' ' ,
scriptable = None ,
exportable = None ,
no_jit = None ,
* * kwargs ) :
* * kwargs ) :
""" Create a model
""" Create a model
@ -17,13 +21,16 @@ def create_model(
num_classes ( int ) : number of classes for final fully connected layer ( default : 1000 )
num_classes ( int ) : number of classes for final fully connected layer ( default : 1000 )
in_chans ( int ) : number of input channels / colors ( default : 3 )
in_chans ( int ) : number of input channels / colors ( default : 3 )
checkpoint_path ( str ) : path of checkpoint to load after model is initialized
checkpoint_path ( str ) : path of checkpoint to load after model is initialized
scriptable ( bool ) : set layer config so that model is jit scriptable ( not working for all models yet )
exportable ( bool ) : set layer config so that model is traceable / ONNX exportable ( not fully impl / obeyed yet )
no_jit ( bool ) : set layer config so that model doesn ' t utilize jit scripted layers (so far activations only)
Keyword Args :
Keyword Args :
drop_rate ( float ) : dropout rate for training ( default : 0.0 )
drop_rate ( float ) : dropout rate for training ( default : 0.0 )
global_pool ( str ) : global pool type ( default : ' avg ' )
global_pool ( str ) : global pool type ( default : ' avg ' )
* * : other kwargs are model specific
* * : other kwargs are model specific
"""
"""
m args = dict ( pretrained = pretrained , num_classes = num_classes , in_chans = in_chans )
m odel_ args = dict ( pretrained = pretrained , num_classes = num_classes , in_chans = in_chans )
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules ( model_name , [ ' efficientnet ' , ' mobilenetv3 ' ] )
is_efficientnet = is_model_in_modules ( model_name , [ ' efficientnet ' , ' mobilenetv3 ' ] )
@ -47,11 +54,12 @@ def create_model(
if kwargs . get ( ' drop_path_rate ' , None ) is None :
if kwargs . get ( ' drop_path_rate ' , None ) is None :
kwargs . pop ( ' drop_path_rate ' , None )
kwargs . pop ( ' drop_path_rate ' , None )
if is_model ( model_name ) :
with set_layer_config ( scriptable = scriptable , exportable = exportable , no_jit = no_jit ) :
create_fn = model_entrypoint ( model_name )
if is_model ( model_name ) :
model = create_fn ( * * margs , * * kwargs )
create_fn = model_entrypoint ( model_name )
else :
model = create_fn ( * * model_args , * * kwargs )
raise RuntimeError ( ' Unknown model ( %s ) ' % model_name )
else :
raise RuntimeError ( ' Unknown model ( %s ) ' % model_name )
if checkpoint_path :
if checkpoint_path :
load_checkpoint ( model , checkpoint_path )
load_checkpoint ( model , checkpoint_path )