Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 16bc4c078b
commit d8cce7d1e4

@ -525,8 +525,8 @@ class DaViT(nn.Module):
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super(DaViT, self).__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)

Loading…
Cancel
Save