Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 7997ddc30d
commit f6993dbf20

@ -458,13 +458,13 @@ class DaViT(nn.Module):
self.apply(self._init_weights) self.apply(self._init_weights)
#self.forward = self._get_forward_fn() self.forward = self._get_forward_fn()
'''
if self._features_only == True: if self._features_only == True:
self.forward = self.forward_features_full self.forward = self.forward_features_full
else: else:
self.forward = self.forward_classification self.forward = self.forward_classification
'''
''' '''
def _get_forward_fn(self): def _get_forward_fn(self):
@ -609,13 +609,10 @@ class DaViT(nn.Module):
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
return x return x
'''
def forward(self, x): def forward(self, x):
if self.features_only == True: return x
return self.forward_features_full(x)
else:
return self.forward_classification(x)
'''
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """

Loading…
Cancel
Save