From 51febd869b4afa833838ad33a5e7e1781c70095c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 29 Mar 2021 11:33:08 -0700 Subject: [PATCH] Small tweak to tests for tnt model, reorder model imports. --- tests/test_models.py | 3 ++- timm/models/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 8e2b674e..1f70d115 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,6 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = ['vit_*', 'tnt_*'] +NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): @@ -31,7 +32,7 @@ MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 08e2b59a..ab3c4b2f 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -6,6 +6,7 @@ from .dpn import * from .efficientnet import * from .gluon_resnet import * from .gluon_xception import * +from .hardcorenas import * from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * @@ -23,14 +24,13 @@ from .rexnet import * from .selecsls import * from .senet import * from .sknet import * +from .tnt import * from .tresnet import * from .vgg import * from .vision_transformer import * from .vovnet import * from .xception import * from .xception_aligned import * -from .hardcorenas import * -from .tnt import * from .factory import create_model, split_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters