fix bug for reset classifier and fix for validating the dimension

pull/841/head
Richard Chen 3 years ago
parent 3718c5a5bd
commit 9fe5798bee

@ -243,7 +243,8 @@ class CrossViT(nn.Module):
num_patches = _compute_num_patches(img_size, patch_size)
self.num_branches = len(patch_size)
self.embed_dim = embed_dim
self.num_features = embed_dim[0] # to pass the tests
self.patch_embed = nn.ModuleList()
# hard-coded for torch jit script
@ -274,7 +275,6 @@ class CrossViT(nn.Module):
for i in range(self.num_branches):
if hasattr(self, f'pos_embed_{i}'):
# if self.pos_embed[i].requires_grad:
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
@ -301,7 +301,7 @@ class CrossViT(nn.Module):
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
def forward_features(self, x):
B, C, H, W = x.shape

Loading…
Cancel
Save