Merge branch 'main' into davit_std

pull/1630/head
Fredo Guan 3 years ago committed by GitHub
commit f42ee614a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
# modified nn.Sequential that includes a size tuple in the forward function
# FIXME doesn't work with torchscript/JIT
# Module 'SequentialWithSize' has no attribute '_modules'
@ -43,6 +44,7 @@ class SequentialWithSize(nn.Sequential):
x, size = module(x, size)
return x, size
class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -190,7 +192,9 @@ class ChannelBlock(nn.Module):
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe1(x, size)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path(cur)
@ -520,7 +524,6 @@ class DaViT(nn.Module):
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
@ -551,7 +554,6 @@ class DaViT(nn.Module):
size: Tuple[int, int] = (x.size(2), x.size(3))
x, size = self.stages(x, size)
# take final feature and norm
x = self.norms(x)
H, W = size
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
@ -567,7 +569,6 @@ class DaViT(nn.Module):
def forward(self, x):
return self.forward_classifier(x)
def checkpoint_filter_fn(state_dict, model):
@ -577,6 +578,7 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
@ -590,6 +592,7 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
@ -599,7 +602,6 @@ def _create_davit(variant, pretrained=False, **kwargs):
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model

Loading…
Cancel
Save