From 6521662d588077f48f11813f83e93332b10ae7da Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 17:02:08 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 152 +++++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 76 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index f7607f2c..5c73514a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -535,89 +535,89 @@ class DaViT(nn.Module): x = self.forward_head(x) return x - def checkpoint_filter_fn(state_dict, model): - """ Remap MSFT checkpoints -> timm """ - if 'head.norm.weight' in state_dict: - return state_dict # non-MSFT checkpoint - - if 'state_dict' in state_dict: - state_dict = state_dict['state_dict'] - - out_dict = {} - import re - for k, v in state_dict.items(): - k = k.replace('norms.', 'head.norm.') - out_dict[k] = v - return out_dict - - - - def _create_davit(variant, pretrained=False, **kwargs): - model = build_model_with_cfg(DaViT, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model - +def checkpoint_filter_fn(state_dict, model): + """ Remap MSFT checkpoints -> timm """ + if 'head.norm.weight' in state_dict: + return state_dict # non-MSFT checkpoint + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] - def _cfg(url='', **kwargs): # not sure how this should be set up - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', - **kwargs - } + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('norms.', 'head.norm.') + out_dict[k] = v + return out_dict - default_cfgs = generate_default_cfgs({ - - 'davit_tiny.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"), - 'davit_small.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"), - 'davit_base.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"), +def _create_davit(variant, pretrained=False, **kwargs): + model = build_model_with_cfg(DaViT, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) + return model + + + +def _cfg(url='', **kwargs): # not sure how this should be set up + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', + **kwargs + } + + + +default_cfgs = generate_default_cfgs({ + +'davit_tiny.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"), +'davit_small.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"), +'davit_base.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"), }) + + +@register_model +def davit_tiny(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) +@register_model +def davit_small(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) - @register_model - def davit_tiny(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), - num_heads=(3, 6, 12, 24), **kwargs) - return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) - - @register_model - def davit_small(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), - num_heads=(3, 6, 12, 24), **kwargs) - return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) - - @register_model - def davit_base(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), - num_heads=(4, 8, 16, 32), **kwargs) - return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) +@register_model +def davit_base(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), + num_heads=(4, 8, 16, 32), **kwargs) + return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) + +''' models without weights +# TODO contact authors to get larger pretrained models +@register_model +def davit_large(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) - ''' models without weights - # TODO contact authors to get larger pretrained models - @register_model - def davit_large(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), - num_heads=(6, 12, 24, 48), **kwargs) - return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) - - @register_model - def davit_huge(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), - num_heads=(8, 16, 32, 64), **kwargs) - return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) - - @register_model - def davit_giant(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), - num_heads=(12, 24, 48, 96), **kwargs) - return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) - ''' \ No newline at end of file +@register_model +def davit_huge(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), + num_heads=(8, 16, 32, 64), **kwargs) + return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) + +@register_model +def davit_giant(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), + num_heads=(12, 24, 48, 96), **kwargs) + return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) +''' \ No newline at end of file