swin v2 fixup for latest changes on norm_norm_norm / bits_and_tpu branch

pull/1239/head
Ross Wightman 2 years ago
parent 10fa42b143
commit bb85b09d2a

@ -37,7 +37,7 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
from .helpers import build_model_with_cfg, named_apply
from .layers import DropPath, Mlp, to_2tuple, _assert
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn
@ -754,29 +754,10 @@ def init_weights(module: nn.Module, name: str = ''):
nn.init.zeros_(module.bias)
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
SwinTransformerV2Cr,
variant,
pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs
)
model = build_model_with_cfg(SwinTransformerV2Cr, variant, pretrained, **kwargs)
return model

Loading…
Cancel
Save