|
|
@ -243,7 +243,8 @@ class CrossViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
num_patches = _compute_num_patches(img_size, patch_size)
|
|
|
|
num_patches = _compute_num_patches(img_size, patch_size)
|
|
|
|
self.num_branches = len(patch_size)
|
|
|
|
self.num_branches = len(patch_size)
|
|
|
|
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
|
|
|
|
self.num_features = embed_dim[0] # to pass the tests
|
|
|
|
self.patch_embed = nn.ModuleList()
|
|
|
|
self.patch_embed = nn.ModuleList()
|
|
|
|
|
|
|
|
|
|
|
|
# hard-coded for torch jit script
|
|
|
|
# hard-coded for torch jit script
|
|
|
@ -274,7 +275,6 @@ class CrossViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
if hasattr(self, f'pos_embed_{i}'):
|
|
|
|
if hasattr(self, f'pos_embed_{i}'):
|
|
|
|
# if self.pos_embed[i].requires_grad:
|
|
|
|
|
|
|
|
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
|
|
|
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
|
|
|
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
|
|
|
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
|
|
|
|
|
|
|
|
|
|
@ -301,7 +301,7 @@ class CrossViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
B, C, H, W = x.shape
|
|
|
|