diff --git a/timm/models/davit.py b/timm/models/davit.py index c8b10b23..dfc4203d 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] +# modified nn.Sequential that includes a size tuple in the forward function class SequentialWithSize(nn.Sequential): def forward(self, x : Tensor, size: Tuple[int, int]): for module in self._modules.values(): @@ -171,13 +172,13 @@ class ChannelBlock(nn.Module): ffn=True, cpe_act=False): super().__init__() - self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), - ConvPosEnc(dim=dim, k=3, act=cpe_act)]) + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.ffn = ffn self.norm1 = norm_layer(dim) self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + if self.ffn: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -188,12 +189,12 @@ class ChannelBlock(nn.Module): def forward(self, x : Tensor, size: Tuple[int, int]): - x = self.cpe[0](x, size) + x = self.cpe1(x, size) cur = self.norm1(x) cur = self.attn(cur) x = x + self.drop_path(cur) - x = self.cpe[1](x, size) + x = self.cpe2(x, size) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) return x, size @@ -292,9 +293,8 @@ class SpatialBlock(nn.Module): self.num_heads = num_heads self.window_size = window_size self.mlp_ratio = mlp_ratio - self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), - ConvPosEnc(dim=dim, k=3, act=cpe_act)]) - + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, @@ -303,7 +303,8 @@ class SpatialBlock(nn.Module): qkv_bias=qkv_bias) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + if self.ffn: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -318,7 +319,7 @@ class SpatialBlock(nn.Module): H, W = size B, L, C = x.shape - shortcut = self.cpe[0](x, size) + shortcut = self.cpe1(x, size) x = self.norm1(shortcut) x = x.view(B, H, W, C) @@ -347,7 +348,7 @@ class SpatialBlock(nn.Module): x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) - x = self.cpe[1](x, size) + x = self.cpe2(x, size) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) return x, size @@ -417,12 +418,10 @@ class DaViTStage(nn.Module): def forward(self, x : Tensor, size: Tuple[int, int]): x, size = self.patch_embed(x, size) - for block in self.blocks: - for layer in block: - if self.grad_checkpointing and not torch.jit.is_scripting(): - x, size = checkpoint.checkpoint(layer, x, size) - else: - x, size = layer(x, size) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x, size = checkpoint_seq(self.blocks, x, size) + else: + x, size = self.blocks(x, size) return x, size @@ -490,8 +489,8 @@ class DaViT(nn.Module): stage = DaViTStage( in_chans if stage_id == 0 else embed_dims[stage_id - 1], embed_dims[stage_id], - depth = 1, - patch_size = patch_size, + depth = depths[stage_id], + patch_size = patch_size if stage_id == 0 else 2, overlapped_patch = overlapped_patch, attention_types = attention_types, num_heads = num_heads[stage_id], @@ -602,11 +601,14 @@ def checkpoint_filter_fn(state_dict, model): if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] - + import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('main_blocks.', 'stages.stage_') + k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k) + k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) k = k.replace('head.', 'head.fc.') + k = k.replace('cpe.0', 'cpe1') + k = k.replace('cpe.1', 'cpe2') out_dict[k] = v return out_dict @@ -642,7 +644,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', + 'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc', **kwargs }