diff --git a/tests/test_models.py b/tests/test_models.py index 822e0f2f..5c79dd2e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,15 +1,24 @@ import pytest import torch +import platform +import os +import fnmatch from timm import list_models, create_model -MAX_FWD_SIZE = 320 + +if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system(): + # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models + EXCLUDE_FILTERS = ['*efficientnet_l2*'] +else: + EXCLUDE_FILTERS = [] +MAX_FWD_SIZE = 384 MAX_BWD_SIZE = 128 MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" @@ -28,7 +37,8 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters='dla*')) # DLA models have an issue TBD +# DLA models have an issue TBD, add them to exclusions +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + ['dla*'])) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" @@ -65,7 +75,8 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] - if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and 'efficientnet_l2' not in model_name: + if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ + not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): # pool size only checked if default res <= 448 * 448 to keep resource down input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) outputs = model.forward_features(torch.randn((batch_size, *input_size)))