From 381b2797858248619fe8007fa1c5f5a5d4ab3919 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 19 Jun 2021 22:28:44 -0700 Subject: [PATCH] Add hybrid model fwds back --- tests/test_models.py | 2 +- timm/models/vision_transformer_hybrid.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 52a8023a..0a770784 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -173,7 +173,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): state_dict = model.state_dict() cfg = model.default_cfg - input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE) + input_size = _get_input_size(model=model) if max(input_size) > 320: # FIXME const pytest.skip("Fixed input size model > limit.") diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 30330d39..5d725c58 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -236,6 +236,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) + + @register_model def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. @@ -292,6 +298,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) + + @register_model def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. ImageNet-21k.