Small tweak to tests for tnt model, reorder model imports.

pull/429/head
Ross Wightman 4 years ago
parent b27a4e0d88
commit 51febd869b

@ -15,6 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities # transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*'] NON_STD_FILTERS = ['vit_*', 'tnt_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
@ -31,7 +32,7 @@ MAX_FWD_FEAT_SIZE = 448
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1])) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD]))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size): def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""

@ -6,6 +6,7 @@ from .dpn import *
from .efficientnet import * from .efficientnet import *
from .gluon_resnet import * from .gluon_resnet import *
from .gluon_xception import * from .gluon_xception import *
from .hardcorenas import *
from .hrnet import * from .hrnet import *
from .inception_resnet_v2 import * from .inception_resnet_v2 import *
from .inception_v3 import * from .inception_v3 import *
@ -23,14 +24,13 @@ from .rexnet import *
from .selecsls import * from .selecsls import *
from .senet import * from .senet import *
from .sknet import * from .sknet import *
from .tnt import *
from .tresnet import * from .tresnet import *
from .vgg import * from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from .vovnet import * from .vovnet import *
from .xception import * from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .hardcorenas import *
from .tnt import *
from .factory import create_model, split_model_name, safe_model_name from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters

Loading…
Cancel
Save