|
|
@ -524,8 +524,8 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(*args, **kwargs):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(self).__init__(*args, **kwargs)
|
|
|
|
super(DaViT, self).__init__(*args, **kwargs)
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|