diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index f9296b74..ff529064 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -243,7 +243,8 @@ class CrossViT(nn.Module): num_patches = _compute_num_patches(img_size, patch_size) self.num_branches = len(patch_size) - + self.embed_dim = embed_dim + self.num_features = embed_dim[0] # to pass the tests self.patch_embed = nn.ModuleList() # hard-coded for torch jit script @@ -274,7 +275,6 @@ class CrossViT(nn.Module): for i in range(self.num_branches): if hasattr(self, f'pos_embed_{i}'): - # if self.pos_embed[i].requires_grad: trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) @@ -301,7 +301,7 @@ class CrossViT(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) def forward_features(self, x): B, C, H, W = x.shape