|
|
@ -535,7 +535,7 @@ class DaViT(nn.Module):
|
|
|
|
x = self.forward_head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
if 'head.norm.weight' in state_dict:
|
|
|
|
if 'head.norm.weight' in state_dict:
|
|
|
|
return state_dict # non-MSFT checkpoint
|
|
|
|
return state_dict # non-MSFT checkpoint
|
|
|
@ -552,14 +552,14 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
model = build_model_with_cfg(DaViT, variant, pretrained,
|
|
|
|
model = build_model_with_cfg(DaViT, variant, pretrained,
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
return {
|
|
|
|
return {
|
|
|
|
'url': url,
|
|
|
|
'url': url,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
@ -571,53 +571,53 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
|
|
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"),
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"),
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"),
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_small(pretrained=False, **kwargs):
|
|
|
|
def davit_small(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_base(pretrained=False, **kwargs):
|
|
|
|
def davit_base(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
''' models without weights
|
|
|
|
''' models without weights
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_huge(pretrained=False, **kwargs):
|
|
|
|
def davit_huge(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
|
|
|
|
num_heads=(8, 16, 32, 64), **kwargs)
|
|
|
|
num_heads=(8, 16, 32, 64), **kwargs)
|
|
|
|
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_giant(pretrained=False, **kwargs):
|
|
|
|
def davit_giant(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
'''
|
|
|
|
'''
|