|
|
|
@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# modified nn.Sequential that includes a size tuple in the forward function
|
|
|
|
|
# FIXME doesn't work with torchscript/JIT
|
|
|
|
|
# Module 'SequentialWithSize' has no attribute '_modules'
|
|
|
|
@ -43,6 +44,7 @@ class SequentialWithSize(nn.Sequential):
|
|
|
|
|
x, size = module(x, size)
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
@ -190,7 +192,9 @@ class ChannelBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
x = self.cpe1(x, size)
|
|
|
|
|
|
|
|
|
|
cur = self.norm1(x)
|
|
|
|
|
cur = self.attn(cur)
|
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
@ -520,7 +524,6 @@ class DaViT(nn.Module):
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
@ -551,7 +554,6 @@ class DaViT(nn.Module):
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
x, size = self.stages(x, size)
|
|
|
|
|
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x)
|
|
|
|
|
H, W = size
|
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
@ -567,7 +569,6 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.forward_classifier(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
@ -577,6 +578,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
out_dict = {}
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
@ -590,6 +592,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
@ -599,7 +602,6 @@ def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|