Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 3cff8d0ff4
commit fbf4396115

@ -341,11 +341,11 @@ class SpatialBlock(nn.Module):
class DaViTStage(nn.Module):
def __init__(
self,
#in_chs,
dim,
in_chs,
out_chs,
depth = 1,
#patch_size = 4,
#overlapped_patch = False,
patch_size = 4,
overlapped_patch = False,
attention_types = ('spatial', 'channel'),
num_heads = 3,
window_size = 7,
@ -361,14 +361,12 @@ class DaViTStage(nn.Module):
self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
'''
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
'''
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
@ -384,7 +382,7 @@ class DaViTStage(nn.Module):
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=dim,
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
@ -396,7 +394,7 @@ class DaViTStage(nn.Module):
))
elif attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=dim,
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
@ -411,7 +409,7 @@ class DaViTStage(nn.Module):
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x : Tensor):
#x = self.patch_embed(x)
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
@ -474,23 +472,25 @@ class DaViT(nn.Module):
self.drop_rate=drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dims[0],
overlapped=overlapped_patch
)
stages = []
for stage_id in range(self.num_stages):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
print(stage_drop_rates)
patch_embed = PatchEmbed(
patch_size=patch_size if stage_id == 0 else 2,
in_chans=in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dim=embed_dims[stage_id],
overlapped=overlapped_patch
)
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = depths[stage_id],
patch_size = patch_size if stage_id == 0 else 2,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
window_size = window_size,
@ -502,7 +502,9 @@ class DaViT(nn.Module):
cpe_act = cpe_act
)
stages.append(patch_embed)
if stage_id == 0:
stage.patch_embed = nn.Identity()
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
@ -537,6 +539,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stages(x)
# take final feature and norm
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

Loading…
Cancel
Save