Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 447dd0d23f
commit 4c8c7faa12

@ -792,34 +792,36 @@ class DaViT(nn.Module):
def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]
#sizes = [size]
for stage in self.stages:
features[-1], sizes[-1] = stage(features[-1], sizes[-1])
features[-1] = stage(features[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
return features
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
x = self.forward_network(x)
'''
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
'''
return x
def forward_features(self, x):
x, sizes = self.forward_network(x)
x = self.forward_network(x)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
#H, W = sizes[-1]
#x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x
def forward_head(self, x, pre_logits: bool = False):

Loading…
Cancel
Save