Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent f7a8fb9f97
commit 3b993a9301

@ -383,6 +383,7 @@ class DaViT(nn.Module):
img_size=224,
num_classes=1000,
global_pool='avg'
features_only = False
):
super().__init__()
@ -398,7 +399,7 @@ class DaViT(nn.Module):
self.num_features = embed_dims[-1]
self.drop_rate=drop_rate
self.grad_checkpointing = False
self.features_only = False
self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2,
@ -567,13 +568,18 @@ class DaViT(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
def forward_classification(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def forward(self, x):
if self.features_only == True:
return forward_features_full(x)
else:
return forward(x)
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """

Loading…
Cancel
Save