From 2a4f6c13dd75df4a2439e9ba26dc60b38926a62b Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 00:40:22 +0100 Subject: [PATCH] Create model functions --- timm/models/swin_transformer_v2_cr.py | 223 +++++++++++++++++++++++++- 1 file changed, 222 insertions(+), 1 deletion(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 7adf1ec0..9aad19c0 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -12,6 +12,8 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Licensed under The MIT License [see LICENSE for details] # Written by Christoph Reich # -------------------------------------------------------- +import logging +from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch @@ -19,7 +21,81 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from .layers import DropPath, Mlp +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +# from .helpers import build_model_with_cfg, overlay_external_default_cfg +# from .vision_transformer import checkpoint_filter_fn +# from .registry import register_model +# from .layers import DropPath, Mlp + +from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg +from timm.models.vision_transformer import checkpoint_filter_fn +from timm.models.registry import register_model +from timm.models.layers import DropPath, Mlp + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_v2_cr_tiny_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_tiny_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_small_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_small_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_base_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_base_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_large_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_large_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_huge_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_huge_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_giant_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_giant_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), +} def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor: @@ -958,3 +1034,148 @@ class SwinTransformerV2CR(nn.Module): # Predict classification classification: torch.Tensor = self.head(output) return classification + + +def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformerV2CR, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def swin_v2_cr_tiny_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-T V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=96, depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_tiny_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-T V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_small_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-S V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_small_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-S V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_base_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-B V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_base_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-B V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_large_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-L V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_large_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-L V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_huge_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-H V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=352, depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_huge_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-H V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=352, depths=(2, 2, 18, 2), + num_heads=(11, 22, 44, 88), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_giant_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-G V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=512, depths=(2, 2, 18, 2), + num_heads=(16, 32, 64, 128), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-G V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2), + num_heads=(16, 32, 64, 128), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +if __name__ == '__main__': + model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False) + model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False) + + model = swin_v2_cr_small_patch4_window12_384(pretrained=False) + model = swin_v2_cr_small_patch4_window7_224(pretrained=False) + + model = swin_v2_cr_base_patch4_window12_384(pretrained=False) + model = swin_v2_cr_base_patch4_window7_224(pretrained=False) + + model = swin_v2_cr_large_patch4_window12_384(pretrained=False) + model = swin_v2_cr_large_patch4_window7_224(pretrained=False)