Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 284de330cd
commit 6521662d58

@ -535,7 +535,7 @@ class DaViT(nn.Module):
x = self.forward_head(x) x = self.forward_head(x)
return x return x
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict: if 'head.norm.weight' in state_dict:
return state_dict # non-MSFT checkpoint return state_dict # non-MSFT checkpoint
@ -552,14 +552,14 @@ class DaViT(nn.Module):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(DaViT, variant, pretrained, model = build_model_with_cfg(DaViT, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, **kwargs) pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
return model return model
def _cfg(url='', **kwargs): # not sure how this should be set up def _cfg(url='', **kwargs): # not sure how this should be set up
return { return {
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
@ -571,53 +571,53 @@ class DaViT(nn.Module):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
'davit_tiny.msft_in1k': _cfg( 'davit_tiny.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"),
'davit_small.msft_in1k': _cfg( 'davit_small.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg( 'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"),
}) })
@register_model @register_model
def davit_tiny(pretrained=False, **kwargs): def davit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), **kwargs) num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_small(pretrained=False, **kwargs): def davit_small(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24), **kwargs) num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_base(pretrained=False, **kwargs): def davit_base(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
num_heads=(4, 8, 16, 32), **kwargs) num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
''' models without weights ''' models without weights
# TODO contact authors to get larger pretrained models # TODO contact authors to get larger pretrained models
@register_model @register_model
def davit_large(pretrained=False, **kwargs): def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
num_heads=(6, 12, 24, 48), **kwargs) num_heads=(6, 12, 24, 48), **kwargs)
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_huge(pretrained=False, **kwargs): def davit_huge(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
num_heads=(8, 16, 32, 64), **kwargs) num_heads=(8, 16, 32, 64), **kwargs)
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def davit_giant(pretrained=False, **kwargs): def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
num_heads=(12, 24, 48, 96), **kwargs) num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
''' '''
Loading…
Cancel
Save