Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent e643c36488
commit e7c5ab9d1e

@ -485,7 +485,7 @@ class DaViT(nn.Module):
@torch.jit.ignore
def _update_forward_fn(self):
if self._features_only == True:
self.forward = self.forward_features_full
self.forward = self.forward_pyramid_features
else:
self.forward = self.forward_classification
@ -527,7 +527,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features_full(self, x):
def forward_network(self, x):
#x, size = self.patch_embeds[0](x, (x.size(2), x.size(3)))
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
@ -601,9 +601,18 @@ class DaViT(nn.Module):
# non-normalized pyramid features + corresponding sizes
return features[1:], sizes[:-1]
def forward_pyramid_features(self, x):
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(x.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x):
x, sizes = self.forward_features_full(x)
x, sizes = self.forward_network(x)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]

Loading…
Cancel
Save