|
|
|
@ -241,7 +241,7 @@ class HybridEmbed(nn.Module):
|
|
|
|
|
training = backbone.training
|
|
|
|
|
if training:
|
|
|
|
|
backbone.eval()
|
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]), return_featuremaps=True) # it works with osnet
|
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
|
|
|
if isinstance(o, (list, tuple)):
|
|
|
|
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
|
feature_size = o.shape[-2:]
|
|
|
|
@ -257,7 +257,7 @@ class HybridEmbed(nn.Module):
|
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.backbone(x, return_featuremaps=True) # it works with osnet
|
|
|
|
|
x = self.backbone(x)
|
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
@ -784,4 +784,4 @@ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
return model
|
|
|
|
|