Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent f7a8fb9f97
commit 3b993a9301

@ -383,6 +383,7 @@ class DaViT(nn.Module):
img_size=224, img_size=224,
num_classes=1000, num_classes=1000,
global_pool='avg' global_pool='avg'
features_only = False
): ):
super().__init__() super().__init__()
@ -398,7 +399,7 @@ class DaViT(nn.Module):
self.num_features = embed_dims[-1] self.num_features = embed_dims[-1]
self.drop_rate=drop_rate self.drop_rate=drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
self.features_only = False
self.patch_embeds = nn.ModuleList([ self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2, PatchEmbed(patch_size=patch_size if i == 0 else 2,
@ -567,14 +568,19 @@ class DaViT(nn.Module):
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits) return self.head(x, pre_logits=pre_logits)
def forward(self, x): def forward_classification(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
return x return x
def forward(self, x):
if self.features_only == True:
return forward_features_full(x)
else:
return forward(x)
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict: if 'head.norm.weight' in state_dict:

Loading…
Cancel
Save