|
|
|
@ -12,6 +12,8 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
|
|
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
|
# Written by Christoph Reich
|
|
|
|
|
# --------------------------------------------------------
|
|
|
|
|
import logging
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from typing import Tuple, Optional, List, Union, Any, Type
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -19,7 +21,81 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
from .layers import DropPath, Mlp
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
# from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
|
|
|
# from .vision_transformer import checkpoint_filter_fn
|
|
|
|
|
# from .registry import register_model
|
|
|
|
|
# from .layers import DropPath, Mlp
|
|
|
|
|
|
|
|
|
|
from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
|
|
|
from timm.models.vision_transformer import checkpoint_filter_fn
|
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
|
from timm.models.layers import DropPath, Mlp
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
# patch models (my experiments)
|
|
|
|
|
'swin_v2_cr_tiny_patch4_window12_384': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_tiny_patch4_window7_224': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_small_patch4_window12_384': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_small_patch4_window7_224': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_base_patch4_window12_384': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_base_patch4_window7_224': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_large_patch4_window12_384': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_large_patch4_window7_224': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_huge_patch4_window12_384': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_huge_patch4_window7_224': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_giant_patch4_window12_384': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
'swin_v2_cr_giant_patch4_window7_224': _cfg(
|
|
|
|
|
url="",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=1.0),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor:
|
|
|
|
@ -958,3 +1034,148 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
# Predict classification
|
|
|
|
|
classification: torch.Tensor = self.head(output)
|
|
|
|
|
return classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
|
|
|
if default_cfg is None:
|
|
|
|
|
default_cfg = deepcopy(default_cfgs[variant])
|
|
|
|
|
overlay_external_default_cfg(default_cfg, kwargs)
|
|
|
|
|
default_num_classes = default_cfg['num_classes']
|
|
|
|
|
default_img_size = default_cfg['input_size'][-2:]
|
|
|
|
|
|
|
|
|
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
|
|
|
|
img_size = kwargs.pop('img_size', default_img_size)
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
SwinTransformerV2CR, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfg,
|
|
|
|
|
img_size=img_size,
|
|
|
|
|
num_classes=num_classes,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_tiny_patch4_window12_384(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-T V2 CR @ 384x384, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=96, depths=(2, 2, 6, 2),
|
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_tiny_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-T V2 CR @ 224x224, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2),
|
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_small_patch4_window12_384(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-S V2 CR @ 384x384, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_small_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-S V2 CR @ 224x224, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_base_patch4_window12_384(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-B V2 CR @ 384x384, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_base_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-B V2 CR @ 224x224, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_large_patch4_window12_384(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-L V2 CR @ 384x384, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_large_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-L V2 CR @ 224x224, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_huge_patch4_window12_384(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-H V2 CR @ 384x384, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=352, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_huge_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-H V2 CR @ 224x224, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=352, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(11, 22, 44, 88), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_giant_patch4_window12_384(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-G V2 CR @ 384x384, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=512, depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(16, 32, 64, 128), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Swin-G V2 CR @ 224x224, trained ImageNet-1k
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2),
|
|
|
|
|
num_heads=(16, 32, 64, 128), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False)
|
|
|
|
|
|
|
|
|
|
model = swin_v2_cr_small_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_small_patch4_window7_224(pretrained=False)
|
|
|
|
|
|
|
|
|
|
model = swin_v2_cr_base_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_base_patch4_window7_224(pretrained=False)
|
|
|
|
|
|
|
|
|
|
model = swin_v2_cr_large_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_large_patch4_window7_224(pretrained=False)
|
|
|
|
|