|
|
@ -36,23 +36,13 @@ __all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
# modified nn.Sequential that includes a size tuple in the forward function
|
|
|
|
# modified nn.Sequential that includes a size tuple in the forward function
|
|
|
|
# FIXME doesn't work with torchscript/JIT
|
|
|
|
# FIXME doesn't work with torchscript/JIT
|
|
|
|
|
|
|
|
# Module 'SequentialWithSize' has no attribute '_modules'
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
for module in self._modules.values():
|
|
|
|
for module in self._modules.values():
|
|
|
|
x, size = module(x, size)
|
|
|
|
x, size = module(x, size)
|
|
|
|
return x, size
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
|
|
|
x, size = module(x, size)
|
|
|
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
|
|
@ -555,36 +545,6 @@ class DaViT(nn.Module):
|
|
|
|
global_pool = self.head.global_pool.pool_type
|
|
|
|
global_pool = self.head.global_pool.pool_type
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
def forward_network(self, x : Tensor):
|
|
|
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
|
|
|
features = [x]
|
|
|
|
|
|
|
|
sizes = [size]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for stage in self.stages:
|
|
|
|
|
|
|
|
features[-1], sizes[-1] = stage(features[-1], sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# don't append outputs of last stage, since they are already there
|
|
|
|
|
|
|
|
if(len(features) < self.num_stages):
|
|
|
|
|
|
|
|
features.append(features[-1])
|
|
|
|
|
|
|
|
sizes.append(sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
|
|
|
|
return features, sizes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_pyramid_features(self, x) -> List[Tensor]:
|
|
|
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
|
|
|
|
for i, out in enumerate(x):
|
|
|
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
|
|
|
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
#x, sizes = self.forward_network(x)
|
|
|
|
#x, sizes = self.forward_network(x)
|
|
|
@ -609,18 +569,6 @@ class DaViT(nn.Module):
|
|
|
|
return self.forward_classifier(x)
|
|
|
|
return self.forward_classifier(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
|
|
|
return self.forward_pyramid_features(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
@ -641,30 +589,6 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
model_cls = DaViT
|
|
|
|
|
|
|
|
features_only = False
|
|
|
|
|
|
|
|
kwargs_filter = None
|
|
|
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
|
|
|
if kwargs.pop('features_only', False):
|
|
|
|
|
|
|
|
model_cls = DaViTFeatures
|
|
|
|
|
|
|
|
kwargs_filter = ('num_classes', 'global_pool')
|
|
|
|
|
|
|
|
features_only = True
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
|
|
|
model_cls,
|
|
|
|
|
|
|
|
variant,
|
|
|
|
|
|
|
|
pretrained,
|
|
|
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
|
|
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
if features_only:
|
|
|
|
|
|
|
|
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
|
|
|
|
|
|
|
|
model.default_cfg = model.pretrained_cfg # backwards compat
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
@ -699,6 +623,9 @@ default_cfgs = generate_default_cfgs({
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
|
|
|
|
|
|
|
'davit_large': _cfg(),
|
|
|
|
|
|
|
|
'davit_huge': _cfg(),
|
|
|
|
|
|
|
|
'davit_giant': _cfg(),
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -721,7 +648,7 @@ def davit_base(pretrained=False, **kwargs):
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
''' models without weights
|
|
|
|
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
@ -740,4 +667,3 @@ def davit_giant(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
'''
|
|
|
|
|