|
|
|
@ -382,8 +382,8 @@ class DaViT(nn.Module):
|
|
|
|
|
attn_drop_rate=0.,
|
|
|
|
|
img_size=224,
|
|
|
|
|
num_classes=1000,
|
|
|
|
|
global_pool='avg',
|
|
|
|
|
features_only = False
|
|
|
|
|
global_pool='avg'#,
|
|
|
|
|
#features_only = False
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
@ -399,7 +399,8 @@ class DaViT(nn.Module):
|
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self._features_only = False
|
|
|
|
|
self._features_only = features_only
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
|
|
PatchEmbed(patch_size=patch_size if i == 0 else 2,
|
|
|
|
@ -438,7 +439,10 @@ class DaViT(nn.Module):
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types)]
|
|
|
|
|
) for layer_id, item in enumerate(block_param)
|
|
|
|
|
])
|
|
|
|
|
main_blocks.append(block)
|
|
|
|
|
|
|
|
|
|
main_blocks.append((f'block.{block_id}', block))
|
|
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_ch=self.embed_dims[block_id], reduction = 2, module=f'block.{block_id}')]
|
|
|
|
|
self.main_blocks = nn.ModuleList(main_blocks)
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
@ -457,7 +461,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
self._update_forward_fn()
|
|
|
|
|
#self._update_forward_fn()
|
|
|
|
|
|
|
|
|
|
#self.forward = self._get_forward_fn()
|
|
|
|
|
'''
|
|
|
|
@ -482,6 +486,7 @@ class DaViT(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def _update_forward_fn(self):
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|