diff --git a/tests/test_models.py b/tests/test_models.py index f06ddd95..6489892c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -202,13 +202,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size): pytest.skip("Fixed input size model > limit.") input_tensor = torch.randn((batch_size, *input_size)) + feat_dim = getattr(model, 'feature_dim', None) outputs = model.forward_features(input_tensor) if isinstance(outputs, (tuple, list)): # cannot currently verify multi-tensor output. pass else: - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features @@ -216,14 +218,16 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # check classifier name matches default_cfg diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 5fff04d1..b1ae92a4 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -288,6 +288,7 @@ class Sequencer2D(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models + self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC) self.embed_dims = embed_dims self.stem = PatchEmbed( img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, @@ -333,7 +334,7 @@ class Sequencer2D(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if self.global_pool is not None: + if global_pool is not None: assert global_pool in ('', 'avg') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()