editing hybrid backbone for osnet specific architecture | to be generalized

pull/434/head
szingaro 4 years ago
parent 7987f0c83d
commit 203219f906

@ -241,7 +241,7 @@ class HybridEmbed(nn.Module):
training = backbone.training training = backbone.training
if training: if training:
backbone.eval() backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]), return_featuremaps=True) # it works with osnet
if isinstance(o, (list, tuple)): if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:] feature_size = o.shape[-2:]
@ -257,7 +257,7 @@ class HybridEmbed(nn.Module):
self.proj = nn.Conv2d(feature_dim, embed_dim, 1) self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
def forward(self, x): def forward(self, x):
x = self.backbone(x) x = self.backbone(x, return_featuremaps=True) # it works with osnet
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x).flatten(2).transpose(1, 2)
@ -299,7 +299,7 @@ class VisionTransformer(nn.Module):
if hybrid_backbone is not None: if hybrid_backbone is not None:
self.patch_embed = HybridEmbed( self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) hybrid_backbone, img_size=img_size, feature_size=None , in_chans=in_chans, embed_dim=embed_dim)
else: else:
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
@ -322,7 +322,7 @@ class VisionTransformer(nn.Module):
self.num_features = representation_size self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([ self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)), ('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh()) ('act', nn.Identity()) #('act', nn.Tanh())
])) ]))
else: else:
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()

Loading…
Cancel
Save