|
|
@ -457,6 +457,13 @@ class DaViT(nn.Module):
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.features_only == True:
|
|
|
|
|
|
|
|
self.forward = self.forward_features_full
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.forward = self.forward_classification
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
def _init_weights(self, m):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
@ -574,12 +581,13 @@ class DaViT(nn.Module):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
'''
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
if self.features_only == True:
|
|
|
|
if self.features_only == True:
|
|
|
|
return self.forward_features_full(x)
|
|
|
|
return self.forward_features_full(x)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return self.forward_classification(x)
|
|
|
|
return self.forward_classification(x)
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|