|
|
|
@ -473,7 +473,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
self.patch_embed = None
|
|
|
|
|
self.stem = None
|
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
@ -497,7 +497,7 @@ class DaViT(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stage_id == 0:
|
|
|
|
|
self.patch_embed = stage.patch_embed
|
|
|
|
|
self.stem = stage.patch_embed
|
|
|
|
|
stage.patch_embed = nn.Identity()
|
|
|
|
|
|
|
|
|
|
stages.append(stage)
|
|
|
|
@ -506,7 +506,7 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
self.norm = norm_layer(self.num_features)
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
@ -536,10 +536,7 @@ class DaViT(nn.Module):
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
#H, W = sizes[-1]
|
|
|
|
|
#x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
@ -567,7 +564,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
|
k = k.replace('stages.0.patch_embed', 'patch_embed')
|
|
|
|
|
k = k.replace('stages.0.patch_embed', 'stem')
|
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
|
k = k.replace('cpe.1', 'cpe2')
|
|
|
|
@ -577,8 +574,6 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
|
|
|
|
@ -594,7 +589,7 @@ def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
|