Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent c307f4ac3d
commit c6973a3ebe

@ -341,11 +341,11 @@ class SpatialBlock(nn.Module):
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
#in_chs,
dim,
depth = 1,
patch_size = 4,
overlapped_patch = False,
#patch_size = 4,
#overlapped_patch = False,
attention_types = ('spatial', 'channel'),
num_heads = 3,
window_size = 7,
@ -361,12 +361,14 @@ class DaViTStage(nn.Module):
self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
'''
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
'''
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
@ -382,7 +384,7 @@ class DaViTStage(nn.Module):
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
@ -394,7 +396,7 @@ class DaViTStage(nn.Module):
))
elif attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
@ -477,13 +479,18 @@ class DaViT(nn.Module):
for stage_id in range(self.num_stages):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
print(stage_drop_rates)
patch_embed = PatchEmbed(
patch_size=patch_size if stage_id == 0 else 2,
in_chans=in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dim=embed_dims[stage_id],
overlapped=overlapped_patch
)
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = depths[stage_id],
patch_size = patch_size if stage_id == 0 else 2,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
window_size = window_size,
@ -495,6 +502,7 @@ class DaViT(nn.Module):
cpe_act = cpe_act
)
stages.append(patch_embed
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
@ -598,6 +606,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
@ -631,8 +640,6 @@ def davit_base(pretrained=False, **kwargs):
num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
# TODO contact authors to get larger pretrained models
@register_model
def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),

Loading…
Cancel
Save