fixing master bug on hardcoded osnet architecture

pull/434/head
szingaro 4 years ago
parent 7a2076ae04
commit d8131d5d47

@ -241,7 +241,7 @@ class HybridEmbed(nn.Module):
training = backbone.training training = backbone.training
if training: if training:
backbone.eval() backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]), return_featuremaps=True) # it works with osnet o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)): if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:] feature_size = o.shape[-2:]
@ -257,7 +257,7 @@ class HybridEmbed(nn.Module):
self.proj = nn.Conv2d(feature_dim, embed_dim, 1) self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
def forward(self, x): def forward(self, x):
x = self.backbone(x, return_featuremaps=True) # it works with osnet x = self.backbone(x)
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x).flatten(2).transpose(1, 2)

Loading…
Cancel
Save