|
|
|
@ -341,11 +341,11 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
class DaViTStage(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
#in_chs,
|
|
|
|
|
dim,
|
|
|
|
|
depth = 1,
|
|
|
|
|
patch_size = 4,
|
|
|
|
|
overlapped_patch = False,
|
|
|
|
|
#patch_size = 4,
|
|
|
|
|
#overlapped_patch = False,
|
|
|
|
|
attention_types = ('spatial', 'channel'),
|
|
|
|
|
num_heads = 3,
|
|
|
|
|
window_size = 7,
|
|
|
|
@ -361,12 +361,14 @@ class DaViTStage(nn.Module):
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
# patch embedding layer at the beginning of each stage
|
|
|
|
|
'''
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chs,
|
|
|
|
|
embed_dim=out_chs,
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
'''
|
|
|
|
|
'''
|
|
|
|
|
repeating alternating attention blocks in each stage
|
|
|
|
|
default: (spatial -> channel) x depth
|
|
|
|
@ -382,7 +384,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types):
|
|
|
|
|
if attention_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
dim=dim,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
@ -394,7 +396,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
))
|
|
|
|
|
elif attention_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
dim=dim,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
@ -477,13 +479,18 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
|
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
|
|
|
|
|
print(stage_drop_rates)
|
|
|
|
|
|
|
|
|
|
patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size if stage_id == 0 else 2,
|
|
|
|
|
in_chans=in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
|
embed_dim=embed_dims[stage_id],
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stage = DaViTStage(
|
|
|
|
|
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
|
embed_dims[stage_id],
|
|
|
|
|
depth = depths[stage_id],
|
|
|
|
|
patch_size = patch_size if stage_id == 0 else 2,
|
|
|
|
|
overlapped_patch = overlapped_patch,
|
|
|
|
|
attention_types = attention_types,
|
|
|
|
|
num_heads = num_heads[stage_id],
|
|
|
|
|
window_size = window_size,
|
|
|
|
@ -495,6 +502,7 @@ class DaViT(nn.Module):
|
|
|
|
|
cpe_act = cpe_act
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stages.append(patch_embed
|
|
|
|
|
stages.append(stage)
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
@ -598,6 +606,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
# official microsoft weights from https://github.com/dingmyu/davit
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
@ -631,8 +640,6 @@ def davit_base(pretrained=False, **kwargs):
|
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
|