|
|
|
@ -485,7 +485,7 @@ class DaViT(nn.Module):
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def _update_forward_fn(self):
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
self.forward = self.forward_features_full
|
|
|
|
|
self.forward = self.forward_pyramid_features
|
|
|
|
|
else:
|
|
|
|
|
self.forward = self.forward_classification
|
|
|
|
|
|
|
|
|
@ -527,7 +527,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features_full(self, x):
|
|
|
|
|
def forward_network(self, x):
|
|
|
|
|
#x, size = self.patch_embeds[0](x, (x.size(2), x.size(3)))
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
features = [x]
|
|
|
|
@ -601,9 +601,18 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
|
return features[1:], sizes[:-1]
|
|
|
|
|
|
|
|
|
|
def forward_pyramid_features(self, x):
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
outs = []
|
|
|
|
|
for i, out in enumerate(x):
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
outs.append(x.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x, sizes = self.forward_features_full(x)
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x[-1])
|
|
|
|
|
H, W = sizes[-1]
|
|
|
|
|