diff --git a/timm/models/davit.py b/timm/models/davit.py index bf8a8377..c2f3fb2b 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -341,11 +341,11 @@ class SpatialBlock(nn.Module): class DaViTStage(nn.Module): def __init__( self, - in_chs, - out_chs, + #in_chs, + dim, depth = 1, - patch_size = 4, - overlapped_patch = False, + #patch_size = 4, + #overlapped_patch = False, attention_types = ('spatial', 'channel'), num_heads = 3, window_size = 7, @@ -361,12 +361,14 @@ class DaViTStage(nn.Module): self.grad_checkpointing = False # patch embedding layer at the beginning of each stage + ''' self.patch_embed = PatchEmbed( patch_size=patch_size, in_chans=in_chs, embed_dim=out_chs, overlapped=overlapped_patch ) + ''' ''' repeating alternating attention blocks in each stage default: (spatial -> channel) x depth @@ -382,7 +384,7 @@ class DaViTStage(nn.Module): for attention_id, attention_type in enumerate(attention_types): if attention_type == 'spatial': dual_attention_block.append(SpatialBlock( - dim=out_chs, + dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, @@ -394,7 +396,7 @@ class DaViTStage(nn.Module): )) elif attention_type == 'channel': dual_attention_block.append(ChannelBlock( - dim=out_chs, + dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, @@ -477,13 +479,18 @@ class DaViT(nn.Module): for stage_id in range(self.num_stages): stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])] + print(stage_drop_rates) + + patch_embed = PatchEmbed( + patch_size=patch_size if stage_id == 0 else 2, + in_chans=in_chans if stage_id == 0 else embed_dims[stage_id - 1], + embed_dim=embed_dims[stage_id], + overlapped=overlapped_patch + ) stage = DaViTStage( - in_chans if stage_id == 0 else embed_dims[stage_id - 1], embed_dims[stage_id], depth = depths[stage_id], - patch_size = patch_size if stage_id == 0 else 2, - overlapped_patch = overlapped_patch, attention_types = attention_types, num_heads = num_heads[stage_id], window_size = window_size, @@ -495,6 +502,7 @@ class DaViT(nn.Module): cpe_act = cpe_act ) + stages.append(patch_embed stages.append(stage) self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] @@ -598,6 +606,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up +# TODO contact authors to get larger pretrained models default_cfgs = generate_default_cfgs({ # official microsoft weights from https://github.com/dingmyu/davit 'davit_tiny.msft_in1k': _cfg( @@ -631,8 +640,6 @@ def davit_base(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) - -# TODO contact authors to get larger pretrained models @register_model def davit_large(pretrained=False, **kwargs): model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),