From fbf4396115f015995b562b5ab6ce54139f1bccf9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 16 Dec 2022 02:31:48 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 39 +++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 061a68bf..8bae9938 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -341,11 +341,11 @@ class SpatialBlock(nn.Module): class DaViTStage(nn.Module): def __init__( self, - #in_chs, - dim, + in_chs, + out_chs, depth = 1, - #patch_size = 4, - #overlapped_patch = False, + patch_size = 4, + overlapped_patch = False, attention_types = ('spatial', 'channel'), num_heads = 3, window_size = 7, @@ -361,14 +361,12 @@ class DaViTStage(nn.Module): self.grad_checkpointing = False # patch embedding layer at the beginning of each stage - ''' self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chs, embed_dim=out_chs, overlapped=overlapped_patch ) - ''' ''' repeating alternating attention blocks in each stage default: (spatial -> channel) x depth @@ -384,7 +382,7 @@ class DaViTStage(nn.Module): for attention_id, attention_type in enumerate(attention_types): if attention_type == 'spatial': dual_attention_block.append(SpatialBlock( - dim=dim, + dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, @@ -396,7 +394,7 @@ class DaViTStage(nn.Module): )) elif attention_type == 'channel': dual_attention_block.append(ChannelBlock( - dim=dim, + dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, @@ -411,7 +409,7 @@ class DaViTStage(nn.Module): self.blocks = nn.Sequential(*stage_blocks) def forward(self, x : Tensor): - #x = self.patch_embed(x) + x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -474,23 +472,25 @@ class DaViT(nn.Module): self.drop_rate=drop_rate self.grad_checkpointing = False self.feature_info = [] + + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dims[0], + overlapped=overlapped_patch + ) stages = [] for stage_id in range(self.num_stages): stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])] - print(stage_drop_rates) - - patch_embed = PatchEmbed( - patch_size=patch_size if stage_id == 0 else 2, - in_chans=in_chans if stage_id == 0 else embed_dims[stage_id - 1], - embed_dim=embed_dims[stage_id], - overlapped=overlapped_patch - ) stage = DaViTStage( + in_chans if stage_id == 0 else embed_dims[stage_id - 1], embed_dims[stage_id], depth = depths[stage_id], + patch_size = patch_size if stage_id == 0 else 2, + overlapped_patch = overlapped_patch, attention_types = attention_types, num_heads = num_heads[stage_id], window_size = window_size, @@ -502,7 +502,9 @@ class DaViT(nn.Module): cpe_act = cpe_act ) - stages.append(patch_embed) + if stage_id == 0: + stage.patch_embed = nn.Identity() + stages.append(stage) self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] @@ -537,6 +539,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) def forward_features(self, x): + x = self.patch_embed(x) x = self.stages(x) # take final feature and norm x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)