|
|
|
@ -37,7 +37,7 @@ import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .fx_features import register_notrace_function
|
|
|
|
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
|
|
|
|
|
from .helpers import build_model_with_cfg, named_apply
|
|
|
|
|
from .layers import DropPath, Mlp, to_2tuple, _assert
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .vision_transformer import checkpoint_filter_fn
|
|
|
|
@ -754,29 +754,10 @@ def init_weights(module: nn.Module, name: str = ''):
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
|
|
|
if default_cfg is None:
|
|
|
|
|
default_cfg = deepcopy(default_cfgs[variant])
|
|
|
|
|
overlay_external_default_cfg(default_cfg, kwargs)
|
|
|
|
|
default_num_classes = default_cfg['num_classes']
|
|
|
|
|
default_img_size = default_cfg['input_size'][-2:]
|
|
|
|
|
|
|
|
|
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
|
|
|
|
img_size = kwargs.pop('img_size', default_img_size)
|
|
|
|
|
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
SwinTransformerV2Cr,
|
|
|
|
|
variant,
|
|
|
|
|
pretrained,
|
|
|
|
|
default_cfg=default_cfg,
|
|
|
|
|
img_size=img_size,
|
|
|
|
|
num_classes=num_classes,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(SwinTransformerV2Cr, variant, pretrained, **kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|