|
|
@ -493,6 +493,14 @@ class DaViT(nn.Module):
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
return features, sizes
|
|
|
|
return features, sizes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_pyramid_features(self, x) -> List[Tensor]:
|
|
|
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
|
|
|
|
for i, out in enumerate(x):
|
|
|
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
|
|
|
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
@ -505,10 +513,13 @@ class DaViT(nn.Module):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward_classifier(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
return self.forward_classifier(self, x)
|
|
|
|
|
|
|
|
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
|
@ -517,13 +528,8 @@ class DaViTFeatures(DaViT):
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
return self.forward_pyramid_features(self, x)
|
|
|
|
outs = []
|
|
|
|
|
|
|
|
for i, out in enumerate(x):
|
|
|
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
|
|
|
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|