Add SAM pretrained model defs/weights for ViT B16 and B32 models.

pull/746/head
Ross Wightman 4 years ago
parent ee4d8fc69a
commit 6d8272e92c

@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### July 8, 2021
* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models.
### July 5, 2021
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).

@ -129,6 +129,12 @@ default_cfgs = {
hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_sam_224': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
'vit_base_patch16_sam_224': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
@ -596,7 +602,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs):
@register_model
def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
@ -672,6 +679,24 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
return model
@register_model
def vit_base_patch16_sam_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch32_sam_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16).

Loading…
Cancel
Save