diff --git a/tests/test_models.py b/tests/test_models.py index 94744483..0f9b8c0b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 0a54f7fe..2da323cf 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -26,7 +26,7 @@ from .registry import register_model def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True, 'crop_pct': .95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv1', 'classifier': 'head', diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 3e2cd96a..c134b7c2 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -209,8 +209,7 @@ class WindowAttentionGlobal(nn.Module): def forward(self, x, q_global: Optional[torch.Tensor] = None): B, N, C = x.shape - if self.use_global: - _assert(q_global is not None, 'q_global must be passed in global mode') + if self.use_global and q_global is not None: _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal') kv = self.qkv(x) diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 551a8325..1f698fbc 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -286,7 +286,6 @@ class PyramidVisionTransformerV2(nn.Module): self.num_classes = num_classes assert global_pool in ('avg', '') self.global_pool = global_pool - self.img_size = to_2tuple(img_size) if img_size is not None else None self.depths = depths num_stages = len(depths) mlp_ratios = to_ntuple(num_stages)(mlp_ratios) @@ -324,7 +323,8 @@ class PyramidVisionTransformerV2(nn.Module): cur += depths[i] # classification head - self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + self.num_features = embed_dims[-1] + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights)