|
|
@ -535,89 +535,89 @@ class DaViT(nn.Module):
|
|
|
|
x = self.forward_head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
if 'head.norm.weight' in state_dict:
|
|
|
|
if 'head.norm.weight' in state_dict:
|
|
|
|
return state_dict # non-MSFT checkpoint
|
|
|
|
return state_dict # non-MSFT checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
|
|
|
|
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
k = k.replace('norms.', 'head.norm.')
|
|
|
|
k = k.replace('norms.', 'head.norm.')
|
|
|
|
out_dict[k] = v
|
|
|
|
out_dict[k] = v
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
model = build_model_with_cfg(DaViT, variant, pretrained,
|
|
|
|
model = build_model_with_cfg(DaViT, variant, pretrained,
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
return {
|
|
|
|
return {
|
|
|
|
'url': url,
|
|
|
|
'url': url,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
|
|
|
|
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
|
|
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"),
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"),
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_small(pretrained=False, **kwargs):
|
|
|
|
def davit_small(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_base(pretrained=False, **kwargs):
|
|
|
|
def davit_base(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
''' models without weights
|
|
|
|
''' models without weights
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_huge(pretrained=False, **kwargs):
|
|
|
|
def davit_huge(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
|
|
|
|
num_heads=(8, 16, 32, 64), **kwargs)
|
|
|
|
num_heads=(8, 16, 32, 64), **kwargs)
|
|
|
|
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_giant(pretrained=False, **kwargs):
|
|
|
|
def davit_giant(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
'''
|
|
|
|
'''
|