From bb85b09d2a32b1e5b92790dc9e160081744ba65c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 28 Feb 2022 16:39:16 -0800 Subject: [PATCH] swin v2 fixup for latest changes on norm_norm_norm / bits_and_tpu branch --- timm/models/swin_transformer_v2_cr.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index b2915bf8..bb77466f 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -37,7 +37,7 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply from .layers import DropPath, Mlp, to_2tuple, _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn @@ -754,29 +754,10 @@ def init_weights(module: nn.Module, name: str = ''): nn.init.zeros_(module.bias) -def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) +def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - - model = build_model_with_cfg( - SwinTransformerV2Cr, - variant, - pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, - pretrained_filter_fn=checkpoint_filter_fn, - **kwargs - ) - + model = build_model_with_cfg(SwinTransformerV2Cr, variant, pretrained, **kwargs) return model