@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs . shape [ - 1 ] == pool_size [ - 1 ] and outputs . shape [ - 2 ] == pool_size [ - 2 ]
if ' pruned ' not in model_name : # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
model = create_model ( model_name , pretrained = False , num_classes = 0 , global_pool = ' ' ) . eval ( )
outputs = model . forward ( input_tensor )
assert len ( outputs . shape ) == 4
if not isinstance ( model , timm . models . MobileNetV3 ) and not isinstance ( model , timm . models . GhostNet ) :
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs . shape [ - 1 ] == pool_size [ - 1 ] and outputs . shape [ - 2 ] == pool_size [ - 2 ]
# check classifier name matches default_cfg
classifier = cfg [ ' classifier ' ]
if not isinstance ( classifier , ( tuple , list ) ) :
@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
assert len ( outputs . shape ) == 2
assert outputs . shape [ 1 ] == model . num_features
model = create_model ( model_name , pretrained = False , num_classes = 0 ) . eval ( )
outputs = model . forward ( input_tensor )
if isinstance ( outputs , tuple ) :
outputs = outputs [ 0 ]
assert len ( outputs . shape ) == 2
assert outputs . shape [ 1 ] == model . num_features
# check classifier name matches default_cfg
classifier = cfg [ ' classifier ' ]
if not isinstance ( classifier , ( tuple , list ) ) :
@ -217,6 +233,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
""" Create that pretrained weights load, verify support for in_chans != 3 while doing so. """
in_chans = 3 if ' pruned ' in model_name else 1 # pruning not currently supported with in_chans change
create_model ( model_name , pretrained = True , in_chans = in_chans , num_classes = 5 )
create_model ( model_name , pretrained = True , in_chans = in_chans , num_classes = 0 )
@pytest.mark.timeout ( 120 )
@pytest.mark.parametrize ( ' model_name ' , list_models ( pretrained = True , exclude_filters = NON_STD_FILTERS ) )