Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 447dd0d23f
commit 4c8c7faa12

@ -792,34 +792,36 @@ class DaViT(nn.Module):
def forward_network(self, x : Tensor): def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3)) size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x] features = [x]
sizes = [size] #sizes = [size]
for stage in self.stages: for stage in self.stages:
features[-1], sizes[-1] = stage(features[-1], sizes[-1]) features[-1] = stage(features[-1])
# don't append outputs of last stage, since they are already there # don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages): if(len(features) < self.num_stages):
features.append(features[-1]) features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes # non-normalized pyramid features + corresponding sizes
return features, sizes return features
def forward_pyramid_features(self, x) -> List[Tensor]: def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x) x = self.forward_network(x)
'''
outs = [] outs = []
for i, out in enumerate(x): for i, out in enumerate(x):
H, W = sizes[i] H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs '''
return x
def forward_features(self, x): def forward_features(self, x):
x, sizes = self.forward_network(x) x = self.forward_network(x)
# take final feature and norm # take final feature and norm
x = self.norms(x[-1]) x = self.norms(x[-1])
H, W = sizes[-1] #H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() #x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):

Loading…
Cancel
Save