|
|
@ -383,6 +383,7 @@ class DaViT(nn.Module):
|
|
|
|
img_size=224,
|
|
|
|
img_size=224,
|
|
|
|
num_classes=1000,
|
|
|
|
num_classes=1000,
|
|
|
|
global_pool='avg'
|
|
|
|
global_pool='avg'
|
|
|
|
|
|
|
|
features_only = False
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
@ -398,7 +399,7 @@ class DaViT(nn.Module):
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
self.features_only = False
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
|
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
|
|
|
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
|
|
@ -567,14 +568,19 @@ class DaViT(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
|
|
|
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward_classification(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
if self.features_only == True:
|
|
|
|
|
|
|
|
return forward_features_full(x)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return forward(x)
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
if 'head.norm.weight' in state_dict:
|
|
|
|
if 'head.norm.weight' in state_dict:
|
|
|
|