Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 4a936b92a9
commit 8ac5ec2d88

@ -382,8 +382,8 @@ class DaViT(nn.Module):
attn_drop_rate=0.,
img_size=224,
num_classes=1000,
global_pool='avg',
features_only = False
global_pool='avg'#,
#features_only = False
):
super().__init__()
@ -399,7 +399,8 @@ class DaViT(nn.Module):
self.num_features = embed_dims[-1]
self.drop_rate=drop_rate
self.grad_checkpointing = False
self._features_only = False
self._features_only = features_only
self.feature_info = []
self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2,
@ -438,7 +439,10 @@ class DaViT(nn.Module):
for attention_id, attention_type in enumerate(attention_types)]
) for layer_id, item in enumerate(block_param)
])
main_blocks.append(block)
main_blocks.append((f'block.{block_id}', block))
self.feature_info += [dict(num_ch=self.embed_dims[block_id], reduction = 2, module=f'block.{block_id}')]
self.main_blocks = nn.ModuleList(main_blocks)
'''
@ -457,7 +461,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights)
self._update_forward_fn()
#self._update_forward_fn()
#self.forward = self._get_forward_fn()
'''
@ -482,6 +486,7 @@ class DaViT(nn.Module):
else:
return self.forward_classification
'''
@torch.jit.ignore
def _update_forward_fn(self):
if self._features_only == True:

Loading…
Cancel
Save