|
|
|
@ -519,7 +519,9 @@ class DaViT(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.forward_classifier(self, x)
|
|
|
|
|
x = self.forward_classifier(self, x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
@ -528,8 +530,8 @@ class DaViTFeatures(DaViT):
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
return self.forward_pyramid_features(self, x)
|
|
|
|
|
|
|
|
|
|
x = self.forward_pyramid_features(self, x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|