|
|
@ -443,7 +443,7 @@ class DaViT(nn.Module):
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction = 2, module=f'stage_{stage_id}')]
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction = 2, module=f'stage_{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.norm = norm_layer(self.num_features)
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
@ -507,7 +507,7 @@ class DaViT(nn.Module):
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
# take final feature and norm
|
|
|
|
# take final feature and norm
|
|
|
|
x = self.norm(x[-1])
|
|
|
|
x = self.norms(x[-1])
|
|
|
|
H, W = sizes[-1]
|
|
|
|
H, W = sizes[-1]
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|