Add hybrid model fwds back

cleanup_xla_model_fixes
Ross Wightman 3 years ago
parent 26f04a8e3e
commit 381b279785

@ -173,7 +173,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
state_dict = model.state_dict() state_dict = model.state_dict()
cfg = model.default_cfg cfg = model.default_cfg
input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE) input_size = _get_input_size(model=model)
if max(input_size) > 320: # FIXME const if max(input_size) > 320: # FIXME const
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")

@ -236,6 +236,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model @register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs): def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. """ R50+ViT-L/S32 hybrid.
@ -292,6 +298,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# DEPRECATED this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model @register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k. """ R50+ViT-L/S32 hybrid. ImageNet-21k.

Loading…
Cancel
Save