diff --git a/timm/models/davit.py b/timm/models/davit.py index 74d51f73..94c7c8dd 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] + # modified nn.Sequential that includes a size tuple in the forward function # FIXME doesn't work with torchscript/JIT # Module 'SequentialWithSize' has no attribute '_modules' @@ -43,6 +44,7 @@ class SequentialWithSize(nn.Sequential): x, size = module(x, size) return x, size + class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): @@ -190,7 +192,9 @@ class ChannelBlock(nn.Module): def forward(self, x : Tensor, size: Tuple[int, int]): + x = self.cpe1(x, size) + cur = self.norm1(x) cur = self.attn(cur) x = x + self.drop_path(cur) @@ -520,7 +524,6 @@ class DaViT(nn.Module): self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) - def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -551,7 +554,6 @@ class DaViT(nn.Module): size: Tuple[int, int] = (x.size(2), x.size(3)) x, size = self.stages(x, size) - # take final feature and norm x = self.norms(x) H, W = size x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() @@ -567,7 +569,6 @@ class DaViT(nn.Module): def forward(self, x): return self.forward_classifier(x) - def checkpoint_filter_fn(state_dict, model): @@ -577,6 +578,7 @@ def checkpoint_filter_fn(state_dict, model): if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] + import re out_dict = {} for k, v in state_dict.items(): @@ -590,6 +592,7 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) model = build_model_with_cfg( @@ -599,7 +602,6 @@ def _create_davit(variant, pretrained=False, **kwargs): pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) - return model