|
|
|
@ -341,11 +341,11 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
class DaViTStage(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
#in_chs,
|
|
|
|
|
dim,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
depth = 1,
|
|
|
|
|
#patch_size = 4,
|
|
|
|
|
#overlapped_patch = False,
|
|
|
|
|
patch_size = 4,
|
|
|
|
|
overlapped_patch = False,
|
|
|
|
|
attention_types = ('spatial', 'channel'),
|
|
|
|
|
num_heads = 3,
|
|
|
|
|
window_size = 7,
|
|
|
|
@ -361,14 +361,12 @@ class DaViTStage(nn.Module):
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
# patch embedding layer at the beginning of each stage
|
|
|
|
|
'''
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chs,
|
|
|
|
|
embed_dim=out_chs,
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
'''
|
|
|
|
|
'''
|
|
|
|
|
repeating alternating attention blocks in each stage
|
|
|
|
|
default: (spatial -> channel) x depth
|
|
|
|
@ -384,7 +382,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types):
|
|
|
|
|
if attention_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
dim=dim,
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
@ -396,7 +394,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
))
|
|
|
|
|
elif attention_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
dim=dim,
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
@ -411,7 +409,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
self.blocks = nn.Sequential(*stage_blocks)
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
#x = self.patch_embed(x)
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
|
else:
|
|
|
|
@ -474,23 +472,25 @@ class DaViT(nn.Module):
|
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chans,
|
|
|
|
|
embed_dim=embed_dims[0],
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
|
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
|
|
|
|
|
print(stage_drop_rates)
|
|
|
|
|
|
|
|
|
|
patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size if stage_id == 0 else 2,
|
|
|
|
|
in_chans=in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
|
embed_dim=embed_dims[stage_id],
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stage = DaViTStage(
|
|
|
|
|
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
|
embed_dims[stage_id],
|
|
|
|
|
depth = depths[stage_id],
|
|
|
|
|
patch_size = patch_size if stage_id == 0 else 2,
|
|
|
|
|
overlapped_patch = overlapped_patch,
|
|
|
|
|
attention_types = attention_types,
|
|
|
|
|
num_heads = num_heads[stage_id],
|
|
|
|
|
window_size = window_size,
|
|
|
|
@ -502,7 +502,9 @@ class DaViT(nn.Module):
|
|
|
|
|
cpe_act = cpe_act
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stages.append(patch_embed)
|
|
|
|
|
if stage_id == 0:
|
|
|
|
|
stage.patch_embed = nn.Identity()
|
|
|
|
|
|
|
|
|
|
stages.append(stage)
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
@ -537,6 +539,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|