|
|
|
@ -291,10 +291,10 @@ class Twins(nn.Module):
|
|
|
|
|
patch_size = 2
|
|
|
|
|
|
|
|
|
|
self.task = task
|
|
|
|
|
self.F4=F4
|
|
|
|
|
self.extra_norm = extra_norm
|
|
|
|
|
self.strides = strides
|
|
|
|
|
if self.task == 'seg':
|
|
|
|
|
self.F4=F4
|
|
|
|
|
self.extra_norm = extra_norm
|
|
|
|
|
self.strides = strides
|
|
|
|
|
if self.extra_norm:
|
|
|
|
|
self.norm_list = nn.ModuleList()
|
|
|
|
|
for dim in embed_dims:
|
|
|
|
@ -363,7 +363,6 @@ class Twins(nn.Module):
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
|
|
|
|
|
if self.task == 'cls':
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
@ -406,7 +405,6 @@ class Twins(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_twins(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
|