Davit revised (#4)

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py
pull/1630/head
Fredo Guan 3 years ago committed by GitHub
parent 3f14094fcd
commit 52093607f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
# modified nn.Sequential that includes a size tuple in the forward function
class SequentialWithSize(nn.Sequential):
def forward(self, x : Tensor, size: Tuple[int, int]):
for module in self._modules.values():
@ -171,13 +172,13 @@ class ChannelBlock(nn.Module):
ffn=True, cpe_act=False):
super().__init__()
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
@ -188,12 +189,12 @@ class ChannelBlock(nn.Module):
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe[0](x, size)
x = self.cpe1(x, size)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path(cur)
x = self.cpe[1](x, size)
x = self.cpe2(x, size)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
@ -292,9 +293,8 @@ class SpatialBlock(nn.Module):
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
@ -303,7 +303,8 @@ class SpatialBlock(nn.Module):
qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
@ -318,7 +319,7 @@ class SpatialBlock(nn.Module):
H, W = size
B, L, C = x.shape
shortcut = self.cpe[0](x, size)
shortcut = self.cpe1(x, size)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
@ -347,7 +348,7 @@ class SpatialBlock(nn.Module):
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = self.cpe[1](x, size)
x = self.cpe2(x, size)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
@ -417,12 +418,10 @@ class DaViTStage(nn.Module):
def forward(self, x : Tensor, size: Tuple[int, int]):
x, size = self.patch_embed(x, size)
for block in self.blocks:
for layer in block:
if self.grad_checkpointing and not torch.jit.is_scripting():
x, size = checkpoint.checkpoint(layer, x, size)
else:
x, size = layer(x, size)
if self.grad_checkpointing and not torch.jit.is_scripting():
x, size = checkpoint_seq(self.blocks, x, size)
else:
x, size = self.blocks(x, size)
return x, size
@ -490,8 +489,8 @@ class DaViT(nn.Module):
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = 1,
patch_size = patch_size,
depth = depths[stage_id],
patch_size = patch_size if stage_id == 0 else 2,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
@ -602,11 +601,14 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
k = k.replace('main_blocks.', 'stages.stage_')
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('head.', 'head.fc.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
return out_dict
@ -642,7 +644,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}

Loading…
Cancel
Save