diff --git a/tests/test_models.py b/tests/test_models.py index bad2a78c..57d78a8e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -188,25 +188,22 @@ def test_model_default_cfgs_non_std(model_name, batch_size): input_tensor = torch.randn((batch_size, *input_size)) - # test forward_features (always unpooled) - if 'crossvit' not in model_name: - # FIXME remove crossvit exception - outputs = model.forward_features(input_tensor) - if isinstance(outputs, tuple): - outputs = outputs[0] - assert outputs.shape[1] == model.num_features + outputs = model.forward_features(input_tensor) + if isinstance(outputs, (tuple, list)): + outputs = outputs[0] + assert outputs.shape[1] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) outputs = model.forward(input_tensor) - if isinstance(outputs, tuple): + if isinstance(outputs, (tuple, list)): outputs = outputs[0] assert len(outputs.shape) == 2 assert outputs.shape[1] == model.num_features model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) - if isinstance(outputs, tuple): + if isinstance(outputs, (tuple, list)): outputs = outputs[0] assert len(outputs.shape) == 2 assert outputs.shape[1] == model.num_features diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 9eee9dee..12eebdc5 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -268,12 +268,9 @@ class CrossViT(nn.Module): super().__init__() self.num_classes = num_classes - if not isinstance(img_size, (tuple, list)): - img_size = to_2tuple(img_size) - self.img_size = img_size - if not isinstance(img_scale, (tuple, list)): - img_scale = to_2tuple(img_scale) - self.img_size_scaled = [tuple([int(sj * si) for sj in img_size]) for si in img_scale] + self.img_size = to_2tuple(img_size) + img_scale = to_2tuple(img_scale) + self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale] num_patches = _compute_num_patches(self.img_size_scaled, patch_size) self.num_branches = len(patch_size) self.embed_dim = embed_dim @@ -346,7 +343,7 @@ class CrossViT(nn.Module): xs = [] for i, patch_embed in enumerate(self.patch_embed): ss = self.img_size_scaled[i] - x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic') if H != ss[0] else x + x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) if H != ss[0] else x tmp = patch_embed(x_) cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = cls_tokens.expand(B, -1, -1) @@ -361,15 +358,12 @@ class CrossViT(nn.Module): # NOTE: was before branch token section, move to here to assure all branch token are before layer norm xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] - return tuple([x[:, 0] for x in xs]) + return [x[:, 0] for x in xs] def forward(self, x): xs = self.forward_features(x) ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] - if isinstance(self.head[0], nn.Identity): - # FIXME to pass current passthrough features tests, could use better approach - ce_logits = tuple(ce_logits) - else: + if not isinstance(self.head[0], nn.Identity): ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) return ce_logits