Create model functions

pull/1150/head
Christoph Reich 3 years ago
parent 87b4d7a29a
commit 2a4f6c13dd

@ -12,6 +12,8 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# Written by Christoph Reich # Written by Christoph Reich
# -------------------------------------------------------- # --------------------------------------------------------
import logging
from copy import deepcopy
from typing import Tuple, Optional, List, Union, Any, Type from typing import Tuple, Optional, List, Union, Any, Type
import torch import torch
@ -19,7 +21,81 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from .layers import DropPath, Mlp from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
# from .helpers import build_model_with_cfg, overlay_external_default_cfg
# from .vision_transformer import checkpoint_filter_fn
# from .registry import register_model
# from .layers import DropPath, Mlp
from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
from timm.models.vision_transformer import checkpoint_filter_fn
from timm.models.registry import register_model
from timm.models.layers import DropPath, Mlp
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# patch models (my experiments)
'swin_v2_cr_tiny_patch4_window12_384': _cfg(
url="",
input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_tiny_patch4_window7_224': _cfg(
url="",
input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_small_patch4_window12_384': _cfg(
url="",
input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_small_patch4_window7_224': _cfg(
url="",
input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_base_patch4_window12_384': _cfg(
url="",
input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_base_patch4_window7_224': _cfg(
url="",
input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_large_patch4_window12_384': _cfg(
url="",
input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_patch4_window7_224': _cfg(
url="",
input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_huge_patch4_window12_384': _cfg(
url="",
input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_huge_patch4_window7_224': _cfg(
url="",
input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_giant_patch4_window12_384': _cfg(
url="",
input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_giant_patch4_window7_224': _cfg(
url="",
input_size=(3, 224, 224), crop_pct=1.0),
}
def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor: def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor:
@ -958,3 +1034,148 @@ class SwinTransformerV2CR(nn.Module):
# Predict classification # Predict classification
classification: torch.Tensor = self.head(output) classification: torch.Tensor = self.head(output)
return classification return classification
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
SwinTransformerV2CR, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def swin_v2_cr_tiny_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-T V2 CR @ 384x384, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=96, depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_tiny_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-T V2 CR @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_small_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-S V2 CR @ 384x384, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-S V2 CR @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B V2 CR @ 384x384, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B V2 CR @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L V2 CR @ 384x384, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L V2 CR @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_huge_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-H V2 CR @ 384x384, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=352, depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_huge_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-H V2 CR @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=352, depths=(2, 2, 18, 2),
num_heads=(11, 22, 44, 88), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_giant_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-G V2 CR @ 384x384, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=512, depths=(2, 2, 18, 2),
num_heads=(16, 32, 64, 128), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-G V2 CR @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2),
num_heads=(16, 32, 64, 128), **kwargs)
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs)
if __name__ == '__main__':
model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False)
model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False)
model = swin_v2_cr_small_patch4_window12_384(pretrained=False)
model = swin_v2_cr_small_patch4_window7_224(pretrained=False)
model = swin_v2_cr_base_patch4_window12_384(pretrained=False)
model = swin_v2_cr_base_patch4_window7_224(pretrained=False)
model = swin_v2_cr_large_patch4_window12_384(pretrained=False)
model = swin_v2_cr_large_patch4_window7_224(pretrained=False)

Loading…
Cancel
Save