|
|
@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# modified nn.Sequential that includes a size tuple in the forward function
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
for module in self._modules.values():
|
|
|
|
for module in self._modules.values():
|
|
|
@ -171,12 +172,12 @@ class ChannelBlock(nn.Module):
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
|
|
|
|
|
|
|
|
self.ffn = ffn
|
|
|
|
self.ffn = ffn
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
|
|
|
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
if self.ffn:
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
@ -188,12 +189,12 @@ class ChannelBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
x = self.cpe[0](x, size)
|
|
|
|
x = self.cpe1(x, size)
|
|
|
|
cur = self.norm1(x)
|
|
|
|
cur = self.norm1(x)
|
|
|
|
cur = self.attn(cur)
|
|
|
|
cur = self.attn(cur)
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.cpe[1](x, size)
|
|
|
|
x = self.cpe2(x, size)
|
|
|
|
if self.ffn:
|
|
|
|
if self.ffn:
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
return x, size
|
|
|
|
return x, size
|
|
|
@ -292,9 +293,8 @@ class SpatialBlock(nn.Module):
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.window_size = window_size
|
|
|
|
self.window_size = window_size
|
|
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
|
|
|
|
|
|
|
|
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.attn = WindowAttention(
|
|
|
|
self.attn = WindowAttention(
|
|
|
|
dim,
|
|
|
|
dim,
|
|
|
@ -303,6 +303,7 @@ class SpatialBlock(nn.Module):
|
|
|
|
qkv_bias=qkv_bias)
|
|
|
|
qkv_bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
if self.ffn:
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
@ -318,7 +319,7 @@ class SpatialBlock(nn.Module):
|
|
|
|
H, W = size
|
|
|
|
H, W = size
|
|
|
|
B, L, C = x.shape
|
|
|
|
B, L, C = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
shortcut = self.cpe[0](x, size)
|
|
|
|
shortcut = self.cpe1(x, size)
|
|
|
|
x = self.norm1(shortcut)
|
|
|
|
x = self.norm1(shortcut)
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
|
|
|
@ -347,7 +348,7 @@ class SpatialBlock(nn.Module):
|
|
|
|
x = x.view(B, H * W, C)
|
|
|
|
x = x.view(B, H * W, C)
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
|
|
x = self.cpe[1](x, size)
|
|
|
|
x = self.cpe2(x, size)
|
|
|
|
if self.ffn:
|
|
|
|
if self.ffn:
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
return x, size
|
|
|
|
return x, size
|
|
|
@ -417,12 +418,10 @@ class DaViTStage(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
x, size = self.patch_embed(x, size)
|
|
|
|
x, size = self.patch_embed(x, size)
|
|
|
|
for block in self.blocks:
|
|
|
|
|
|
|
|
for layer in block:
|
|
|
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
x, size = checkpoint.checkpoint(layer, x, size)
|
|
|
|
x, size = checkpoint_seq(self.blocks, x, size)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x, size = layer(x, size)
|
|
|
|
x, size = self.blocks(x, size)
|
|
|
|
return x, size
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -490,8 +489,8 @@ class DaViT(nn.Module):
|
|
|
|
stage = DaViTStage(
|
|
|
|
stage = DaViTStage(
|
|
|
|
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
embed_dims[stage_id],
|
|
|
|
embed_dims[stage_id],
|
|
|
|
depth = 1,
|
|
|
|
depth = depths[stage_id],
|
|
|
|
patch_size = patch_size,
|
|
|
|
patch_size = patch_size if stage_id == 0 else 2,
|
|
|
|
overlapped_patch = overlapped_patch,
|
|
|
|
overlapped_patch = overlapped_patch,
|
|
|
|
attention_types = attention_types,
|
|
|
|
attention_types = attention_types,
|
|
|
|
num_heads = num_heads[stage_id],
|
|
|
|
num_heads = num_heads[stage_id],
|
|
|
@ -602,11 +601,14 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
k = k.replace('main_blocks.', 'stages.stage_')
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
|
|
|
|
k = k.replace('cpe.1', 'cpe2')
|
|
|
|
out_dict[k] = v
|
|
|
|
out_dict[k] = v
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
@ -642,7 +644,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
|
|
|
|
'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|