|
|
@ -241,7 +241,7 @@ class HybridEmbed(nn.Module):
|
|
|
|
training = backbone.training
|
|
|
|
training = backbone.training
|
|
|
|
if training:
|
|
|
|
if training:
|
|
|
|
backbone.eval()
|
|
|
|
backbone.eval()
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]), return_featuremaps=True) # it works with osnet
|
|
|
|
if isinstance(o, (list, tuple)):
|
|
|
|
if isinstance(o, (list, tuple)):
|
|
|
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
feature_size = o.shape[-2:]
|
|
|
|
feature_size = o.shape[-2:]
|
|
|
@ -257,7 +257,7 @@ class HybridEmbed(nn.Module):
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.backbone(x)
|
|
|
|
x = self.backbone(x, return_featuremaps=True) # it works with osnet
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
@ -299,7 +299,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
self.patch_embed = HybridEmbed(
|
|
|
|
self.patch_embed = HybridEmbed(
|
|
|
|
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
hybrid_backbone, img_size=img_size, feature_size=None , in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
@ -322,7 +322,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
self.num_features = representation_size
|
|
|
|
self.num_features = representation_size
|
|
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
|
|
('fc', nn.Linear(embed_dim, representation_size)),
|
|
|
|
('fc', nn.Linear(embed_dim, representation_size)),
|
|
|
|
('act', nn.Tanh())
|
|
|
|
('act', nn.Identity()) #('act', nn.Tanh())
|
|
|
|
]))
|
|
|
|
]))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.pre_logits = nn.Identity()
|
|
|
|
self.pre_logits = nn.Identity()
|
|
|
|