Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 007be8319b
commit eef8673a93

@ -456,6 +456,13 @@ class DaViT(nn.Module):
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights)
if self.features_only == True:
self.forward = self.forward_features_full
else:
self.forward = self.forward_classification
def _init_weights(self, m):
if isinstance(m, nn.Linear):
@ -574,12 +581,13 @@ class DaViT(nn.Module):
x = self.forward_features(x)
x = self.forward_head(x)
return x
'''
def forward(self, x):
if self.features_only == True:
return self.forward_features_full(x)
else:
return self.forward_classification(x)
'''
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """

Loading…
Cancel
Save