From 5db74521736f6d6caef4e0cd8aba1ad624ae9390 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 14:11:36 -0700 Subject: [PATCH] Fix visformer in_chans stem handling --- tests/test_models.py | 2 +- timm/models/visformer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index de664068..44cb3ba2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -190,7 +190,7 @@ EXCLUDE_JIT_FILTERS = [ def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") with set_scriptable(True): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 936f1ddf..33a2fe87 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -26,7 +26,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'stem.0', 'classifier': 'head', **kwargs } @@ -183,7 +183,7 @@ class Visformer(nn.Module): img_size //= 8 else: self.stem = nn.Sequential( - nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(self.init_channels), nn.ReLU(inplace=True) )